From efe18dda1e28b65fdacf7dd9eb7f57ec33a8f593 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 8 Sep 2021 12:36:53 +0100 Subject: [PATCH 01/56] Add use_tls_ member to GrpcServer --- src/ray/rpc/grpc_server.cc | 3 ++- src/ray/rpc/grpc_server.h | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 9a4cdfa7e63c3..44dcc75e44b03 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -35,10 +35,11 @@ DEFINE_stats(grpc_server_req_finished, "Finished request number in grpc server", namespace ray { namespace rpc { -GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, +GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, bool use_tls, int64_t keepalive_time_ms) : name_(std::move(name)), port_(port), + use_tls_(use_tls), is_closed_(true), num_threads_(num_threads), keepalive_time_ms_(keepalive_time_ms) { diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index ddc88aa82beb0..5c901fb0ff167 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -61,7 +61,7 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - GrpcServer(std::string name, const uint32_t port, int num_threads = 1, + GrpcServer(std::string name, const uint32_t port, int num_threads = 1, bool use_tls = true, int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/); /// Destruct this gRPC server. @@ -107,6 +107,8 @@ class GrpcServer { const std::string name_; /// Port of this server. int port_; + /// Whether to use TLS. + bool use_tls_; /// Indicates whether this server has been closed. bool is_closed_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. From d38af3541caafb975f23863a934074bf05b15e73 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 8 Sep 2021 16:52:10 +0100 Subject: [PATCH 02/56] Hacky TLS --- src/ray/rpc/grpc_client.h | 61 +++++++++++++++++++++++++++++++------- src/ray/rpc/grpc_server.cc | 35 +++++++++++++++++++++- src/ray/rpc/grpc_server.h | 7 ++++- 3 files changed, 91 insertions(+), 12 deletions(-) diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 6ca3b1f47f68b..eed14c2d1a4cb 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -17,6 +17,8 @@ #include #include +#include +#include #include "ray/common/grpc_util.h" #include "ray/common/ray_config.h" @@ -43,23 +45,27 @@ namespace rpc { template class GrpcClient { public: - GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager) - : client_call_manager_(call_manager) { + GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, + bool use_tls = true) + : client_call_manager_(call_manager), + use_tls_(use_tls) { grpc::ChannelArguments argument; // Disable http proxy since it disrupts local connections. TODO(ekl) we should make // this configurable, or selectively set it for known local connections only. argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0); argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - std::shared_ptr channel = - grpc::CreateCustomChannel(address + ":" + std::to_string(port), - grpc::InsecureChannelCredentials(), argument); + + use_tls_ = std::strcmp(std::getenv("RAY_CLIENT_TLS"), "0") != 0; + std::shared_ptr channel = BuildChannel(argument, address, port); + stub_ = GrpcService::NewStub(channel); } GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, - int num_threads) - : client_call_manager_(call_manager) { + int num_threads, bool use_tls = true) + : client_call_manager_(call_manager), + use_tls_(use_tls) { grpc::ResourceQuota quota; quota.SetMaxThreads(num_threads); grpc::ChannelArguments argument; @@ -67,9 +73,10 @@ class GrpcClient { argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0); argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - std::shared_ptr channel = - grpc::CreateCustomChannel(address + ":" + std::to_string(port), - grpc::InsecureChannelCredentials(), argument); + + use_tls_ = std::strcmp(std::getenv("RAY_CLIENT_TLS"), "0") != 0; + std::shared_ptr channel = BuildChannel(argument, address, port); + stub_ = GrpcService::NewStub(channel); } @@ -98,6 +105,40 @@ class GrpcClient { ClientCallManager &client_call_manager_; /// The gRPC-generated stub. std::unique_ptr stub_; + /// Whether to use TLS. + bool use_tls_; + + std::string ReadFile(std::string filename) { + std::ifstream t(filename); + std::stringstream buffer; + buffer << t.rdbuf(); + return buffer.str(); + }; + + std::shared_ptr BuildChannel( + grpc::ChannelArguments argument, + std::string address, + int port) { + std::shared_ptr channel; + if (use_tls_) { + std::cout << "Using TLS" << std::endl; + std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); + std::string cacert = ReadFile(server_key_file); + grpc::SslCredentialsOptions ssl_opts; + ssl_opts.pem_root_certs=cacert; + auto ssl_creds = grpc::SslCredentials(ssl_opts); + channel = + grpc::CreateCustomChannel(address + ":" + std::to_string(port), + ssl_creds, argument); + } else { + std::cout << "Not using TLS"; + channel = + grpc::CreateCustomChannel(address + ":" + std::to_string(port), + grpc::InsecureChannelCredentials(), argument); + } + return channel; + }; + }; } // namespace rpc diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 44dcc75e44b03..332ee9b6220b2 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -17,6 +17,8 @@ #include #include +#include +#include #include "ray/common/ray_config.h" #include "ray/rpc/grpc_server.h" @@ -46,6 +48,13 @@ GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, b cqs_.resize(num_threads_); } +std::string GrpcServer::ReadFile(std::string filename) { + std::ifstream t(filename); + std::stringstream buffer; + buffer << t.rdbuf(); + return buffer.str(); +}; + void GrpcServer::Run() { uint32_t specified_port = port_; std::string server_address("0.0.0.0:" + std::to_string(port_)); @@ -64,7 +73,31 @@ void GrpcServer::Run() { builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0); // TODO(hchen): Add options for authentication. - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); + use_tls_ = std::strcmp(std::getenv("RAY_SERVER_TLS"), "0") != 0; + if (use_tls_) { + std::cout << "Look at me I'm using authentication (std::cout)" << std::endl; + + std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); + std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_KEY")); + + // Create credentials from hardcoded location + std::string rootcert = ""; // for verifying clients + std::string servercert = ReadFile(server_cert_file); + std::string serverkey = ReadFile(server_key_file); + grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey.c_str(), + servercert.c_str()}; +// grpc::SslServerCredentialsOptions ssl_opts; + grpc::SslServerCredentialsOptions ssl_opts(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE); + ssl_opts.pem_root_certs = rootcert; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + + // Create server credentials + std::shared_ptr server_creds; + server_creds = grpc::SslServerCredentials(ssl_opts); + builder.AddListeningPort(server_address, server_creds, &port_); + } else { + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); + } // Register all the services to this server. if (services_.empty()) { RAY_LOG(WARNING) << "No service is found when start grpc server " << name_; diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 5c901fb0ff167..360a9bf10a90d 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -19,6 +19,8 @@ #include #include #include +#include +#include #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/status.h" @@ -87,7 +89,10 @@ class GrpcServer { } } - /// Get the port of this gRPC server. + /// Read a file + std::string ReadFile(std::string filename); + + /// Get the port of this gRPC server. int GetPort() const { return port_; } /// Register a grpc service. Multiple services can be registered to the same server. From 3b5f210b5a403fc739bf0329b9326fc23a082b47 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 8 Sep 2021 17:33:08 +0100 Subject: [PATCH 03/56] Create secure gRPC channels in Python code --- dashboard/agent.py | 4 ++-- dashboard/head.py | 3 ++- dashboard/modules/actor/actor_head.py | 5 ++--- dashboard/modules/event/event_agent.py | 4 ++-- dashboard/modules/job/job_head.py | 5 ++++- dashboard/modules/node/node_head.py | 2 +- dashboard/modules/reporter/reporter_head.py | 3 +-- dashboard/utils.py | 15 ++++++++++++++- python/ray/_private/utils.py | 13 +++++++++++++ python/ray/autoscaler/_private/monitor.py | 4 ++-- python/ray/internal/internal_api.py | 21 ++++++++++++--------- python/ray/scripts/scripts.py | 2 +- python/ray/util/client/server/proxier.py | 4 ++-- 13 files changed, 58 insertions(+), 27 deletions(-) diff --git a/dashboard/agent.py b/dashboard/agent.py index a66358fd99306..0b66929c5f269 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -9,6 +9,7 @@ import json import time import traceback +import grpc from grpc.experimental import aio as aiogrpc @@ -89,8 +90,7 @@ def __init__(self, self.grpc_port) self.aioredis_client = None options = (("grpc.enable_http_proxy", 0), ) - self.aiogrpc_raylet_channel = aiogrpc.insecure_channel( - f"{self.ip}:{self.node_manager_port}", options=options) + self.aiogrpc_raylet_channel = dashboard_utils.init_aiogrpc_channel(f"{self.ip}:{self.node_manager_port}", options) self.http_session = None ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) diff --git a/dashboard/head.py b/dashboard/head.py index 5251281446d43..039bdf5373bd3 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -1,5 +1,6 @@ import os import sys +import grpc import socket import asyncio import logging @@ -38,7 +39,7 @@ async def make_gcs_grpc_channel(redis_client): raise Exception("GCS address not found.") logger.info("Connect to GCS at %s", gcs_address) options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel(gcs_address, options=options) + channel = dashboard_utils.init_aiogrpc_channel(gcs_address, options) return channel except Exception as ex: logger.error("Connect to GCS failed: %s, retry...", ex) diff --git a/dashboard/modules/actor/actor_head.py b/dashboard/modules/actor/actor_head.py index fd00b643284eb..1d281ff1ffe8a 100644 --- a/dashboard/modules/actor/actor_head.py +++ b/dashboard/modules/actor/actor_head.py @@ -51,7 +51,7 @@ async def _update_stubs(self, change): address = "{}:{}".format(node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])) options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel(address, options=options) + channel = dashboard_utils.init_aiogrpc_channel(address, options) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) self._stubs[node_id] = stub @@ -180,8 +180,7 @@ async def kill_actor(self, req) -> aiohttp.web.Response: return rest_response(success=False, message="Bad Request") try: options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel( - f"{ip_address}:{port}", options=options) + channel = ray._private.utils.init_grpc_channel(f"{ip_address}:{port}", options=options) stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel) await stub.KillActor( diff --git a/dashboard/modules/event/event_agent.py b/dashboard/modules/event/event_agent.py index bc4651777ac6b..cc7bd609685c2 100644 --- a/dashboard/modules/event/event_agent.py +++ b/dashboard/modules/event/event_agent.py @@ -4,6 +4,7 @@ from typing import Union from grpc.experimental import aio as aiogrpc +import ray._private.utils as utils import ray.new_dashboard.utils as dashboard_utils import ray.new_dashboard.consts as dashboard_consts from ray.ray_constants import env_bool @@ -46,8 +47,7 @@ async def _connect_to_dashboard(self): if dashboard_rpc_address: logger.info("Report events to %s", dashboard_rpc_address) options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel( - dashboard_rpc_address, options=options) + channel = utils.init_grpc_channel(dashboard_rpc_address, options=options) return event_pb2_grpc.ReportEventServiceStub(channel) except Exception: logger.exception("Connect to dashboard failed.") diff --git a/dashboard/modules/job/job_head.py b/dashboard/modules/job/job_head.py index 8b69a051de75a..85d8c7cc259c2 100644 --- a/dashboard/modules/job/job_head.py +++ b/dashboard/modules/job/job_head.py @@ -1,4 +1,6 @@ +import os import json +import grpc import logging import asyncio @@ -52,7 +54,8 @@ async def submit_job(self, req) -> aiohttp.web.Response: ip = DataSource.node_id_to_ip[node_id] address = f"{ip}:{ports[1]}" options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel(address, options=options) + channel = dashboard_utils.init_aiogrpc_channel(address, options) + stub = job_agent_pb2_grpc.JobAgentServiceStub(channel) request = job_agent_pb2.InitializeJobEnvRequest( job_description=json.dumps(job_description_data)) diff --git a/dashboard/modules/node/node_head.py b/dashboard/modules/node/node_head.py index 2fb4c92bb89c6..a6acd1b549566 100644 --- a/dashboard/modules/node/node_head.py +++ b/dashboard/modules/node/node_head.py @@ -66,7 +66,7 @@ async def _update_stubs(self, change): address = "{}:{}".format(node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])) options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel(address, options=options) + channel = dashboard_utils.init_aiogrpc_channel(address, options) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) self._stubs[node_id] = stub diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py index 9e3d6a1abaa05..228396549410d 100644 --- a/dashboard/modules/reporter/reporter_head.py +++ b/dashboard/modules/reporter/reporter_head.py @@ -38,8 +38,7 @@ async def _update_stubs(self, change): node_id, ports = change.new ip = DataSource.node_id_to_ip[node_id] options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel( - f"{ip}:{ports[1]}", options=options) + channel = ray._private.utils.init_grpc_channel(f"{ip}:{ports[1]}", options=options) stub = reporter_pb2_grpc.ReporterServiceStub(channel) self._stubs[ip] = stub diff --git a/dashboard/utils.py b/dashboard/utils.py index e8b72f0bf893f..ee72a61227c02 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -12,12 +12,13 @@ import socket import time import traceback +import grpc from abc import ABCMeta, abstractmethod from base64 import b64decode from collections import namedtuple from collections.abc import MutableMapping, Mapping, Sequence from typing import Any - +from grpc.experimental import aio as aiogrpc from google.protobuf.json_format import MessageToDict import ray.new_dashboard.consts as dashboard_consts @@ -690,3 +691,15 @@ async def _looper(*args, **kwargs): return _looper return _wrapper + + +def init_aiogrpc_channel(address, options): + if os.environ["RAY_CLIENT_TLS"] == "1": + with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: + root_certs = f.read() + credentials = grpc.ssl_channel_credentials(root_certs) + channel = aiogrpc.secure_channel(address, credentials, options=options) + else: + channel = aiogrpc.insecure_channel(address, options=options) + + return channel diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index f45c4ab64c313..a6f92e90dc60f 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -15,6 +15,7 @@ import time from typing import Optional import uuid +import grpc import warnings import inspect @@ -1104,3 +1105,15 @@ def validate_namespace(namespace: str): elif namespace == "": raise ValueError("\"\" is not a valid namespace. " "Pass None to not specify a namespace.") + + +def init_grpc_channel(address, options=None): + if os.environ["RAY_CLIENT_TLS"] == "1": + with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: + root_certs = f.read() + credentials = grpc.ssl_channel_credentials(root_certs) + channel = grpc.secure_channel(address, credentials, options=options) + else: + channel = grpc.insecure_channel(address, options=options) + + return channel diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index ef12dcd7ad801..3b8565b119514 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -38,6 +38,7 @@ from ray.experimental.internal_kv import _internal_kv_put, \ _internal_kv_initialized, _internal_kv_get, _internal_kv_del from ray._raylet import connect_to_gcs, disconnect_from_gcs +import ray.new_dashboard.utils as dashboard_utils logger = logging.getLogger(__name__) @@ -113,9 +114,8 @@ def __init__(self, self.gcs_client = connect_to_gcs(ip, int(port), redis_password) # Initialize the gcs stub for getting all node resource usage. gcs_address = self.redis.get("GcsServerAddress").decode("utf-8") - options = (("grpc.enable_http_proxy", 0), ) - gcs_channel = grpc.insecure_channel(gcs_address, options=options) + gcs_channel = dashboard_utils.init_aiogrpc_channel(gcs_address, options) self.gcs_node_resources_stub = \ gcs_service_pb2_grpc.NodeResourceInfoGcsServiceStub(gcs_channel) diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index 9686125faf1c8..f14e882486524 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -1,7 +1,10 @@ +import os + import ray import ray._private.services as services import ray.worker import ray._private.profiling as profiling +import ray._private.utils as utils from ray import ray_constants from ray.state import GlobalState @@ -60,13 +63,12 @@ def get_store_stats(state, node_manager_address=None, node_manager_port=None): else: raylet_address = "{}:{}".format(node_manager_address, node_manager_port) - channel = grpc.insecure_channel( - raylet_address, - options=[ - ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), - ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), - ], - ) + + channel = utils.init_grpc_channel(raylet_address, options=[ + ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), + ]) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) reply = stub.FormatGlobalMemoryInfo( node_manager_pb2.FormatGlobalMemoryInfoRequest( @@ -87,13 +89,14 @@ def node_stats(node_manager_address=None, # We can ask any Raylet for the global memory info. assert (node_manager_address is not None and node_manager_port is not None) raylet_address = "{}:{}".format(node_manager_address, node_manager_port) - channel = grpc.insecure_channel( + channel = utils.init_grpc_channel( raylet_address, options=[ ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), - ], + ] ) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) node_stats = stub.GetNodeStats( node_manager_pb2.GetNodeStatsRequest( diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index a8ac4008fc52d..ff1f29b596af4 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -1784,7 +1784,7 @@ def healthcheck(address, redis_password, component): try: gcs_address = redis_client.get("GcsServerAddress").decode("utf-8") options = (("grpc.enable_http_proxy", 0), ) - channel = grpc.insecure_channel(gcs_address, options=options) + channel = ray._private.utils.init_grpc_channel(gcs_address, options) stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(channel) request = gcs_service_pb2.CheckAliveRequest() reply = stub.CheckAlive( diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 67d48d8f2b7cd..a250ce7fd5eca 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -198,8 +198,8 @@ def create_specific_server(self, client_id: str) -> SpecificServer: server = SpecificServer( port=port, process_handle_future=futures.Future(), - channel=grpc.insecure_channel( - f"localhost:{port}", options=GRPC_OPTIONS)) + channel=ray._private.utils.init_grpc_channel(f"localhost:{port}", options=GRPC_OPTIONS) + ) self.servers[client_id] = server return server From 01c5cd97e0de2804af62a412e9021766d9e16bb8 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 8 Sep 2021 17:33:24 +0100 Subject: [PATCH 04/56] Remove unecessary std::cout --- src/ray/rpc/grpc_client.h | 2 -- src/ray/rpc/grpc_server.cc | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index eed14c2d1a4cb..7a14d5dd007fc 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -121,7 +121,6 @@ class GrpcClient { int port) { std::shared_ptr channel; if (use_tls_) { - std::cout << "Using TLS" << std::endl; std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); std::string cacert = ReadFile(server_key_file); grpc::SslCredentialsOptions ssl_opts; @@ -131,7 +130,6 @@ class GrpcClient { grpc::CreateCustomChannel(address + ":" + std::to_string(port), ssl_creds, argument); } else { - std::cout << "Not using TLS"; channel = grpc::CreateCustomChannel(address + ":" + std::to_string(port), grpc::InsecureChannelCredentials(), argument); diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 332ee9b6220b2..0af69b08d6a1d 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -75,8 +75,6 @@ void GrpcServer::Run() { // TODO(hchen): Add options for authentication. use_tls_ = std::strcmp(std::getenv("RAY_SERVER_TLS"), "0") != 0; if (use_tls_) { - std::cout << "Look at me I'm using authentication (std::cout)" << std::endl; - std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_KEY")); From 27696756ebd3b8e607015e68cfcec4be0dffa129 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 8 Sep 2021 18:47:34 +0100 Subject: [PATCH 05/56] More TLS --- dashboard/agent.py | 3 +-- dashboard/head.py | 2 +- dashboard/utils.py | 13 +++++++++++++ python/ray/autoscaler/_private/monitor.py | 1 + python/ray/util/client/server/proxier.py | 10 +++++++++- python/ray/util/client/server/server.py | 11 ++++++++++- python/ray/util/client/worker.py | 23 ++++++++++++++--------- 7 files changed, 49 insertions(+), 14 deletions(-) diff --git a/dashboard/agent.py b/dashboard/agent.py index 0b66929c5f269..08e3c4900a28a 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -84,8 +84,7 @@ def __init__(self, assert self.ppid > 0 logger.info("Parent pid is %s", self.ppid) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = self.server.add_insecure_port( - f"[::]:{self.dashboard_agent_port}") + self.grpc_port = dashboard_utils.add_port(self.server, f"[::]:{self.dashboard_agent_port}") logger.info("Dashboard agent grpc address: %s:%s", self.ip, self.grpc_port) self.aioredis_client = None diff --git a/dashboard/head.py b/dashboard/head.py index 039bdf5373bd3..d56caf4ae8dae 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -113,7 +113,7 @@ def __init__(self, http_host, http_port, http_port_retries, redis_address, ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = self.server.add_insecure_port("[::]:0") + self.grpc_port = dashboard_utils.add_port(self.server, "[::]:0") logger.info("Dashboard head grpc address: %s:%s", self.ip, self.grpc_port) diff --git a/dashboard/utils.py b/dashboard/utils.py index ee72a61227c02..a9b8e670002c6 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -703,3 +703,16 @@ def init_aiogrpc_channel(address, options): channel = aiogrpc.insecure_channel(address, options=options) return channel + + +def add_port(server, address): + if os.environ["RAY_SERVER_TLS"] == "1": + with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: + root_certs = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: + private_key = f.read() + credentials = grpc.ssl_server_credentials([(private_key, root_certs)]) + return server.add_secure_port(address, credentials) + else: + return server.add_insecure_port(address) + diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index 3b8565b119514..8d87e3f32cef1 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -175,6 +175,7 @@ def update_load_metrics(self): request = gcs_service_pb2.GetAllResourceUsageRequest() response = self.gcs_node_resources_stub.GetAllResourceUsage( request, timeout=4) + print(response) resources_batch_data = response.resource_usage_data for resource_message in resources_batch_data.batch: diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index a250ce7fd5eca..a537426641bf9 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -644,7 +644,15 @@ def serve_proxier(connection_str: str, data_servicer, server) ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( logs_servicer, server) - server.add_insecure_port(connection_str) + if os.environ["RAY_SERVER_TLS"] == "1": + with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: + root_certs = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: + private_key = f.read() + credentials = grpc.ssl_server_credentials([(root_certs, private_key)]) + server.add_secure_port(connection_str, credentials) + else: + server.add_insecure_port(connection_str) server.start() return ClientServerHandle( task_servicer=task_servicer, diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 436dc03889d6d..9378b39b01ea7 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -1,4 +1,5 @@ import logging +import os from concurrent import futures import grpc import base64 @@ -606,7 +607,15 @@ def default_connect_handler(job_config: JobConfig = None, data_servicer, server) ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( logs_servicer, server) - server.add_insecure_port(connection_str) + if os.environ["RAY_SERVER_TLS"] == "1": + with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: + root_certs = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: + private_key = f.read() + credentials = grpc.ssl_server_credentials([(root_certs, private_key)]) + server.add_secure_port(connection_str, credentials) + else: + server.add_insecure_port(connection_str) current_handle = ClientServerHandle( task_servicer=task_servicer, data_servicer=data_servicer, diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 762ec637cf66b..78afd142f4bba 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -29,6 +29,7 @@ from ray.util.client.dataclient import DataClient from ray.util.client.logsclient import LogstreamClient from ray.util.debug import log_once +from ray._private import utils if TYPE_CHECKING: from ray.actor import ActorClass @@ -98,15 +99,19 @@ def __init__( self._conn_state = grpc.ChannelConnectivity.IDLE self._converted: Dict[str, ClientStub] = {} - if secure and _credentials is None: - _credentials = grpc.ssl_channel_credentials() - - if _credentials is not None: - self.channel = grpc.secure_channel( - conn_str, _credentials, options=GRPC_OPTIONS) - else: - self.channel = grpc.insecure_channel( - conn_str, options=GRPC_OPTIONS) + # if secure and _credentials is None: + # _credentials = grpc.ssl_channel_credentials() + # + # if _credentials is not None: + # self.channel = grpc.secure_channel( + # conn_str, _credentials, options=GRPC_OPTIONS) + # else: + # self.channel = grpc.insecure_channel( + # conn_str, options=GRPC_OPTIONS) + # + print("DEBUG:", conn_str) + print("DEBUG:", GRPC_OPTIONS) + self.channel = utils.init_grpc_channel(conn_str, GRPC_OPTIONS) self.channel.subscribe(self._on_channel_state_change) From 2962be3759dbaed4c99d2d2c77aeddc928abebe7 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 8 Sep 2021 19:10:50 +0100 Subject: [PATCH 06/56] Linting --- dashboard/agent.py | 6 ++++-- dashboard/head.py | 3 ++- dashboard/modules/actor/actor_head.py | 3 ++- dashboard/modules/event/event_agent.py | 3 ++- dashboard/modules/reporter/reporter_head.py | 3 ++- dashboard/utils.py | 1 - python/ray/autoscaler/_private/monitor.py | 3 ++- python/ray/internal/internal_api.py | 13 +++++++------ python/ray/scripts/scripts.py | 3 ++- python/ray/util/client/server/proxier.py | 4 ++-- 10 files changed, 25 insertions(+), 17 deletions(-) diff --git a/dashboard/agent.py b/dashboard/agent.py index 08e3c4900a28a..0597e24b5e41a 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -84,12 +84,14 @@ def __init__(self, assert self.ppid > 0 logger.info("Parent pid is %s", self.ppid) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = dashboard_utils.add_port(self.server, f"[::]:{self.dashboard_agent_port}") + self.grpc_port = dashboard_utils.add_port( + self.server, f"[::]:{self.dashboard_agent_port}") logger.info("Dashboard agent grpc address: %s:%s", self.ip, self.grpc_port) self.aioredis_client = None options = (("grpc.enable_http_proxy", 0), ) - self.aiogrpc_raylet_channel = dashboard_utils.init_aiogrpc_channel(f"{self.ip}:{self.node_manager_port}", options) + self.aiogrpc_raylet_channel = dashboard_utils.init_aiogrpc_channel( + f"{self.ip}:{self.node_manager_port}", options) self.http_session = None ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) diff --git a/dashboard/head.py b/dashboard/head.py index d56caf4ae8dae..54011ed3c0264 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -39,7 +39,8 @@ async def make_gcs_grpc_channel(redis_client): raise Exception("GCS address not found.") logger.info("Connect to GCS at %s", gcs_address) options = (("grpc.enable_http_proxy", 0), ) - channel = dashboard_utils.init_aiogrpc_channel(gcs_address, options) + channel = dashboard_utils.init_aiogrpc_channel( + gcs_address, options) return channel except Exception as ex: logger.error("Connect to GCS failed: %s, retry...", ex) diff --git a/dashboard/modules/actor/actor_head.py b/dashboard/modules/actor/actor_head.py index 1d281ff1ffe8a..d719de56c1f85 100644 --- a/dashboard/modules/actor/actor_head.py +++ b/dashboard/modules/actor/actor_head.py @@ -180,7 +180,8 @@ async def kill_actor(self, req) -> aiohttp.web.Response: return rest_response(success=False, message="Bad Request") try: options = (("grpc.enable_http_proxy", 0), ) - channel = ray._private.utils.init_grpc_channel(f"{ip_address}:{port}", options=options) + channel = ray._private.utils.init_grpc_channel( + f"{ip_address}:{port}", options=options) stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel) await stub.KillActor( diff --git a/dashboard/modules/event/event_agent.py b/dashboard/modules/event/event_agent.py index cc7bd609685c2..6cf8ca61fbc60 100644 --- a/dashboard/modules/event/event_agent.py +++ b/dashboard/modules/event/event_agent.py @@ -47,7 +47,8 @@ async def _connect_to_dashboard(self): if dashboard_rpc_address: logger.info("Report events to %s", dashboard_rpc_address) options = (("grpc.enable_http_proxy", 0), ) - channel = utils.init_grpc_channel(dashboard_rpc_address, options=options) + channel = utils.init_grpc_channel( + dashboard_rpc_address, options=options) return event_pb2_grpc.ReportEventServiceStub(channel) except Exception: logger.exception("Connect to dashboard failed.") diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py index 228396549410d..dd61e921c981f 100644 --- a/dashboard/modules/reporter/reporter_head.py +++ b/dashboard/modules/reporter/reporter_head.py @@ -38,7 +38,8 @@ async def _update_stubs(self, change): node_id, ports = change.new ip = DataSource.node_id_to_ip[node_id] options = (("grpc.enable_http_proxy", 0), ) - channel = ray._private.utils.init_grpc_channel(f"{ip}:{ports[1]}", options=options) + channel = ray._private.utils.init_grpc_channel( + f"{ip}:{ports[1]}", options=options) stub = reporter_pb2_grpc.ReporterServiceStub(channel) self._stubs[ip] = stub diff --git a/dashboard/utils.py b/dashboard/utils.py index a9b8e670002c6..342c8fedfb39c 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -715,4 +715,3 @@ def add_port(server, address): return server.add_secure_port(address, credentials) else: return server.add_insecure_port(address) - diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index 8d87e3f32cef1..a6620106c833d 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -115,7 +115,8 @@ def __init__(self, # Initialize the gcs stub for getting all node resource usage. gcs_address = self.redis.get("GcsServerAddress").decode("utf-8") options = (("grpc.enable_http_proxy", 0), ) - gcs_channel = dashboard_utils.init_aiogrpc_channel(gcs_address, options) + gcs_channel = dashboard_utils.init_aiogrpc_channel( + gcs_address, options) self.gcs_node_resources_stub = \ gcs_service_pb2_grpc.NodeResourceInfoGcsServiceStub(gcs_channel) diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index f14e882486524..ada909e15f9f0 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -64,10 +64,12 @@ def get_store_stats(state, node_manager_address=None, node_manager_port=None): raylet_address = "{}:{}".format(node_manager_address, node_manager_port) - channel = utils.init_grpc_channel(raylet_address, options=[ - ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), - ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), - ]) + channel = utils.init_grpc_channel( + raylet_address, + options=[ + ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), + ]) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) reply = stub.FormatGlobalMemoryInfo( @@ -94,8 +96,7 @@ def node_stats(node_manager_address=None, options=[ ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), - ] - ) + ]) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) node_stats = stub.GetNodeStats( diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index ff1f29b596af4..b9be07a6d093e 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -1784,7 +1784,8 @@ def healthcheck(address, redis_password, component): try: gcs_address = redis_client.get("GcsServerAddress").decode("utf-8") options = (("grpc.enable_http_proxy", 0), ) - channel = ray._private.utils.init_grpc_channel(gcs_address, options) + channel = ray._private.utils.init_grpc_channel( + gcs_address, options) stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(channel) request = gcs_service_pb2.CheckAliveRequest() reply = stub.CheckAlive( diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index a537426641bf9..5da50809ab397 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -198,8 +198,8 @@ def create_specific_server(self, client_id: str) -> SpecificServer: server = SpecificServer( port=port, process_handle_future=futures.Future(), - channel=ray._private.utils.init_grpc_channel(f"localhost:{port}", options=GRPC_OPTIONS) - ) + channel=ray._private.utils.init_grpc_channel( + f"localhost:{port}", options=GRPC_OPTIONS)) self.servers[client_id] = server return server From 64be21aae338126b468939ab856fe06368208fab Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 9 Sep 2021 00:07:54 +0100 Subject: [PATCH 07/56] Add secure grpc in tests --- python/ray/tests/test_metrics.py | 3 ++- python/ray/tests/test_multi_tenancy.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ray/tests/test_metrics.py b/python/ray/tests/test_metrics.py index e834686da16a7..ec1cde816ad74 100644 --- a/python/ray/tests/test_metrics.py +++ b/python/ray/tests/test_metrics.py @@ -9,6 +9,7 @@ from ray.core.generated import node_manager_pb2_grpc from ray._private.test_utils import (RayTestTimeoutException, wait_until_succeeded_without_exception) +from ray._private.utils import init_grpc_channel import psutil # We must import psutil after ray because we bundle it with ray. @@ -20,7 +21,7 @@ def test_worker_stats(shutdown_only): raylet_address = "{}:{}".format(raylet["NodeManagerAddress"], ray.nodes()[0]["NodeManagerPort"]) - channel = grpc.insecure_channel(raylet_address) + channel = init_grpc_channel(raylet_address) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) def try_get_node_stats(num_retry=5, timeout=2): diff --git a/python/ray/tests/test_multi_tenancy.py b/python/ray/tests/test_multi_tenancy.py index 359f590536fe9..acbd41b52c42c 100644 --- a/python/ray/tests/test_multi_tenancy.py +++ b/python/ray/tests/test_multi_tenancy.py @@ -12,13 +12,14 @@ from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc from ray._private.test_utils import (wait_for_condition, run_string_as_driver, run_string_as_driver_nonblocking) +from ray._private.utils import init_grpc_channel def get_workers(): raylet = ray.nodes()[0] raylet_address = "{}:{}".format(raylet["NodeManagerAddress"], raylet["NodeManagerPort"]) - channel = grpc.insecure_channel(raylet_address) + channel = init_grpc_channel(raylet_address) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) return [ worker for worker in stub.GetNodeStats( From d38e2b0af03985fe81340a02676333a3bc6fa602 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 9 Sep 2021 00:13:00 +0100 Subject: [PATCH 08/56] Fix secure grpc server initialisation --- python/ray/util/client/server/proxier.py | 2 +- python/ray/util/client/server/server.py | 2 +- python/ray/util/client/worker.py | 23 +++++++++-------------- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 5da50809ab397..63b295a9f806b 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -649,7 +649,7 @@ def serve_proxier(connection_str: str, root_certs = f.read() with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: private_key = f.read() - credentials = grpc.ssl_server_credentials([(root_certs, private_key)]) + credentials = grpc.ssl_server_credentials([(private_key, root_certs)]) server.add_secure_port(connection_str, credentials) else: server.add_insecure_port(connection_str) diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 9378b39b01ea7..28e05ae32832d 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -612,7 +612,7 @@ def default_connect_handler(job_config: JobConfig = None, root_certs = f.read() with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: private_key = f.read() - credentials = grpc.ssl_server_credentials([(root_certs, private_key)]) + credentials = grpc.ssl_server_credentials([(private_key, root_certs)]) server.add_secure_port(connection_str, credentials) else: server.add_insecure_port(connection_str) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 78afd142f4bba..762ec637cf66b 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -29,7 +29,6 @@ from ray.util.client.dataclient import DataClient from ray.util.client.logsclient import LogstreamClient from ray.util.debug import log_once -from ray._private import utils if TYPE_CHECKING: from ray.actor import ActorClass @@ -99,19 +98,15 @@ def __init__( self._conn_state = grpc.ChannelConnectivity.IDLE self._converted: Dict[str, ClientStub] = {} - # if secure and _credentials is None: - # _credentials = grpc.ssl_channel_credentials() - # - # if _credentials is not None: - # self.channel = grpc.secure_channel( - # conn_str, _credentials, options=GRPC_OPTIONS) - # else: - # self.channel = grpc.insecure_channel( - # conn_str, options=GRPC_OPTIONS) - # - print("DEBUG:", conn_str) - print("DEBUG:", GRPC_OPTIONS) - self.channel = utils.init_grpc_channel(conn_str, GRPC_OPTIONS) + if secure and _credentials is None: + _credentials = grpc.ssl_channel_credentials() + + if _credentials is not None: + self.channel = grpc.secure_channel( + conn_str, _credentials, options=GRPC_OPTIONS) + else: + self.channel = grpc.insecure_channel( + conn_str, options=GRPC_OPTIONS) self.channel.subscribe(self._on_channel_state_change) From 1668ecc50f885a7b2d1ccfb7ba142ad3ff5b8fb0 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 9 Sep 2021 15:04:42 +0100 Subject: [PATCH 09/56] Use single environment variable as feature flag --- dashboard/utils.py | 4 ++-- python/ray/_private/utils.py | 2 +- python/ray/autoscaler/_private/monitor.py | 6 +++--- python/ray/util/client/server/proxier.py | 3 ++- python/ray/util/client/server/server.py | 2 +- src/ray/rpc/grpc_client.h | 4 ++-- src/ray/rpc/grpc_server.cc | 2 +- 7 files changed, 12 insertions(+), 11 deletions(-) diff --git a/dashboard/utils.py b/dashboard/utils.py index 342c8fedfb39c..06490e03de93b 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -694,7 +694,7 @@ async def _looper(*args, **kwargs): def init_aiogrpc_channel(address, options): - if os.environ["RAY_CLIENT_TLS"] == "1": + if os.environ["RAY_USE_TLS"] == "1": with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: root_certs = f.read() credentials = grpc.ssl_channel_credentials(root_certs) @@ -706,7 +706,7 @@ def init_aiogrpc_channel(address, options): def add_port(server, address): - if os.environ["RAY_SERVER_TLS"] == "1": + if os.environ["RAY_USE_TLS"] == "1": with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: root_certs = f.read() with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index a6f92e90dc60f..54906aef693de 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1108,7 +1108,7 @@ def validate_namespace(namespace: str): def init_grpc_channel(address, options=None): - if os.environ["RAY_CLIENT_TLS"] == "1": + if os.environ["RAY_USE_TLS"] == "1": with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: root_certs = f.read() credentials = grpc.ssl_channel_credentials(root_certs) diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index a6620106c833d..cdf7427b892f6 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -38,7 +38,7 @@ from ray.experimental.internal_kv import _internal_kv_put, \ _internal_kv_initialized, _internal_kv_get, _internal_kv_del from ray._raylet import connect_to_gcs, disconnect_from_gcs -import ray.new_dashboard.utils as dashboard_utils +import ray._private.utils logger = logging.getLogger(__name__) @@ -115,7 +115,7 @@ def __init__(self, # Initialize the gcs stub for getting all node resource usage. gcs_address = self.redis.get("GcsServerAddress").decode("utf-8") options = (("grpc.enable_http_proxy", 0), ) - gcs_channel = dashboard_utils.init_aiogrpc_channel( + gcs_channel = ray._private.utils.init_grpc_channel( gcs_address, options) self.gcs_node_resources_stub = \ gcs_service_pb2_grpc.NodeResourceInfoGcsServiceStub(gcs_channel) @@ -176,7 +176,7 @@ def update_load_metrics(self): request = gcs_service_pb2.GetAllResourceUsageRequest() response = self.gcs_node_resources_stub.GetAllResourceUsage( request, timeout=4) - print(response) + print(type(response), response, [i for i in dir(response) if not i.startswith("__")]) resources_batch_data = response.resource_usage_data for resource_message in resources_batch_data.batch: diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 74e6aefd49b52..2d36771894f82 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -11,6 +11,7 @@ import time import traceback from typing import Any, Callable, Dict, List, Optional, Tuple +import os import ray from ray.cloudpickle.compat import pickle @@ -634,7 +635,7 @@ def serve_proxier(connection_str: str, data_servicer, server) ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( logs_servicer, server) - if os.environ["RAY_SERVER_TLS"] == "1": + if os.environ["RAY_USE_TLS"] == "1": with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: root_certs = f.read() with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 28e05ae32832d..b0c6380d59385 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -607,7 +607,7 @@ def default_connect_handler(job_config: JobConfig = None, data_servicer, server) ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( logs_servicer, server) - if os.environ["RAY_SERVER_TLS"] == "1": + if os.environ["RAY_USE_TLS"] == "1": with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: root_certs = f.read() with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 7a14d5dd007fc..4adde7195807d 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -56,7 +56,7 @@ class GrpcClient { argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - use_tls_ = std::strcmp(std::getenv("RAY_CLIENT_TLS"), "0") != 0; + use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; std::shared_ptr channel = BuildChannel(argument, address, port); stub_ = GrpcService::NewStub(channel); @@ -74,7 +74,7 @@ class GrpcClient { argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - use_tls_ = std::strcmp(std::getenv("RAY_CLIENT_TLS"), "0") != 0; + use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; std::shared_ptr channel = BuildChannel(argument, address, port); stub_ = GrpcService::NewStub(channel); diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 0af69b08d6a1d..25e68ca828230 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -73,7 +73,7 @@ void GrpcServer::Run() { builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0); // TODO(hchen): Add options for authentication. - use_tls_ = std::strcmp(std::getenv("RAY_SERVER_TLS"), "0") != 0; + use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; if (use_tls_) { std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_KEY")); From a2c49d66fef7e26bb029bcf84e5eb047863ed2ff Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 9 Sep 2021 16:55:29 +0100 Subject: [PATCH 10/56] Pass environment in test_client_builder.py --- python/ray/tests/test_client_builder.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/ray/tests/test_client_builder.py b/python/ray/tests/test_client_builder.py index c325a7188b04a..832ebe478a78a 100644 --- a/python/ray/tests/test_client_builder.py +++ b/python/ray/tests/test_client_builder.py @@ -77,15 +77,14 @@ def ping(self): ray.get(a.ping.remote()) print(ray.get_runtime_context().namespace) """ - anon_driver = template.format(namespace="None") - run_string_as_driver(anon_driver) + run_string_as_driver(anon_driver, dict(os.environ)) # This second run will fail if the actors don't run in separate anonymous # namespaces. - run_string_as_driver(anon_driver) + run_string_as_driver(anon_driver, dict(os.environ)) run_in_namespace = template.format(namespace="'namespace'") - script_namespace = run_string_as_driver(run_in_namespace) + script_namespace = run_string_as_driver(run_in_namespace, dict(os.environ)) # The second run fails because the actors are run in the same namespace. with pytest.raises(subprocess.CalledProcessError): run_string_as_driver(run_in_namespace) From 0b73c38b94ac40d30b058a8fe603f85e8f7c7bfb Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 9 Sep 2021 16:55:57 +0100 Subject: [PATCH 11/56] Read RAY_USE_TLS in client worker --- python/ray/util/client/worker.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 762ec637cf66b..9f8a964de50c7 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -5,6 +5,7 @@ import base64 import json import logging +import os import time import uuid import warnings @@ -98,8 +99,15 @@ def __init__( self._conn_state = grpc.ChannelConnectivity.IDLE self._converted: Dict[str, ClientStub] = {} - if secure and _credentials is None: - _credentials = grpc.ssl_channel_credentials() + # TODO tidy this up + secure = secure or (os.environ.get("RAY_USE_TLS") == "1") + if secure and _credentials is None : + if os.environ.get("RAY_TLS_SERVER_CERT"): + with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: + root_certs = f.read() + else: + root_certs = None + _credentials = grpc.ssl_channel_credentials(root_certs) if _credentials is not None: self.channel = grpc.secure_channel( From ddc874939e595ec6d738d95e9c94eff7b9d635d3 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Fri, 10 Sep 2021 10:10:01 +0100 Subject: [PATCH 12/56] Unify init_grpc_channel and init_aiogrpc_channel functions --- dashboard/agent.py | 4 ++-- dashboard/head.py | 5 +++-- dashboard/modules/actor/actor_head.py | 2 +- dashboard/modules/job/job_head.py | 6 ++---- dashboard/modules/node/node_head.py | 2 +- dashboard/utils.py | 12 ------------ python/ray/_private/utils.py | 10 ++++++---- 7 files changed, 15 insertions(+), 26 deletions(-) diff --git a/dashboard/agent.py b/dashboard/agent.py index 0597e24b5e41a..11c90d21595f1 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -90,8 +90,8 @@ def __init__(self, self.grpc_port) self.aioredis_client = None options = (("grpc.enable_http_proxy", 0), ) - self.aiogrpc_raylet_channel = dashboard_utils.init_aiogrpc_channel( - f"{self.ip}:{self.node_manager_port}", options) + self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel( + f"{self.ip}:{self.node_manager_port}", options, asynchronous=True) self.http_session = None ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) diff --git a/dashboard/head.py b/dashboard/head.py index 29a3629a7d020..ad0b04de53592 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -9,6 +9,7 @@ from grpc.experimental import aio as aiogrpc +import ray._private.utils import ray._private.services import ray.new_dashboard.consts as dashboard_consts import ray.new_dashboard.utils as dashboard_utils @@ -45,8 +46,8 @@ async def make_gcs_grpc_channel(redis_client): ("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), ) - channel = dashboard_utils.init_aiogrpc_channel( - gcs_address, options) + channel = ray._private.utils.init_grpc_channel( + gcs_address, options, asynchronous=True) return channel except Exception as ex: logger.error("Connect to GCS failed: %s, retry...", ex) diff --git a/dashboard/modules/actor/actor_head.py b/dashboard/modules/actor/actor_head.py index d719de56c1f85..dea163f1f88b4 100644 --- a/dashboard/modules/actor/actor_head.py +++ b/dashboard/modules/actor/actor_head.py @@ -51,7 +51,7 @@ async def _update_stubs(self, change): address = "{}:{}".format(node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])) options = (("grpc.enable_http_proxy", 0), ) - channel = dashboard_utils.init_aiogrpc_channel(address, options) + channel = ray._private.utils.init_grpc_channel(address, options, asynchronous=True) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) self._stubs[node_id] = stub diff --git a/dashboard/modules/job/job_head.py b/dashboard/modules/job/job_head.py index 85d8c7cc259c2..b19246ce4188c 100644 --- a/dashboard/modules/job/job_head.py +++ b/dashboard/modules/job/job_head.py @@ -1,13 +1,11 @@ -import os import json -import grpc import logging import asyncio import aiohttp.web from aioredis.pubsub import Receiver -from grpc.experimental import aio as aiogrpc +import ray._private.utils import ray._private.gcs_utils as gcs_utils import ray.new_dashboard.utils as dashboard_utils from ray.new_dashboard.modules.job import job_consts @@ -54,7 +52,7 @@ async def submit_job(self, req) -> aiohttp.web.Response: ip = DataSource.node_id_to_ip[node_id] address = f"{ip}:{ports[1]}" options = (("grpc.enable_http_proxy", 0), ) - channel = dashboard_utils.init_aiogrpc_channel(address, options) + channel = ray._private.utils.init_grpc_channel(address, options, asynchronous=True) stub = job_agent_pb2_grpc.JobAgentServiceStub(channel) request = job_agent_pb2.InitializeJobEnvRequest( diff --git a/dashboard/modules/node/node_head.py b/dashboard/modules/node/node_head.py index a6acd1b549566..3a9655cb4ea3a 100644 --- a/dashboard/modules/node/node_head.py +++ b/dashboard/modules/node/node_head.py @@ -66,7 +66,7 @@ async def _update_stubs(self, change): address = "{}:{}".format(node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])) options = (("grpc.enable_http_proxy", 0), ) - channel = dashboard_utils.init_aiogrpc_channel(address, options) + channel = ray._private.utils.init_grpc_channel(address, options, asynchronous=True) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) self._stubs[node_id] = stub diff --git a/dashboard/utils.py b/dashboard/utils.py index 06490e03de93b..6e4753faee76e 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -693,18 +693,6 @@ async def _looper(*args, **kwargs): return _wrapper -def init_aiogrpc_channel(address, options): - if os.environ["RAY_USE_TLS"] == "1": - with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: - root_certs = f.read() - credentials = grpc.ssl_channel_credentials(root_certs) - channel = aiogrpc.secure_channel(address, credentials, options=options) - else: - channel = aiogrpc.insecure_channel(address, options=options) - - return channel - - def add_port(server, address): if os.environ["RAY_USE_TLS"] == "1": with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 54906aef693de..be8a09588d724 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -13,10 +13,11 @@ import tempfile import threading import time -from typing import Optional +from typing import Optional, Sequence, Tuple, Any import uuid import grpc import warnings +from grpc.experimental import aio as aiogrpc import inspect from inspect import signature @@ -1107,13 +1108,14 @@ def validate_namespace(namespace: str): "Pass None to not specify a namespace.") -def init_grpc_channel(address, options=None): +def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] = None, asynchronous: bool = False): + grpc_module = aiogrpc if asynchronous else grpc if os.environ["RAY_USE_TLS"] == "1": with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: root_certs = f.read() credentials = grpc.ssl_channel_credentials(root_certs) - channel = grpc.secure_channel(address, credentials, options=options) + channel = grpc_module.secure_channel(address, credentials, options=options) else: - channel = grpc.insecure_channel(address, options=options) + channel = grpc_module.insecure_channel(address, options=options) return channel From 966fc49fd16df5d8e8b427dd647eb722a6f7c0ae Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Fri, 10 Sep 2021 10:18:43 +0100 Subject: [PATCH 13/56] Make function to add port to grpc server --- dashboard/agent.py | 2 +- dashboard/head.py | 2 +- dashboard/utils.py | 12 ------------ python/ray/_private/utils.py | 13 +++++++++++++ python/ray/util/client/server/proxier.py | 12 ++---------- python/ray/util/client/server/server.py | 11 ++--------- 6 files changed, 19 insertions(+), 33 deletions(-) diff --git a/dashboard/agent.py b/dashboard/agent.py index 11c90d21595f1..abc46c6be17df 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -84,7 +84,7 @@ def __init__(self, assert self.ppid > 0 logger.info("Parent pid is %s", self.ppid) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = dashboard_utils.add_port( + self.grpc_port = ray._private.utils.add_port_to_grpc_server( self.server, f"[::]:{self.dashboard_agent_port}") logger.info("Dashboard agent grpc address: %s:%s", self.ip, self.grpc_port) diff --git a/dashboard/head.py b/dashboard/head.py index ad0b04de53592..30a458c3f878e 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -121,7 +121,7 @@ def __init__(self, http_host, http_port, http_port_retries, redis_address, ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = dashboard_utils.add_port(self.server, "[::]:0") + self.grpc_port = ray._private.utils.add_port_to_grpc_server(self.server, "[::]:0") logger.info("Dashboard head grpc address: %s:%s", self.ip, self.grpc_port) diff --git a/dashboard/utils.py b/dashboard/utils.py index 6e4753faee76e..b152d412b724f 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -691,15 +691,3 @@ async def _looper(*args, **kwargs): return _looper return _wrapper - - -def add_port(server, address): - if os.environ["RAY_USE_TLS"] == "1": - with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: - root_certs = f.read() - with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: - private_key = f.read() - credentials = grpc.ssl_server_credentials([(private_key, root_certs)]) - return server.add_secure_port(address, credentials) - else: - return server.add_insecure_port(address) diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index be8a09588d724..791be2a4c60eb 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1119,3 +1119,16 @@ def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] channel = grpc_module.insecure_channel(address, options=options) return channel + + +def add_port_to_grpc_server(server, address): + if os.environ["RAY_USE_TLS"] == "1": + with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: + root_certs = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: + private_key = f.read() + credentials = grpc.ssl_server_credentials([(private_key, root_certs)]) + return server.add_secure_port(address, credentials) + else: + return server.add_insecure_port(address) + diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 2d36771894f82..e82ef58e66ba9 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -26,7 +26,7 @@ from ray._private.runtime_env import RuntimeEnvContext import ray._private.runtime_env.working_dir as working_dir_pkg from ray._private.services import ProcessInfo, start_ray_client_server -from ray._private.utils import detect_fate_sharing_support +from ray._private.utils import detect_fate_sharing_support, add_port_to_grpc_server # Import psutil after ray so the packaged version is used. import psutil @@ -635,15 +635,7 @@ def serve_proxier(connection_str: str, data_servicer, server) ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( logs_servicer, server) - if os.environ["RAY_USE_TLS"] == "1": - with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: - root_certs = f.read() - with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: - private_key = f.read() - credentials = grpc.ssl_server_credentials([(private_key, root_certs)]) - server.add_secure_port(connection_str, credentials) - else: - server.add_insecure_port(connection_str) + add_port_to_grpc_server(server, connection_str) server.start() return ClientServerHandle( task_servicer=task_servicer, diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index b0c6380d59385..c0e05601b3ce3 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -34,6 +34,7 @@ from ray.ray_constants import env_integer from ray.util.placement_group import PlacementGroup from ray._private.client_mode_hook import disable_client_hook +from ray._private.utils import add_port_to_grpc_server logger = logging.getLogger(__name__) @@ -607,15 +608,7 @@ def default_connect_handler(job_config: JobConfig = None, data_servicer, server) ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( logs_servicer, server) - if os.environ["RAY_USE_TLS"] == "1": - with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: - root_certs = f.read() - with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: - private_key = f.read() - credentials = grpc.ssl_server_credentials([(private_key, root_certs)]) - server.add_secure_port(connection_str, credentials) - else: - server.add_insecure_port(connection_str) + add_port_to_grpc_server(server, connection_str) current_handle = ClientServerHandle( task_servicer=task_servicer, data_servicer=data_servicer, From b173b784123fc61495d2d0e84f30dc7d7a3c4f84 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Fri, 10 Sep 2021 17:10:19 +0100 Subject: [PATCH 14/56] Upgrade to mTLS --- python/ray/_private/utils.py | 36 +++++++++++++++++++++++++------- python/ray/util/client/worker.py | 19 ++++++++++++----- src/ray/rpc/grpc_client.h | 13 +++++++++--- src/ray/rpc/grpc_server.cc | 5 +++-- 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 791be2a4c60eb..1dd098e1b908b 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1111,9 +1111,21 @@ def validate_namespace(namespace: str): def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] = None, asynchronous: bool = False): grpc_module = aiogrpc if asynchronous else grpc if os.environ["RAY_USE_TLS"] == "1": - with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: - root_certs = f.read() - credentials = grpc.ssl_channel_credentials(root_certs) + with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: + server_cert_chain = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: + private_key = f.read() + if "RAY_TLS_CA_CERT" in os.environ: + with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: + ca_cert = f.read() + else: + ca_cert = None + + credentials = grpc.ssl_channel_credentials( + certificate_chain=server_cert_chain, + private_key=private_key, + root_certificates=ca_cert + ) channel = grpc_module.secure_channel(address, credentials, options=options) else: channel = grpc_module.insecure_channel(address, options=options) @@ -1123,11 +1135,21 @@ def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] def add_port_to_grpc_server(server, address): if os.environ["RAY_USE_TLS"] == "1": - with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: - root_certs = f.read() - with open(os.environ["RAY_TLS_SERVER_KEY"], 'rb') as f: + with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: + server_cert_chain = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: private_key = f.read() - credentials = grpc.ssl_server_credentials([(private_key, root_certs)]) + if "RAY_TLS_CA_CERT" in os.environ: + with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: + ca_cert = f.read() + else: + ca_cert = None + + credentials = grpc.ssl_server_credentials( + [(private_key, server_cert_chain)], + root_certificates=ca_cert, + require_client_auth=ca_cert is not None + ) return server.add_secure_port(address, credentials) else: return server.add_insecure_port(address) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 9f8a964de50c7..2c7f0dc384556 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -102,12 +102,21 @@ def __init__( # TODO tidy this up secure = secure or (os.environ.get("RAY_USE_TLS") == "1") if secure and _credentials is None : - if os.environ.get("RAY_TLS_SERVER_CERT"): - with open(os.environ["RAY_TLS_SERVER_CERT"], 'rb') as f: - root_certs = f.read() + with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: + server_cert_chain = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: + private_key = f.read() + if "RAY_TLS_CA_CERT" in os.environ: + with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: + ca_cert = f.read() else: - root_certs = None - _credentials = grpc.ssl_channel_credentials(root_certs) + ca_cert = None + + _credentials = grpc.ssl_channel_credentials( + certificate_chain=server_cert_chain, + private_key=private_key, + root_certificates=ca_cert + ) if _credentials is not None: self.channel = grpc.secure_channel( diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 4adde7195807d..792328130302a 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -108,7 +108,7 @@ class GrpcClient { /// Whether to use TLS. bool use_tls_; - std::string ReadFile(std::string filename) { + std::string ReadFile(std::string filename) { std::ifstream t(filename); std::stringstream buffer; buffer << t.rdbuf(); @@ -121,10 +121,17 @@ class GrpcClient { int port) { std::shared_ptr channel; if (use_tls_) { - std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); - std::string cacert = ReadFile(server_key_file); + std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); + std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_KEY")); + std::string root_cert_file = std::string(std::getenv("RAY_TLS_CA_CERT")); + std::string server_cert_chain = ReadFile(server_cert_file); + std::string private_key = ReadFile(server_key_file); + std::string cacert = ReadFile(root_cert_file); + grpc::SslCredentialsOptions ssl_opts; ssl_opts.pem_root_certs=cacert; + ssl_opts.pem_private_key=private_key; + ssl_opts.pem_cert_chain=server_cert_chain; auto ssl_creds = grpc::SslCredentials(ssl_opts); channel = grpc::CreateCustomChannel(address + ":" + std::to_string(port), diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index f6c0dd4f343e0..2238aa47b47cb 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -77,15 +77,16 @@ void GrpcServer::Run() { if (use_tls_) { std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_KEY")); + std::string root_cert_file = std::string(std::getenv("RAY_TLS_CA_CERT")); // Create credentials from hardcoded location - std::string rootcert = ""; // for verifying clients + std::string rootcert = ReadFile(root_cert_file); std::string servercert = ReadFile(server_cert_file); std::string serverkey = ReadFile(server_key_file); grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey.c_str(), servercert.c_str()}; // grpc::SslServerCredentialsOptions ssl_opts; - grpc::SslServerCredentialsOptions ssl_opts(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE); + grpc::SslServerCredentialsOptions ssl_opts(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); ssl_opts.pem_root_certs = rootcert; ssl_opts.pem_key_cert_pairs.push_back(pkcp); From 65361a282da3e3310b910d65110d088cbecfa4cb Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Fri, 10 Sep 2021 17:25:55 +0100 Subject: [PATCH 15/56] Function to load certs from env variables --- python/ray/_private/utils.py | 22 +++++++++------------- python/ray/util/client/worker.py | 14 +++----------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 1dd098e1b908b..eba213d74d342 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1108,8 +1108,7 @@ def validate_namespace(namespace: str): "Pass None to not specify a namespace.") -def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] = None, asynchronous: bool = False): - grpc_module = aiogrpc if asynchronous else grpc +def load_certs_from_env(): if os.environ["RAY_USE_TLS"] == "1": with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: server_cert_chain = f.read() @@ -1121,6 +1120,13 @@ def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] else: ca_cert = None + return server_cert_chain, private_key, ca_cert + + +def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] = None, asynchronous: bool = False): + grpc_module = aiogrpc if asynchronous else grpc + if os.environ["RAY_USE_TLS"] == "1": + server_cert_chain, private_key, ca_cert = load_certs_from_env() credentials = grpc.ssl_channel_credentials( certificate_chain=server_cert_chain, private_key=private_key, @@ -1135,16 +1141,7 @@ def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] def add_port_to_grpc_server(server, address): if os.environ["RAY_USE_TLS"] == "1": - with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: - server_cert_chain = f.read() - with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: - private_key = f.read() - if "RAY_TLS_CA_CERT" in os.environ: - with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: - ca_cert = f.read() - else: - ca_cert = None - + server_cert_chain, private_key, ca_cert = load_certs_from_env() credentials = grpc.ssl_server_credentials( [(private_key, server_cert_chain)], root_certificates=ca_cert, @@ -1153,4 +1150,3 @@ def add_port_to_grpc_server(server, address): return server.add_secure_port(address, credentials) else: return server.add_insecure_port(address) - diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index cfbb2465e2a38..1871dbf40effc 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -31,6 +31,7 @@ from ray.util.client.logsclient import LogstreamClient from ray.util.debug import log_once import ray._private.runtime_env.working_dir as working_dir_pkg +import ray._private.utils if TYPE_CHECKING: from ray.actor import ActorClass @@ -102,17 +103,8 @@ def __init__( # TODO tidy this up secure = secure or (os.environ.get("RAY_USE_TLS") == "1") - if secure and _credentials is None : - with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: - server_cert_chain = f.read() - with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: - private_key = f.read() - if "RAY_TLS_CA_CERT" in os.environ: - with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: - ca_cert = f.read() - else: - ca_cert = None - + if secure and _credentials is None: + server_cert_chain, private_key, ca_cert = ray._private.utils.load_certs_from_env() _credentials = grpc.ssl_channel_credentials( certificate_chain=server_cert_chain, private_key=private_key, From f19e7a76324017fb04cabb38bdc9b2ab67beac8a Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Mon, 13 Sep 2021 17:00:45 +0100 Subject: [PATCH 16/56] Add example cluster yaml which generates self-signed keys --- eks-cluster.yaml | 206 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 eks-cluster.yaml diff --git a/eks-cluster.yaml b/eks-cluster.yaml new file mode 100644 index 0000000000000..6c2c5014f10a7 --- /dev/null +++ b/eks-cluster.yaml @@ -0,0 +1,206 @@ +cluster_name: example-cluster + +max_workers: 63 + +upscaling_speed: 1.0 + +idle_timeout_minutes: 1 + +provider: + type: kubernetes + use_internal_ips: true + namespace: ray + + autoscaler_service_account: + apiVersion: v1 + kind: ServiceAccount + metadata: + name: autoscaler + + autoscaler_role: + kind: Role + apiVersion: rbac.authorization.k8s.io/v1 + metadata: + name: autoscaler + rules: + - apiGroups: [""] + resources: ["pods", "pods/status", "pods/exec"] + verbs: ["get", "watch", "list", "create", "delete", "patch"] + + autoscaler_role_binding: + apiVersion: rbac.authorization.k8s.io/v1 + kind: RoleBinding + metadata: + name: autoscaler + subjects: + - kind: ServiceAccount + name: autoscaler + roleRef: + kind: Role + name: autoscaler + apiGroup: rbac.authorization.k8s.io + + services: + - apiVersion: v1 + kind: Service + metadata: + # NOTE: If you're running multiple Ray clusters with services + # on one Kubernetes cluster, they must have unique service + # names. + name: ray-head + spec: + # This selector must match the head node pod's selector below. + selector: + component: ray-head + ports: + - name: client + protocol: TCP + port: 10001 + targetPort: 10001 + - name: dashboard + protocol: TCP + port: 8265 + targetPort: 8265 + - name: ray-serve + protocol: TCP + port: 8000 + targetPort: 8000 + - name: jupyter + protocol: TCP + port: 8888 + targetPort: 8888 + +head_node_type: head_node +available_node_types: + gpu_node: + min_workers: 1 + max_workers: 1 + resources: { "CPU": 8, "GPU": 0 } + node_config: + apiVersion: v1 + kind: Pod + metadata: + # Automatically generates a name for the pod with this prefix. + generateName: ray-gpu-worker- + labels: + component: ray-gpu + spec: + restartPolicy: Never + volumes: + - name: dshm + emptyDir: + medium: Memory + + containers: + - name: ray-node + imagePullPolicy: IfNotPresent + image: rayproject/ray:3bc5f0-py38 + command: ["/bin/bash", "-c", "--"] + args: ["trap : TERM INT; sleep infinity & wait;"] + # This volume allocates shared memory for Ray to use for its plasma + # object store. If you do not provide this, Ray will fall back to + # /tmp which cause slowdowns if is not a shared memory volume. + volumeMounts: + - mountPath: /dev/shm + name: dshm + + resources: + limits: + cpu: 7 + memory: 31G + + # Add in my hacky environment variables to enable TLS + env: + - name: RAY_USE_TLS + value: "1" + - name: RAY_TLS_SERVER_CERT + value: "/home/ray/.ssh/server.crt" + - name: RAY_TLS_SERVER_KEY + value: "/home/ray/.ssh/server.key" + - name: RAY_TLS_CA_CERT + value: "/home/ray/.ssh/ca.crt" + - name: GRPC_VERBOSITY + value: "debug" + + head_node: + min_workers: 0 + max_workers: 0 + resources: { "CPU": 8, "GPU": 0 } + node_config: + apiVersion: v1 + kind: Pod + metadata: + generateName: ray-head- + labels: + component: ray-head + spec: + serviceAccountName: autoscaler + + restartPolicy: Never + + volumes: + - name: dshm + emptyDir: + medium: Memory + + containers: + - name: ray-node + imagePullPolicy: IfNotPresent + image: rayproject/ray:3bc5f0-py38 + # Do not change this command - it keeps the pod alive until it is + # explicitly killed. + command: ["/bin/bash", "-c", "--"] + args: ['trap : TERM INT; sleep infinity & wait;'] + ports: + - containerPort: 6379 # Redis port + - containerPort: 10001 # Used by Ray Client + - containerPort: 8265 # Used by Ray Dashboard + - containerPort: 8888 # Used by Jupyter + + volumeMounts: + - mountPath: /dev/shm + name: dshm + + resources: + limits: + cpu: 7 + memory: 31G + + env: + - name: RAY_USE_TLS + value: "1" + - name: RAY_TLS_SERVER_CERT + value: "/home/ray/.ssh/server.crt" + - name: RAY_TLS_SERVER_KEY + value: "/home/ray/.ssh/server.key" + - name: RAY_TLS_CA_CERT + value: "/home/ray/.ssh/ca.crt" + - name: GRPC_VERBOSITY + value: "debug" + +setup_commands: + - openssl req -newkey rsa:2048 -nodes -keyout /home/ray/.ssh/server.key -subj "/C=CN/ST=GD/L=SZ/O=G-Research/CN=$HOSTNAME" -out /home/ray/.ssh/server.csr + - export SAN=$(python -c "import socket; print(socket.gethostbyname(socket.gethostname()))") && echo subjectAltName=DNS:$SAN,DNS:127.0.0.1,DNS:0.0.0.0,DNS:localhost > /home/ray/.ssh/extfile + - openssl x509 -req -extfile /home/ray/.ssh/extfile -days 365 -in /home/ray/.ssh/server.csr -CA /home/ray/.ssh/ca.crt -CAkey /home/ray/.ssh/ca.key -CAcreateserial -out /home/ray/.ssh/server.crt + - pip install /home/ray/ray-2.0.0.dev0-cp38-cp38-linux_x86_64.whl --force-reinstall + +# After startup, run this on local machine to pull the keys +# ray rsync_down eks-cluster.yaml /home/ray/.ssh/server.crt /home/oscar/.ssh/server.crt +# ray rsync_down eks-cluster.yaml /home/ray/.ssh/server.key /home/oscar/.ssh/server.key + +head_start_ray_commands: + - ray stop + - ulimit -n 65536; ray start --head --autoscaling-config=~/ray_bootstrap_config.yaml --dashboard-host 0.0.0.0 --object-store-memory 16000000000 + +worker_start_ray_commands: + - ray stop + - ulimit -n 65536; ray start --address=$RAY_HEAD_IP:6379 --object-store-memory 16000000000 + +file_mounts: { + # TLS enabled Ray wheel + "/home/ray/ray-2.0.0.dev0-cp38-cp38-linux_x86_64.whl": "/home/oscar/CLionProjects/ray/python/dist/ray-2.0.0.dev0-cp38-cp38-linux_x86_64.whl", + # Certificate authority keys generated locally + "/home/ray/.ssh/ca.crt": "/home/oscar/.ssh/ca.crt", + "/home/ray/.ssh/ca.key": "/home/oscar/.ssh/ca.key", +# "/path2/on/remote/machine": "/path2/on/local/machine", +} \ No newline at end of file From b57c2e2744db9f58642fad3d16fb931d1b038d18 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 14 Sep 2021 16:24:59 +0100 Subject: [PATCH 17/56] Add TLS auth test --- python/ray/_private/test_utils.py | 94 ++++++++++++++++++++++++++++ python/ray/tests/conftest.py | 22 ++++++- python/ray/tests/test_client_init.py | 8 --- python/ray/tests/test_tls_auth.py | 20 ++++++ 4 files changed, 135 insertions(+), 9 deletions(-) create mode 100644 python/ray/tests/test_tls_auth.py diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index f8ae53d4351a8..3b9c21fc27390 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -7,15 +7,18 @@ import pathlib import subprocess import sys +import tempfile import time import timeit import socket import math import traceback import logging +import datetime from typing import Optional, Any, List, Dict from contextlib import redirect_stdout, redirect_stderr import yaml +import socket import ray import ray._private.services @@ -698,3 +701,94 @@ async def get_batch(self, except asyncio.TimeoutError: break return batch + + +def generate_self_signed_tls_certs(): + """Create self-signed key/cert pair for testing. + + This method requires the library ``cryptography`` be installed. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + except ImportError: + raise ImportError( + "Using `Security.temporary` requires `cryptography`, please " + "install it using either pip or conda" + ) + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + key_contents = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + ray_interal = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")] + ) + # This is the same logic used by the GCS server to acquire a private/interal IP + # address to listen on. If we just use localhost + 127.0.0.1 then we won't be able to + # connect to the GCS and will get an error like "No match found for server name: 192.168.X.Y" + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + private_ip_address = s.getsockname()[0] + s.close() + altnames = x509.SubjectAlternativeName([ + x509.DNSName(socket.gethostbyname(socket.gethostname())), # Probably 127.0.0.1 + x509.DNSName(private_ip_address), # 192.168.*.* + x509.DNSName("localhost"), + ]) + now = datetime.datetime.utcnow() + cert = ( + x509.CertificateBuilder() + .subject_name(ray_interal) + .issuer_name(ray_interal) + .add_extension(altnames, critical=False) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=365)) + .sign(key, hashes.SHA256(), default_backend()) + ) + + cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() + + return cert_contents, key_contents + + +def setup_tls(): + """Sets up required environment variables for tls""" + print("Creating self-signed certs") + cert, key = generate_self_signed_tls_certs() + print("Created certs") + temp_dir = tempfile.mkdtemp("ray-test-certs") + cert_filepath = os.path.join(temp_dir, "server.crt") + key_filepath = os.path.join(temp_dir, "server.key") + print("Writing certs to {}".format(temp_dir)) + with open(cert_filepath, "w") as fh: + fh.write(cert) + with open(key_filepath, "w") as fh: + fh.write(key) + + print("Setting environment variables") + os.environ["RAY_USE_TLS"] = "1" + os.environ["RAY_TLS_SERVER_CERT"] = cert_filepath + os.environ["RAY_TLS_SERVER_KEY"] = key_filepath + os.environ["RAY_TLS_CA_CERT"] = cert_filepath + + return key_filepath, cert_filepath, temp_dir + + +def teardown_tls(key_filepath, cert_filepath, temp_dir): + os.remove(key_filepath) + os.remove(cert_filepath) + os.removedirs(temp_dir) + os.environ["RAY_USE_TLS"] = "0" + del os.environ["RAY_TLS_SERVER_CERT"] + del os.environ["RAY_TLS_SERVER_KEY"] + del os.environ["RAY_TLS_CA_CERT"] diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index a30e1d29b6667..31188544ae1b9 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -6,11 +6,13 @@ import pytest import subprocess import json +import time import ray from ray.cluster_utils import Cluster from ray._private.services import REDIS_EXECUTABLE, _start_redis_instance -from ray._private.test_utils import init_error_pubsub +from ray._private.test_utils import init_error_pubsub, setup_tls, teardown_tls +import ray.util.client.server.server as ray_client_server import ray._private.gcs_utils as gcs_utils @@ -230,6 +232,14 @@ def call_ray_start_with_external_redis(request): subprocess.check_call(["ray", "stop"]) +@pytest.fixture +def init_and_serve(): + server_handle, _ = ray_client_server.init_and_serve("localhost:50051") + yield server_handle + ray_client_server.shutdown_with_server(server_handle.grpc_server) + time.sleep(2) + + @pytest.fixture def call_ray_stop_only(): yield @@ -287,6 +297,16 @@ def log_pubsub(): p.close() +@pytest.fixture +def use_tls(request): + if request.param: + print("Setting up TLS") + key_filepath, cert_filepath, temp_dir = setup_tls() + yield None + if request.param: + print("Tearing down TLS") + teardown_tls(key_filepath, cert_filepath, temp_dir) + """ Object spilling test fixture """ diff --git a/python/ray/tests/test_client_init.py b/python/ray/tests/test_client_init.py index 74c4cca200fea..a474d88ebe724 100644 --- a/python/ray/tests/test_client_init.py +++ b/python/ray/tests/test_client_init.py @@ -40,14 +40,6 @@ def get(self): return self.val -@pytest.fixture -def init_and_serve(): - server_handle, _ = ray_client_server.init_and_serve("localhost:50051") - yield server_handle - ray_client_server.shutdown_with_server(server_handle.grpc_server) - time.sleep(2) - - @pytest.fixture def init_and_serve_lazy(): cluster = ray.cluster_utils.Cluster() diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py new file mode 100644 index 0000000000000..0126d8f6f99f9 --- /dev/null +++ b/python/ray/tests/test_tls_auth.py @@ -0,0 +1,20 @@ +# coding: utf-8 +import os + +import pytest + +import logging + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize("use_tls", [True]) +def test_client_connect_to_tls_server(use_tls, init_and_serve): + from ray.util.client import ray + os.environ["RAY_USE_TLS"] = "0" + with pytest.raises(ConnectionError): + ray.connect("localhost:50051") + + os.environ["RAY_USE_TLS"] = "1" + ray.connect("localhost:50051") + From 65f0080b848328fe9726217de5d8991af3ac4409 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 15 Sep 2021 12:18:24 +0100 Subject: [PATCH 18/56] Add some fixtures to run test_basic.py with TLS auth --- python/ray/_private/test_utils.py | 1 + python/ray/tests/conftest.py | 31 ++++++++++++++++++++----------- python/ray/tests/test_basic.py | 15 ++++++++++----- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 3b9c21fc27390..8594b1e10c550 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -740,6 +740,7 @@ def generate_self_signed_tls_certs(): s.close() altnames = x509.SubjectAlternativeName([ x509.DNSName(socket.gethostbyname(socket.gethostname())), # Probably 127.0.0.1 + x509.DNSName("127.0.0.1"), x509.DNSName(private_ip_address), # 192.168.*.* x509.DNSName("localhost"), ]) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 31188544ae1b9..b5cd96df09f54 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -89,15 +89,17 @@ def ray_start_regular_shared(request): @pytest.fixture( - scope="module", params=[{ - "local_mode": True - }, { - "local_mode": False - }]) + scope="module", params=[ + {"local_mode": True}, + {"local_mode": False}, + {"local_mode": False, "use_tls": True} + ]) def ray_start_shared_local_modes(request): param = getattr(request, "param", {}) - with _ray_start(**param) as res: - yield res + use_tls = param.pop("use_tls", False) + with manage_tls(use_tls): + with _ray_start(**param) as res: + yield res @pytest.fixture @@ -297,14 +299,21 @@ def log_pubsub(): p.close() -@pytest.fixture +@contextmanager +def manage_tls(use_tls): + if use_tls: + key_filepath, cert_filepath, temp_dir = setup_tls() + yield use_tls + if use_tls: + teardown_tls(key_filepath, cert_filepath, temp_dir) + + +@pytest.fixture() def use_tls(request): if request.param: - print("Setting up TLS") key_filepath, cert_filepath, temp_dir = setup_tls() - yield None + yield request.param if request.param: - print("Tearing down TLS") teardown_tls(key_filepath, cert_filepath, temp_dir) """ diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 53a9974d6be36..5c95f4ec5568f 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -76,7 +76,8 @@ def test_omp_threads_set(shutdown_only): assert os.environ["OMP_NUM_THREADS"] == "1" -def test_submit_api(shutdown_only): +@pytest.mark.parametrize("use_tls", [False, True], indirect=True) +def test_submit_api(shutdown_only, use_tls): ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) @ray.remote @@ -140,7 +141,8 @@ def method(self): assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] -def test_invalid_arguments(shutdown_only): +@pytest.mark.parametrize("use_tls", [False, True], indirect=True) +def test_invalid_arguments(shutdown_only, use_tls): ray.init(num_cpus=2) for opt in [np.random.randint(-100, -1), np.random.uniform(0, 1)]: @@ -236,7 +238,8 @@ def check(): {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}) -def test_put_get(shutdown_only): +@pytest.mark.parametrize("use_tls", [False, True], indirect=True) +def test_put_get(shutdown_only, use_tls): ray.init(num_cpus=0) for i in range(100): @@ -265,7 +268,8 @@ def test_put_get(shutdown_only): @pytest.mark.skipif(sys.platform != "linux", reason="Failing on Windows") -def test_wait_timing(shutdown_only): +@pytest.mark.parametrize("use_tls", [False, True], indirect=True) +def test_wait_timing(shutdown_only, use_tls): ray.init(num_cpus=2) @ray.remote @@ -299,7 +303,8 @@ def test_function_descriptor(): assert d.get(python_descriptor2) == 123 -def test_ray_options(shutdown_only): +@pytest.mark.parametrize("use_tls", [False, True], indirect=True) +def test_ray_options(shutdown_only, use_tls): ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2}) @ray.remote( From b4dc0cad5dae4bef0c0f1caef8b906304fc4c584 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 15 Sep 2021 12:39:10 +0100 Subject: [PATCH 19/56] Fix test_tls_auth.py --- python/ray/autoscaler/_private/monitor.py | 1 - python/ray/tests/conftest.py | 2 +- python/ray/tests/test_tls_auth.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index cdf7427b892f6..63115525a7883 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -176,7 +176,6 @@ def update_load_metrics(self): request = gcs_service_pb2.GetAllResourceUsageRequest() response = self.gcs_node_resources_stub.GetAllResourceUsage( request, timeout=4) - print(type(response), response, [i for i in dir(response) if not i.startswith("__")]) resources_batch_data = response.resource_usage_data for resource_message in resources_batch_data.batch: diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index b5cd96df09f54..c5bbbb2cd814a 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -308,7 +308,7 @@ def manage_tls(use_tls): teardown_tls(key_filepath, cert_filepath, temp_dir) -@pytest.fixture() +@pytest.fixture def use_tls(request): if request.param: key_filepath, cert_filepath, temp_dir = setup_tls() diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py index 0126d8f6f99f9..c59671526ceff 100644 --- a/python/ray/tests/test_tls_auth.py +++ b/python/ray/tests/test_tls_auth.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) -@pytest.mark.parametrize("use_tls", [True]) +@pytest.mark.parametrize("use_tls", [True], indirect=True) def test_client_connect_to_tls_server(use_tls, init_and_serve): from ray.util.client import ray os.environ["RAY_USE_TLS"] = "0" From 16c0cb3a30e3b0fb1d638f547d019994964312c2 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 15 Sep 2021 13:00:24 +0100 Subject: [PATCH 20/56] Remove duplicated ReadFile function --- src/ray/rpc/common.cc | 31 +++++++++++++++++++++++++++++++ src/ray/rpc/common.h | 22 ++++++++++++++++++++++ src/ray/rpc/grpc_client.h | 16 ++++------------ src/ray/rpc/grpc_server.cc | 16 ++++------------ src/ray/rpc/grpc_server.h | 2 -- 5 files changed, 61 insertions(+), 26 deletions(-) create mode 100644 src/ray/rpc/common.cc create mode 100644 src/ray/rpc/common.h diff --git a/src/ray/rpc/common.cc b/src/ray/rpc/common.cc new file mode 100644 index 0000000000000..0193e6aea2c95 --- /dev/null +++ b/src/ray/rpc/common.cc @@ -0,0 +1,31 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "ray/rpc/common.h" + +namespace ray { +namespace rpc { + +std::string ReadCert(std::string cert_filepath) { + std::ifstream t(cert_filepath); + std::stringstream buffer; + buffer << t.rdbuf(); + return buffer.str(); +}; + +} // namespace rpc +} // namespace ray \ No newline at end of file diff --git a/src/ray/rpc/common.h b/src/ray/rpc/common.h new file mode 100644 index 0000000000000..16954c5ebeb72 --- /dev/null +++ b/src/ray/rpc/common.h @@ -0,0 +1,22 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace rpc { + +// Utility to read cert file from a particular location +std::string ReadCert(std::string cert_filepath); + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 792328130302a..88703ef6607a0 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -17,13 +17,12 @@ #include #include -#include -#include #include "ray/common/grpc_util.h" #include "ray/common/ray_config.h" #include "ray/common/status.h" #include "ray/rpc/client_call.h" +#include "ray/rpc/common.h" namespace ray { namespace rpc { @@ -108,13 +107,6 @@ class GrpcClient { /// Whether to use TLS. bool use_tls_; - std::string ReadFile(std::string filename) { - std::ifstream t(filename); - std::stringstream buffer; - buffer << t.rdbuf(); - return buffer.str(); - }; - std::shared_ptr BuildChannel( grpc::ChannelArguments argument, std::string address, @@ -124,9 +116,9 @@ class GrpcClient { std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_KEY")); std::string root_cert_file = std::string(std::getenv("RAY_TLS_CA_CERT")); - std::string server_cert_chain = ReadFile(server_cert_file); - std::string private_key = ReadFile(server_key_file); - std::string cacert = ReadFile(root_cert_file); + std::string server_cert_chain = ReadCert(server_cert_file); + std::string private_key = ReadCert(server_key_file); + std::string cacert = ReadCert(root_cert_file); grpc::SslCredentialsOptions ssl_opts; ssl_opts.pem_root_certs=cacert; diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 2238aa47b47cb..338db85599651 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -17,11 +17,10 @@ #include #include -#include -#include #include "ray/common/ray_config.h" #include "ray/rpc/grpc_server.h" +#include "ray/rpc/common.h" #include "ray/stats/metric.h" #include "ray/util/util.h" @@ -48,13 +47,6 @@ GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, b cqs_.resize(num_threads_); } -std::string GrpcServer::ReadFile(std::string filename) { - std::ifstream t(filename); - std::stringstream buffer; - buffer << t.rdbuf(); - return buffer.str(); -}; - void GrpcServer::Run() { uint32_t specified_port = port_; std::string server_address("0.0.0.0:" + std::to_string(port_)); @@ -80,9 +72,9 @@ void GrpcServer::Run() { std::string root_cert_file = std::string(std::getenv("RAY_TLS_CA_CERT")); // Create credentials from hardcoded location - std::string rootcert = ReadFile(root_cert_file); - std::string servercert = ReadFile(server_cert_file); - std::string serverkey = ReadFile(server_key_file); + std::string rootcert = ReadCert(root_cert_file); + std::string servercert = ReadCert(server_cert_file); + std::string serverkey = ReadCert(server_key_file); grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey.c_str(), servercert.c_str()}; // grpc::SslServerCredentialsOptions ssl_opts; diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 78eb729fbf064..cde493a69365e 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -19,8 +19,6 @@ #include #include #include -#include -#include #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/status.h" From 30bebae5432d6292e8b241648c61ef8f71fce3ed Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 15 Sep 2021 13:45:53 +0100 Subject: [PATCH 21/56] Formatting --- dashboard/head.py | 3 ++- dashboard/modules/actor/actor_head.py | 3 ++- dashboard/modules/job/job_head.py | 3 ++- dashboard/modules/node/node_head.py | 3 ++- python/ray/_private/test_utils.py | 29 ++++++++++----------------- python/ray/_private/utils.py | 13 ++++++------ python/ray/tests/conftest.py | 15 +++++++++----- python/ray/tests/test_tls_auth.py | 1 - python/ray/util/client/worker.py | 6 +++--- 9 files changed, 39 insertions(+), 37 deletions(-) diff --git a/dashboard/head.py b/dashboard/head.py index 30a458c3f878e..e35a6b6f862f0 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -121,7 +121,8 @@ def __init__(self, http_host, http_port, http_port_retries, redis_address, ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = ray._private.utils.add_port_to_grpc_server(self.server, "[::]:0") + self.grpc_port = ray._private.utils.add_port_to_grpc_server( + self.server, "[::]:0") logger.info("Dashboard head grpc address: %s:%s", self.ip, self.grpc_port) diff --git a/dashboard/modules/actor/actor_head.py b/dashboard/modules/actor/actor_head.py index dea163f1f88b4..f59023043826b 100644 --- a/dashboard/modules/actor/actor_head.py +++ b/dashboard/modules/actor/actor_head.py @@ -51,7 +51,8 @@ async def _update_stubs(self, change): address = "{}:{}".format(node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])) options = (("grpc.enable_http_proxy", 0), ) - channel = ray._private.utils.init_grpc_channel(address, options, asynchronous=True) + channel = ray._private.utils.init_grpc_channel( + address, options, asynchronous=True) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) self._stubs[node_id] = stub diff --git a/dashboard/modules/job/job_head.py b/dashboard/modules/job/job_head.py index b19246ce4188c..c9527a955b1eb 100644 --- a/dashboard/modules/job/job_head.py +++ b/dashboard/modules/job/job_head.py @@ -52,7 +52,8 @@ async def submit_job(self, req) -> aiohttp.web.Response: ip = DataSource.node_id_to_ip[node_id] address = f"{ip}:{ports[1]}" options = (("grpc.enable_http_proxy", 0), ) - channel = ray._private.utils.init_grpc_channel(address, options, asynchronous=True) + channel = ray._private.utils.init_grpc_channel( + address, options, asynchronous=True) stub = job_agent_pb2_grpc.JobAgentServiceStub(channel) request = job_agent_pb2.InitializeJobEnvRequest( diff --git a/dashboard/modules/node/node_head.py b/dashboard/modules/node/node_head.py index 3a9655cb4ea3a..49c4d6527073c 100644 --- a/dashboard/modules/node/node_head.py +++ b/dashboard/modules/node/node_head.py @@ -66,7 +66,8 @@ async def _update_stubs(self, change): address = "{}:{}".format(node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])) options = (("grpc.enable_http_proxy", 0), ) - channel = ray._private.utils.init_grpc_channel(address, options, asynchronous=True) + channel = ray._private.utils.init_grpc_channel( + address, options, asynchronous=True) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) self._stubs[node_id] = stub diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index bf7f23733ab27..4c2fd2421a0ab 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -703,11 +703,9 @@ def generate_self_signed_tls_certs(): except ImportError: raise ImportError( "Using `Security.temporary` requires `cryptography`, please " - "install it using either pip or conda" - ) + "install it using either pip or conda") key = rsa.generate_private_key( - public_exponent=65537, key_size=2048, backend=default_backend() - ) + public_exponent=65537, key_size=2048, backend=default_backend()) key_contents = key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, @@ -715,8 +713,7 @@ def generate_self_signed_tls_certs(): ).decode() ray_interal = x509.Name( - [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")] - ) + [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) # This is the same logic used by the GCS server to acquire a private/interal IP # address to listen on. If we just use localhost + 127.0.0.1 then we won't be able to # connect to the GCS and will get an error like "No match found for server name: 192.168.X.Y" @@ -725,23 +722,19 @@ def generate_self_signed_tls_certs(): private_ip_address = s.getsockname()[0] s.close() altnames = x509.SubjectAlternativeName([ - x509.DNSName(socket.gethostbyname(socket.gethostname())), # Probably 127.0.0.1 + x509.DNSName(socket.gethostbyname( + socket.gethostname())), # Probably 127.0.0.1 x509.DNSName("127.0.0.1"), x509.DNSName(private_ip_address), # 192.168.*.* x509.DNSName("localhost"), ]) now = datetime.datetime.utcnow() - cert = ( - x509.CertificateBuilder() - .subject_name(ray_interal) - .issuer_name(ray_interal) - .add_extension(altnames, critical=False) - .public_key(key.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(now) - .not_valid_after(now + datetime.timedelta(days=365)) - .sign(key, hashes.SHA256(), default_backend()) - ) + cert = (x509.CertificateBuilder() + .subject_name(ray_interal).issuer_name(ray_interal).add_extension( + altnames, critical=False).public_key(key.public_key()) + .serial_number(x509.random_serial_number()).not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=365)).sign( + key, hashes.SHA256(), default_backend())) cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 2acd66c328aed..32b3e34519f80 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1126,16 +1126,18 @@ def load_certs_from_env(): return server_cert_chain, private_key, ca_cert -def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] = None, asynchronous: bool = False): +def init_grpc_channel(address: str, + options: Optional[Sequence[Tuple[str, Any]]] = None, + asynchronous: bool = False): grpc_module = aiogrpc if asynchronous else grpc if os.environ["RAY_USE_TLS"] == "1": server_cert_chain, private_key, ca_cert = load_certs_from_env() credentials = grpc.ssl_channel_credentials( certificate_chain=server_cert_chain, private_key=private_key, - root_certificates=ca_cert - ) - channel = grpc_module.secure_channel(address, credentials, options=options) + root_certificates=ca_cert) + channel = grpc_module.secure_channel( + address, credentials, options=options) else: channel = grpc_module.insecure_channel(address, options=options) @@ -1148,8 +1150,7 @@ def add_port_to_grpc_server(server, address): credentials = grpc.ssl_server_credentials( [(private_key, server_cert_chain)], root_certificates=ca_cert, - require_client_auth=ca_cert is not None - ) + require_client_auth=ca_cert is not None) return server.add_secure_port(address, credentials) else: return server.add_insecure_port(address) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index c5bbbb2cd814a..25edefd60fd1d 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -89,11 +89,15 @@ def ray_start_regular_shared(request): @pytest.fixture( - scope="module", params=[ - {"local_mode": True}, - {"local_mode": False}, - {"local_mode": False, "use_tls": True} - ]) + scope="module", + params=[{ + "local_mode": True + }, { + "local_mode": False + }, { + "local_mode": False, + "use_tls": True + }]) def ray_start_shared_local_modes(request): param = getattr(request, "param", {}) use_tls = param.pop("use_tls", False) @@ -316,6 +320,7 @@ def use_tls(request): if request.param: teardown_tls(key_filepath, cert_filepath, temp_dir) + """ Object spilling test fixture """ diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py index c59671526ceff..290f92af6a66d 100644 --- a/python/ray/tests/test_tls_auth.py +++ b/python/ray/tests/test_tls_auth.py @@ -17,4 +17,3 @@ def test_client_connect_to_tls_server(use_tls, init_and_serve): os.environ["RAY_USE_TLS"] = "1" ray.connect("localhost:50051") - diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index dd0f72d5c387b..9a984ad348c5b 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -106,12 +106,12 @@ def __init__( # TODO tidy this up if self._secure and _credentials is None: - server_cert_chain, private_key, ca_cert = ray._private.utils.load_certs_from_env() + server_cert_chain, private_key, ca_cert = ray._private.utils.load_certs_from_env( + ) _credentials = grpc.ssl_channel_credentials( certificate_chain=server_cert_chain, private_key=private_key, - root_certificates=ca_cert - ) + root_certificates=ca_cert) if _credentials is not None: self._credentials = _credentials From c551c30f30c8671d8c60b8d3016829c942326b4d Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 15 Sep 2021 13:58:27 +0100 Subject: [PATCH 22/56] Remove EKS cluster YAML --- eks-cluster.yaml | 206 ----------------------------------------------- 1 file changed, 206 deletions(-) delete mode 100644 eks-cluster.yaml diff --git a/eks-cluster.yaml b/eks-cluster.yaml deleted file mode 100644 index 6c2c5014f10a7..0000000000000 --- a/eks-cluster.yaml +++ /dev/null @@ -1,206 +0,0 @@ -cluster_name: example-cluster - -max_workers: 63 - -upscaling_speed: 1.0 - -idle_timeout_minutes: 1 - -provider: - type: kubernetes - use_internal_ips: true - namespace: ray - - autoscaler_service_account: - apiVersion: v1 - kind: ServiceAccount - metadata: - name: autoscaler - - autoscaler_role: - kind: Role - apiVersion: rbac.authorization.k8s.io/v1 - metadata: - name: autoscaler - rules: - - apiGroups: [""] - resources: ["pods", "pods/status", "pods/exec"] - verbs: ["get", "watch", "list", "create", "delete", "patch"] - - autoscaler_role_binding: - apiVersion: rbac.authorization.k8s.io/v1 - kind: RoleBinding - metadata: - name: autoscaler - subjects: - - kind: ServiceAccount - name: autoscaler - roleRef: - kind: Role - name: autoscaler - apiGroup: rbac.authorization.k8s.io - - services: - - apiVersion: v1 - kind: Service - metadata: - # NOTE: If you're running multiple Ray clusters with services - # on one Kubernetes cluster, they must have unique service - # names. - name: ray-head - spec: - # This selector must match the head node pod's selector below. - selector: - component: ray-head - ports: - - name: client - protocol: TCP - port: 10001 - targetPort: 10001 - - name: dashboard - protocol: TCP - port: 8265 - targetPort: 8265 - - name: ray-serve - protocol: TCP - port: 8000 - targetPort: 8000 - - name: jupyter - protocol: TCP - port: 8888 - targetPort: 8888 - -head_node_type: head_node -available_node_types: - gpu_node: - min_workers: 1 - max_workers: 1 - resources: { "CPU": 8, "GPU": 0 } - node_config: - apiVersion: v1 - kind: Pod - metadata: - # Automatically generates a name for the pod with this prefix. - generateName: ray-gpu-worker- - labels: - component: ray-gpu - spec: - restartPolicy: Never - volumes: - - name: dshm - emptyDir: - medium: Memory - - containers: - - name: ray-node - imagePullPolicy: IfNotPresent - image: rayproject/ray:3bc5f0-py38 - command: ["/bin/bash", "-c", "--"] - args: ["trap : TERM INT; sleep infinity & wait;"] - # This volume allocates shared memory for Ray to use for its plasma - # object store. If you do not provide this, Ray will fall back to - # /tmp which cause slowdowns if is not a shared memory volume. - volumeMounts: - - mountPath: /dev/shm - name: dshm - - resources: - limits: - cpu: 7 - memory: 31G - - # Add in my hacky environment variables to enable TLS - env: - - name: RAY_USE_TLS - value: "1" - - name: RAY_TLS_SERVER_CERT - value: "/home/ray/.ssh/server.crt" - - name: RAY_TLS_SERVER_KEY - value: "/home/ray/.ssh/server.key" - - name: RAY_TLS_CA_CERT - value: "/home/ray/.ssh/ca.crt" - - name: GRPC_VERBOSITY - value: "debug" - - head_node: - min_workers: 0 - max_workers: 0 - resources: { "CPU": 8, "GPU": 0 } - node_config: - apiVersion: v1 - kind: Pod - metadata: - generateName: ray-head- - labels: - component: ray-head - spec: - serviceAccountName: autoscaler - - restartPolicy: Never - - volumes: - - name: dshm - emptyDir: - medium: Memory - - containers: - - name: ray-node - imagePullPolicy: IfNotPresent - image: rayproject/ray:3bc5f0-py38 - # Do not change this command - it keeps the pod alive until it is - # explicitly killed. - command: ["/bin/bash", "-c", "--"] - args: ['trap : TERM INT; sleep infinity & wait;'] - ports: - - containerPort: 6379 # Redis port - - containerPort: 10001 # Used by Ray Client - - containerPort: 8265 # Used by Ray Dashboard - - containerPort: 8888 # Used by Jupyter - - volumeMounts: - - mountPath: /dev/shm - name: dshm - - resources: - limits: - cpu: 7 - memory: 31G - - env: - - name: RAY_USE_TLS - value: "1" - - name: RAY_TLS_SERVER_CERT - value: "/home/ray/.ssh/server.crt" - - name: RAY_TLS_SERVER_KEY - value: "/home/ray/.ssh/server.key" - - name: RAY_TLS_CA_CERT - value: "/home/ray/.ssh/ca.crt" - - name: GRPC_VERBOSITY - value: "debug" - -setup_commands: - - openssl req -newkey rsa:2048 -nodes -keyout /home/ray/.ssh/server.key -subj "/C=CN/ST=GD/L=SZ/O=G-Research/CN=$HOSTNAME" -out /home/ray/.ssh/server.csr - - export SAN=$(python -c "import socket; print(socket.gethostbyname(socket.gethostname()))") && echo subjectAltName=DNS:$SAN,DNS:127.0.0.1,DNS:0.0.0.0,DNS:localhost > /home/ray/.ssh/extfile - - openssl x509 -req -extfile /home/ray/.ssh/extfile -days 365 -in /home/ray/.ssh/server.csr -CA /home/ray/.ssh/ca.crt -CAkey /home/ray/.ssh/ca.key -CAcreateserial -out /home/ray/.ssh/server.crt - - pip install /home/ray/ray-2.0.0.dev0-cp38-cp38-linux_x86_64.whl --force-reinstall - -# After startup, run this on local machine to pull the keys -# ray rsync_down eks-cluster.yaml /home/ray/.ssh/server.crt /home/oscar/.ssh/server.crt -# ray rsync_down eks-cluster.yaml /home/ray/.ssh/server.key /home/oscar/.ssh/server.key - -head_start_ray_commands: - - ray stop - - ulimit -n 65536; ray start --head --autoscaling-config=~/ray_bootstrap_config.yaml --dashboard-host 0.0.0.0 --object-store-memory 16000000000 - -worker_start_ray_commands: - - ray stop - - ulimit -n 65536; ray start --address=$RAY_HEAD_IP:6379 --object-store-memory 16000000000 - -file_mounts: { - # TLS enabled Ray wheel - "/home/ray/ray-2.0.0.dev0-cp38-cp38-linux_x86_64.whl": "/home/oscar/CLionProjects/ray/python/dist/ray-2.0.0.dev0-cp38-cp38-linux_x86_64.whl", - # Certificate authority keys generated locally - "/home/ray/.ssh/ca.crt": "/home/oscar/.ssh/ca.crt", - "/home/ray/.ssh/ca.key": "/home/oscar/.ssh/ca.key", -# "/path2/on/remote/machine": "/path2/on/local/machine", -} \ No newline at end of file From 1fa0fbfb78ff450206639b7390effd265a1f4d9b Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 15 Sep 2021 16:25:51 +0100 Subject: [PATCH 23/56] Don't assume TLS env vars are set --- python/ray/_private/utils.py | 6 +++--- src/ray/rpc/grpc_client.h | 6 +++++- src/ray/rpc/grpc_server.cc | 7 +++++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 32b3e34519f80..37430d928dd92 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1112,7 +1112,7 @@ def validate_namespace(namespace: str): def load_certs_from_env(): - if os.environ["RAY_USE_TLS"] == "1": + if os.environ.get("RAY_USE_TLS", "0") == "1": with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: server_cert_chain = f.read() with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: @@ -1130,7 +1130,7 @@ def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] = None, asynchronous: bool = False): grpc_module = aiogrpc if asynchronous else grpc - if os.environ["RAY_USE_TLS"] == "1": + if os.environ.get("RAY_USE_TLS", "0") == "1": server_cert_chain, private_key, ca_cert = load_certs_from_env() credentials = grpc.ssl_channel_credentials( certificate_chain=server_cert_chain, @@ -1145,7 +1145,7 @@ def init_grpc_channel(address: str, def add_port_to_grpc_server(server, address): - if os.environ["RAY_USE_TLS"] == "1": + if os.environ.get("RAY_USE_TLS", "0") == "1": server_cert_chain, private_key, ca_cert = load_certs_from_env() credentials = grpc.ssl_server_credentials( [(private_key, server_cert_chain)], diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 88703ef6607a0..d7641ce5b370e 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -55,7 +55,11 @@ class GrpcClient { argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; + if (std::getenv("RAY_USE_TLS")) { + use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; + } else { + use_tls_ = false; + } std::shared_ptr channel = BuildChannel(argument, address, port); stub_ = GrpcService::NewStub(channel); diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 338db85599651..f28fb5ecbc4dd 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -64,8 +64,11 @@ void GrpcServer::Run() { RayConfig::instance().grpc_keepalive_timeout_ms()); builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0); - // TODO(hchen): Add options for authentication. - use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; + if (std::getenv("RAY_USE_TLS")) { + use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; + } else { + use_tls_ = false; + } if (use_tls_) { std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_KEY")); From 2b0bc687cfa2d8ea94419f2c7414142c403360e9 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 15 Sep 2021 18:27:10 +0100 Subject: [PATCH 24/56] Add cryptography requirement to generate testing certs --- python/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/python/requirements.txt b/python/requirements.txt index ee0e4bb0d372c..d3a870db2a581 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -85,3 +85,4 @@ smart_open[s3] tqdm async-exit-stack async-generator +cryptography==3.4.8 From ef5025af0e40fb811c0c9a086c5ffbabe904189b Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 15 Sep 2021 18:28:11 +0100 Subject: [PATCH 25/56] Linting --- dashboard/agent.py | 1 - dashboard/head.py | 1 - dashboard/modules/event/event_agent.py | 1 - dashboard/modules/node/node_head.py | 1 - dashboard/modules/reporter/reporter_head.py | 1 - dashboard/utils.py | 2 -- python/ray/_private/test_utils.py | 7 ++++--- python/ray/autoscaler/_private/monitor.py | 2 -- 8 files changed, 4 insertions(+), 12 deletions(-) diff --git a/dashboard/agent.py b/dashboard/agent.py index abc46c6be17df..561f7d88cc3ca 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -9,7 +9,6 @@ import json import time import traceback -import grpc from grpc.experimental import aio as aiogrpc diff --git a/dashboard/head.py b/dashboard/head.py index e35a6b6f862f0..db37a5c1058cb 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -1,6 +1,5 @@ import os import sys -import grpc import socket import asyncio import logging diff --git a/dashboard/modules/event/event_agent.py b/dashboard/modules/event/event_agent.py index 6cf8ca61fbc60..9dbe535737e9b 100644 --- a/dashboard/modules/event/event_agent.py +++ b/dashboard/modules/event/event_agent.py @@ -2,7 +2,6 @@ import asyncio import logging from typing import Union -from grpc.experimental import aio as aiogrpc import ray._private.utils as utils import ray.new_dashboard.utils as dashboard_utils diff --git a/dashboard/modules/node/node_head.py b/dashboard/modules/node/node_head.py index 49c4d6527073c..b93220516f53d 100644 --- a/dashboard/modules/node/node_head.py +++ b/dashboard/modules/node/node_head.py @@ -4,7 +4,6 @@ import json import aiohttp.web from aioredis.pubsub import Receiver -from grpc.experimental import aio as aiogrpc import ray._private.utils import ray._private.gcs_utils as gcs_utils diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py index dd61e921c981f..789acb6e9e8bb 100644 --- a/dashboard/modules/reporter/reporter_head.py +++ b/dashboard/modules/reporter/reporter_head.py @@ -4,7 +4,6 @@ import os import aiohttp.web from aioredis.pubsub import Receiver -from grpc.experimental import aio as aiogrpc import ray import ray.new_dashboard.modules.reporter.reporter_consts as reporter_consts diff --git a/dashboard/utils.py b/dashboard/utils.py index b152d412b724f..2b5b31e872dcb 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -12,13 +12,11 @@ import socket import time import traceback -import grpc from abc import ABCMeta, abstractmethod from base64 import b64decode from collections import namedtuple from collections.abc import MutableMapping, Mapping, Sequence from typing import Any -from grpc.experimental import aio as aiogrpc from google.protobuf.json_format import MessageToDict import ray.new_dashboard.consts as dashboard_consts diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 4c2fd2421a0ab..d0a27ca72b3fd 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -714,9 +714,10 @@ def generate_self_signed_tls_certs(): ray_interal = x509.Name( [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) - # This is the same logic used by the GCS server to acquire a private/interal IP - # address to listen on. If we just use localhost + 127.0.0.1 then we won't be able to - # connect to the GCS and will get an error like "No match found for server name: 192.168.X.Y" + # This is the same logic used by the GCS server to acquire a + # private/interal IP address to listen on. If we just use localhost + + # 127.0.0.1 then we won't be able to connect to the GCS and will get + # an error like "No match found for server name: 192.168.X.Y" s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect(("8.8.8.8", 80)) private_ip_address = s.getsockname()[0] diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index 63115525a7883..c7351213b1e0c 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -12,8 +12,6 @@ from multiprocessing.synchronize import Event from typing import Optional -import grpc - try: import prometheus_client except ImportError: From de36d6ad3f5a25633004e05d982e0fc1c523c6d7 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 16 Sep 2021 10:03:12 +0100 Subject: [PATCH 26/56] Fix new_dashboard->dashboard merge --- dashboard/modules/event/event_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dashboard/modules/event/event_agent.py b/dashboard/modules/event/event_agent.py index e57a53dcf6cbb..613fd439d1e63 100644 --- a/dashboard/modules/event/event_agent.py +++ b/dashboard/modules/event/event_agent.py @@ -4,8 +4,8 @@ from typing import Union import ray._private.utils as utils -import ray.new_dashboard.utils as dashboard_utils -import ray.new_dashboard.consts as dashboard_consts +import ray.dashboard.utils as dashboard_utils +import ray.dashboard.consts as dashboard_consts import ray.dashboard.utils as dashboard_utils import ray.dashboard.consts as dashboard_consts from ray.ray_constants import env_bool From 92627a849adcdd71070349b30beba506f97a475e Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 16 Sep 2021 11:34:23 +0100 Subject: [PATCH 27/56] Remove possibility of nullptr from RAY_USE_TLS --- src/ray/rpc/grpc_client.h | 18 +++++++++++------- src/ray/rpc/grpc_server.cc | 1 - src/ray/rpc/grpc_server.h | 2 +- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index d7641ce5b370e..52d7658972ac1 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -55,18 +55,14 @@ class GrpcClient { argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - if (std::getenv("RAY_USE_TLS")) { - use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; - } else { - use_tls_ = false; - } + CheckTlSEnvironmentVariables(); std::shared_ptr channel = BuildChannel(argument, address, port); stub_ = GrpcService::NewStub(channel); } GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, - int num_threads, bool use_tls = true) + int num_threads, bool use_tls = false) : client_call_manager_(call_manager), use_tls_(use_tls) { grpc::ResourceQuota quota; @@ -77,7 +73,7 @@ class GrpcClient { argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; + CheckTlSEnvironmentVariables(); std::shared_ptr channel = BuildChannel(argument, address, port); stub_ = GrpcService::NewStub(channel); @@ -140,6 +136,14 @@ class GrpcClient { return channel; }; + void CheckTlSEnvironmentVariables() { + if (std::getenv("RAY_USE_TLS")) { + use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; + } else { + use_tls_ = false; + }; + } + }; } // namespace rpc diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index f28fb5ecbc4dd..6ded98169da0e 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -80,7 +80,6 @@ void GrpcServer::Run() { std::string serverkey = ReadCert(server_key_file); grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey.c_str(), servercert.c_str()}; -// grpc::SslServerCredentialsOptions ssl_opts; grpc::SslServerCredentialsOptions ssl_opts(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); ssl_opts.pem_root_certs = rootcert; ssl_opts.pem_key_cert_pairs.push_back(pkcp); diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index cde493a69365e..659ba893f9c93 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -61,7 +61,7 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - GrpcServer(std::string name, const uint32_t port, int num_threads = 1, bool use_tls = true, + GrpcServer(std::string name, const uint32_t port, int num_threads = 1, bool use_tls = false, int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/); /// Destruct this gRPC server. From d3b47dc10c5e7c83d76191a120d407b73612afd1 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 16 Sep 2021 11:53:47 +0100 Subject: [PATCH 28/56] clang-format 7.0.0 linting --- src/ray/rpc/grpc_client.h | 29 +++++++++++------------------ src/ray/rpc/grpc_server.cc | 9 +++++---- src/ray/rpc/grpc_server.h | 5 +++-- 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 52d7658972ac1..5d7887596c0a7 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -46,8 +46,7 @@ class GrpcClient { public: GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, bool use_tls = true) - : client_call_manager_(call_manager), - use_tls_(use_tls) { + : client_call_manager_(call_manager), use_tls_(use_tls) { grpc::ChannelArguments argument; // Disable http proxy since it disrupts local connections. TODO(ekl) we should make // this configurable, or selectively set it for known local connections only. @@ -63,8 +62,7 @@ class GrpcClient { GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, int num_threads, bool use_tls = false) - : client_call_manager_(call_manager), - use_tls_(use_tls) { + : client_call_manager_(call_manager), use_tls_(use_tls) { grpc::ResourceQuota quota; quota.SetMaxThreads(num_threads); grpc::ChannelArguments argument; @@ -107,10 +105,8 @@ class GrpcClient { /// Whether to use TLS. bool use_tls_; - std::shared_ptr BuildChannel( - grpc::ChannelArguments argument, - std::string address, - int port) { + std::shared_ptr BuildChannel(grpc::ChannelArguments argument, + std::string address, int port) { std::shared_ptr channel; if (use_tls_) { std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); @@ -121,17 +117,15 @@ class GrpcClient { std::string cacert = ReadCert(root_cert_file); grpc::SslCredentialsOptions ssl_opts; - ssl_opts.pem_root_certs=cacert; - ssl_opts.pem_private_key=private_key; - ssl_opts.pem_cert_chain=server_cert_chain; + ssl_opts.pem_root_certs = cacert; + ssl_opts.pem_private_key = private_key; + ssl_opts.pem_cert_chain = server_cert_chain; auto ssl_creds = grpc::SslCredentials(ssl_opts); - channel = - grpc::CreateCustomChannel(address + ":" + std::to_string(port), - ssl_creds, argument); + channel = grpc::CreateCustomChannel(address + ":" + std::to_string(port), ssl_creds, + argument); } else { - channel = - grpc::CreateCustomChannel(address + ":" + std::to_string(port), - grpc::InsecureChannelCredentials(), argument); + channel = grpc::CreateCustomChannel(address + ":" + std::to_string(port), + grpc::InsecureChannelCredentials(), argument); } return channel; }; @@ -143,7 +137,6 @@ class GrpcClient { use_tls_ = false; }; } - }; } // namespace rpc diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 6ded98169da0e..3840527fb5a9a 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -19,8 +19,8 @@ #include #include "ray/common/ray_config.h" -#include "ray/rpc/grpc_server.h" #include "ray/rpc/common.h" +#include "ray/rpc/grpc_server.h" #include "ray/stats/metric.h" #include "ray/util/util.h" @@ -36,8 +36,8 @@ DEFINE_stats(grpc_server_req_finished, "Finished request number in grpc server", namespace ray { namespace rpc { -GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, bool use_tls, - int64_t keepalive_time_ms) +GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, + bool use_tls, int64_t keepalive_time_ms) : name_(std::move(name)), port_(port), use_tls_(use_tls), @@ -80,7 +80,8 @@ void GrpcServer::Run() { std::string serverkey = ReadCert(server_key_file); grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey.c_str(), servercert.c_str()}; - grpc::SslServerCredentialsOptions ssl_opts(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + grpc::SslServerCredentialsOptions ssl_opts( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); ssl_opts.pem_root_certs = rootcert; ssl_opts.pem_key_cert_pairs.push_back(pkcp); diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 659ba893f9c93..826efbdf260bb 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -61,7 +61,8 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - GrpcServer(std::string name, const uint32_t port, int num_threads = 1, bool use_tls = false, + GrpcServer(std::string name, const uint32_t port, int num_threads = 1, + bool use_tls = false, int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/); /// Destruct this gRPC server. @@ -91,7 +92,7 @@ class GrpcServer { /// Read a file std::string ReadFile(std::string filename); - /// Get the port of this gRPC server. + /// Get the port of this gRPC server. int GetPort() const { return port_; } /// Register a grpc service. Multiple services can be registered to the same server. From 08fc4b0a1319fb170661f6ea3cb3ec8cbe3b1073 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 16 Sep 2021 12:26:25 +0100 Subject: [PATCH 29/56] Linting --- dashboard/modules/event/event_agent.py | 2 -- python/ray/internal/internal_api.py | 4 ---- python/ray/scripts/scripts.py | 1 - python/ray/tests/test_multi_tenancy.py | 1 - python/ray/util/client/server/proxier.py | 3 ++- python/ray/util/client/server/server.py | 1 - python/ray/util/client/worker.py | 4 ++-- 7 files changed, 4 insertions(+), 12 deletions(-) diff --git a/dashboard/modules/event/event_agent.py b/dashboard/modules/event/event_agent.py index 613fd439d1e63..859df7303d9c3 100644 --- a/dashboard/modules/event/event_agent.py +++ b/dashboard/modules/event/event_agent.py @@ -6,8 +6,6 @@ import ray._private.utils as utils import ray.dashboard.utils as dashboard_utils import ray.dashboard.consts as dashboard_consts -import ray.dashboard.utils as dashboard_utils -import ray.dashboard.consts as dashboard_consts from ray.ray_constants import env_bool from ray.dashboard.utils import async_loop_forever, create_task from ray.dashboard.modules.event import event_consts diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index a75ac08d13ef2..e81637078956c 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -1,5 +1,3 @@ -import os - import ray import ray._private.services as services import ray.worker @@ -44,7 +42,6 @@ def memory_summary(address=None, def get_store_stats(state, node_manager_address=None, node_manager_port=None): """Returns a formatted string describing memory usage in the cluster.""" - import grpc from ray.core.generated import node_manager_pb2 from ray.core.generated import node_manager_pb2_grpc @@ -84,7 +81,6 @@ def node_stats(node_manager_address=None, include_memory_info=True): """Returns NodeStats object describing memory usage in the cluster.""" - import grpc from ray.core.generated import node_manager_pb2 from ray.core.generated import node_manager_pb2_grpc diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index c5e0424d8a8de..fbab357ae96f8 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -16,7 +16,6 @@ import ray import psutil -import grpc import ray._private.services as services import ray.ray_constants as ray_constants import ray._private.utils diff --git a/python/ray/tests/test_multi_tenancy.py b/python/ray/tests/test_multi_tenancy.py index acbd41b52c42c..1267570d3660b 100644 --- a/python/ray/tests/test_multi_tenancy.py +++ b/python/ray/tests/test_multi_tenancy.py @@ -3,7 +3,6 @@ import sys import time -import grpc import pytest import numpy as np diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index e7c43f58e005b..767a6f1bc128a 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -28,7 +28,8 @@ from ray._private.parameter import RayParams from ray._private.runtime_env import RuntimeEnvContext from ray._private.services import ProcessInfo, start_ray_client_server -from ray._private.utils import detect_fate_sharing_support, add_port_to_grpc_server +from ray._private.utils import (detect_fate_sharing_support, + add_port_to_grpc_server) # Import psutil after ray so the packaged version is used. import psutil diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index f4de77750ebd2..97671dbd84236 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -1,5 +1,4 @@ import logging -import os from concurrent import futures import grpc import base64 diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index a2ecc1f66c03e..63e77c9c4b51d 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -106,8 +106,8 @@ def __init__( # TODO tidy this up if self._secure and _credentials is None: - server_cert_chain, private_key, ca_cert = ray._private.utils.load_certs_from_env( - ) + server_cert_chain, private_key, ca_cert = ray._private.utils\ + .load_certs_from_env() _credentials = grpc.ssl_channel_credentials( certificate_chain=server_cert_chain, private_key=private_key, From a70a355d5f0ed93c5b76d92fd00eefc9613daa3c Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 16 Sep 2021 13:51:35 +0100 Subject: [PATCH 30/56] Fix failing test_grpc_credentials test --- python/ray/util/client/worker.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 63e77c9c4b51d..a929060028976 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -100,19 +100,10 @@ def __init__( self.server = None self._conn_state = grpc.ChannelConnectivity.IDLE self._converted: Dict[str, ClientStub] = {} - self._secure = secure or (os.environ.get("RAY_USE_TLS") == "1") + self._secure = secure self._conn_str = conn_str self._connection_retries = connection_retries - # TODO tidy this up - if self._secure and _credentials is None: - server_cert_chain, private_key, ca_cert = ray._private.utils\ - .load_certs_from_env() - _credentials = grpc.ssl_channel_credentials( - certificate_chain=server_cert_chain, - private_key=private_key, - root_certificates=ca_cert) - if _credentials is not None: self._credentials = _credentials self._secure = True @@ -143,6 +134,13 @@ def _connect_channel(self) -> None: if self._secure: if self._credentials is not None: credentials = self._credentials + elif os.environ.get("RAY_USE_TLS", "0") == "1": + server_cert_chain, private_key, ca_cert = ray._private.utils \ + .load_certs_from_env() + credentials = grpc.ssl_channel_credentials( + certificate_chain=server_cert_chain, + private_key=private_key, + root_certificates=ca_cert) else: credentials = grpc.ssl_channel_credentials() self.channel = grpc.secure_channel( From cd613dff627b6b7630e75323631311704b98bc80 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 16 Sep 2021 14:09:05 +0100 Subject: [PATCH 31/56] Make dashboard head classes use async grpc again --- dashboard/modules/actor/actor_head.py | 2 +- dashboard/modules/event/event_agent.py | 4 +++- dashboard/modules/reporter/reporter_head.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dashboard/modules/actor/actor_head.py b/dashboard/modules/actor/actor_head.py index 95f5629e6246b..c05f61ea55ace 100644 --- a/dashboard/modules/actor/actor_head.py +++ b/dashboard/modules/actor/actor_head.py @@ -182,7 +182,7 @@ async def kill_actor(self, req) -> aiohttp.web.Response: try: options = (("grpc.enable_http_proxy", 0), ) channel = ray._private.utils.init_grpc_channel( - f"{ip_address}:{port}", options=options) + f"{ip_address}:{port}", options=options, asynchronous=True) stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel) await stub.KillActor( diff --git a/dashboard/modules/event/event_agent.py b/dashboard/modules/event/event_agent.py index 859df7303d9c3..2740c19d70549 100644 --- a/dashboard/modules/event/event_agent.py +++ b/dashboard/modules/event/event_agent.py @@ -47,7 +47,9 @@ async def _connect_to_dashboard(self): logger.info("Report events to %s", dashboard_rpc_address) options = (("grpc.enable_http_proxy", 0), ) channel = utils.init_grpc_channel( - dashboard_rpc_address, options=options) + dashboard_rpc_address, + options=options, + asynchronous=True) return event_pb2_grpc.ReportEventServiceStub(channel) except Exception: logger.exception("Connect to dashboard failed.") diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py index 3ea0138e6033e..beebb29cbfdf9 100644 --- a/dashboard/modules/reporter/reporter_head.py +++ b/dashboard/modules/reporter/reporter_head.py @@ -38,7 +38,7 @@ async def _update_stubs(self, change): ip = DataSource.node_id_to_ip[node_id] options = (("grpc.enable_http_proxy", 0), ) channel = ray._private.utils.init_grpc_channel( - f"{ip}:{ports[1]}", options=options) + f"{ip}:{ports[1]}", options=options, asynchronous=True) stub = reporter_pb2_grpc.ReporterServiceStub(channel) self._stubs[ip] = stub From b296a8a604cab3f76f2866723cc0695aa407a3c0 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 16 Sep 2021 14:34:01 +0100 Subject: [PATCH 32/56] Add test_tls_auth to BUILD --- python/ray/tests/BUILD | 1 + python/ray/tests/test_tls_auth.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 330a8cc55e50a..bf5a0571b7ab6 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -90,6 +90,7 @@ py_test_module_list( "test_stress_sharded.py", "test_tempfile.py", "test_tensorflow.py", + "test_tls_auth.py", "test_ray_debugger.py", ], size = "medium", diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py index 290f92af6a66d..057c2e0b2ae32 100644 --- a/python/ray/tests/test_tls_auth.py +++ b/python/ray/tests/test_tls_auth.py @@ -1,5 +1,6 @@ # coding: utf-8 import os +import sys import pytest @@ -8,6 +9,9 @@ logger = logging.getLogger(__name__) +@pytest.mark.skipif( + sys.platform == "darwin", + reason=("Cryptography doesn't install in Mac build pipeline")) @pytest.mark.parametrize("use_tls", [True], indirect=True) def test_client_connect_to_tls_server(use_tls, init_and_serve): from ray.util.client import ray From 5528b51b8a1b4404227d2c72f2dc69fb38030265 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Mon, 20 Sep 2021 10:18:59 +0100 Subject: [PATCH 33/56] Relax cryptography requirement --- python/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/requirements.txt b/python/requirements.txt index d3a870db2a581..4d0baeaf9ef80 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -85,4 +85,4 @@ smart_open[s3] tqdm async-exit-stack async-generator -cryptography==3.4.8 +cryptography>=3.0.0 From c77d97a48a125f043c74795ba53865affcaa1bff Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Mon, 20 Sep 2021 11:59:57 +0100 Subject: [PATCH 34/56] Lint --- src/ray/rpc/common.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/rpc/common.cc b/src/ray/rpc/common.cc index 0193e6aea2c95..0adaf83fa8723 100644 --- a/src/ray/rpc/common.cc +++ b/src/ray/rpc/common.cc @@ -28,4 +28,4 @@ std::string ReadCert(std::string cert_filepath) { }; } // namespace rpc -} // namespace ray \ No newline at end of file +} // namespace ray From 3cf6271dcb599de25d110cc7e8a5111da122d0c7 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Mon, 20 Sep 2021 15:05:18 +0100 Subject: [PATCH 35/56] Worker._secure looks at env var --- python/ray/util/client/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index d84af807930b2..3e0b1a38c2e03 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -103,7 +103,7 @@ def __init__( self.server = None self._conn_state = grpc.ChannelConnectivity.IDLE self._converted: Dict[str, ClientStub] = {} - self._secure = secure + self._secure = secure or os.environ.get("RAY_USE_TLS", "0") == "1" self._conn_str = conn_str self._connection_retries = connection_retries From ddfa148560c32207e5b31983c4da63279177ce9a Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Mon, 27 Sep 2021 13:48:07 +0100 Subject: [PATCH 36/56] Apply changes from ci/travis/lint.sh --- src/ray/rpc/common.cc | 8 +++----- src/ray/rpc/common.h | 8 +++----- src/ray/rpc/grpc_client.h | 4 ++-- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/ray/rpc/common.cc b/src/ray/rpc/common.cc index 0adaf83fa8723..eef01f3e1e2f5 100644 --- a/src/ray/rpc/common.cc +++ b/src/ray/rpc/common.cc @@ -17,15 +17,13 @@ #include "ray/rpc/common.h" -namespace ray { -namespace rpc { +namespace ray::rpc { -std::string ReadCert(std::string cert_filepath) { +std::string ReadCert(const std::string &cert_filepath) { std::ifstream t(cert_filepath); std::stringstream buffer; buffer << t.rdbuf(); return buffer.str(); }; -} // namespace rpc -} // namespace ray +} // namespace rpc::ray diff --git a/src/ray/rpc/common.h b/src/ray/rpc/common.h index 16954c5ebeb72..929a555a942f6 100644 --- a/src/ray/rpc/common.h +++ b/src/ray/rpc/common.h @@ -12,11 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -namespace ray { -namespace rpc { +namespace ray::rpc { // Utility to read cert file from a particular location -std::string ReadCert(std::string cert_filepath); +std::string ReadCert(const std::string &cert_filepath); -} // namespace rpc -} // namespace ray +} // namespace ray::rpc diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 5d7887596c0a7..cdde388f7fb6e 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -105,8 +105,8 @@ class GrpcClient { /// Whether to use TLS. bool use_tls_; - std::shared_ptr BuildChannel(grpc::ChannelArguments argument, - std::string address, int port) { + std::shared_ptr BuildChannel(const grpc::ChannelArguments &argument, + const std::string &address, int port) { std::shared_ptr channel; if (use_tls_) { std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); From 32acd64a34c60b891f6d962426d8fd21739fce01 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Mon, 27 Sep 2021 13:49:25 +0100 Subject: [PATCH 37/56] Skip TLS tests on MacOS --- python/ray/_private/test_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index d0a27ca72b3fd..50bb3d13c008b 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -17,6 +17,7 @@ from contextlib import redirect_stdout, redirect_stderr import yaml import socket +import pytest import ray import ray._private.services @@ -744,6 +745,8 @@ def generate_self_signed_tls_certs(): def setup_tls(): """Sets up required environment variables for tls""" + if sys.platform == "darwin": + pytest.skip("Cryptography doesn't install in Mac build pipeline") cert, key = generate_self_signed_tls_certs() temp_dir = tempfile.mkdtemp("ray-test-certs") cert_filepath = os.path.join(temp_dir, "server.crt") From d04fe6de30a8dbd9ab60acabb5cb9e87fa8a200d Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Mon, 27 Sep 2021 13:54:27 +0100 Subject: [PATCH 38/56] format.sh changes --- src/ray/rpc/common.cc | 6 +++--- src/ray/rpc/common.h | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/ray/rpc/common.cc b/src/ray/rpc/common.cc index eef01f3e1e2f5..7526c1e6efc6f 100644 --- a/src/ray/rpc/common.cc +++ b/src/ray/rpc/common.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "ray/rpc/common.h" + #include #include -#include "ray/rpc/common.h" - namespace ray::rpc { std::string ReadCert(const std::string &cert_filepath) { @@ -26,4 +26,4 @@ std::string ReadCert(const std::string &cert_filepath) { return buffer.str(); }; -} // namespace rpc::ray +} // namespace ray::rpc diff --git a/src/ray/rpc/common.h b/src/ray/rpc/common.h index 929a555a942f6..314e1eccf382c 100644 --- a/src/ray/rpc/common.h +++ b/src/ray/rpc/common.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + namespace ray::rpc { // Utility to read cert file from a particular location From 53896b33f1f973d57f1f98bccabddf8e2e12d9d0 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Mon, 27 Sep 2021 13:54:27 +0100 Subject: [PATCH 39/56] Address comments --- .bazelrc | 10 +- .buildkite/pipeline.gpu.large.yml | 8 + .buildkite/pipeline.gpu.yml | 10 + .buildkite/pipeline.macos.yml | 10 +- .buildkite/pipeline.yml | 34 +- .buildkite/windows/install/bazel.ps1 | 2 +- .clang-tidy | 41 +- .flake8 | 16 + .github/CODEOWNERS | 12 +- .github/workflows/main.yml | 3 +- .gitpod/Dockerfile | 2 +- BUILD.bazel | 35 +- bazel/ray_deps_setup.bzl | 9 +- benchmarks/object_store/test_object_store.py | 1 + benchmarks/single_node/test_single_node.py | 3 +- ci/asan_tests/run_asan_tests.sh | 8 +- ci/travis/bazel.py | 42 +- ci/travis/ci.sh | 56 ++- ci/travis/format.sh | 4 + ci/travis/install-dependencies.sh | 2 +- ci/travis/test-worker-in-container.sh | 2 +- cpp/BUILD.bazel | 1 + cpp/src/ray/api.cc | 2 +- cpp/src/ray/runtime/abstract_ray_runtime.cc | 2 +- .../ray/runtime/object/native_object_store.cc | 2 +- .../runtime/task/local_mode_task_submitter.cc | 4 +- cpp/src/ray/runtime/task/task_executor.cc | 2 +- cpp/src/ray/runtime/task/task_executor.h | 4 +- cpp/src/ray/util/process_helper.cc | 11 +- dashboard/agent.py | 2 +- dashboard/client/src/pages/job/JobDetail.tsx | 11 + dashboard/client/src/pages/job/index.tsx | 19 +- dashboard/client/src/type/job.d.ts | 2 + dashboard/head.py | 11 +- dashboard/modules/job/job_agent.py | 4 +- .../modules/runtime_env/runtime_env_agent.py | 24 +- dashboard/modules/snapshot/snapshot_head.py | 3 +- .../modules/snapshot/snapshot_schema.json | 4 - dashboard/tests/test_dashboard.py | 2 +- doc/BUILD | 32 ++ doc/Makefile | 2 +- doc/examples/dask_xgboost/README.rst | 1 + doc/examples/dask_xgboost/dask_xgboost.py | 321 ++++++++++++ doc/examples/dask_xgboost/dask_xgboost.yaml | 24 + doc/examples/modin_xgboost/README.rst | 1 + doc/examples/modin_xgboost/modin_xgboost.py | 233 +++++++++ doc/examples/modin_xgboost/modin_xgboost.yaml | 24 + doc/examples/overview.rst | 12 +- doc/kubernetes/ray-cluster.yaml | 4 +- doc/source/advanced.rst | 35 +- doc/source/cluster/config.rst | 62 +++ doc/source/cluster/ray-client.rst | 69 ++- doc/source/conf.py | 6 +- doc/source/configure.rst | 22 + doc/source/data/.gitignore | 1 + doc/source/data/_examples/README.rst | 1 + .../data/_examples/big_data_ingestion.py | 276 +++++++++++ doc/source/data/big_data_ingestion.yaml | 54 ++ doc/source/data/dask-on-ray.rst | 20 +- doc/source/data/dataset-pipeline.rst | 121 ++++- doc/source/data/dataset-tensor-support.rst | 72 +-- doc/source/data/dataset.rst | 8 +- doc/source/data/package-ref.rst | 2 + doc/source/development.rst | 11 +- doc/source/index.rst | 6 +- doc/source/raysgd/raysgd.rst | 2 +- doc/source/raysgd/raysgd_pytorch.rst | 5 +- doc/source/raysgd/raysgd_tensorflow.rst | 5 +- doc/source/raysgd/raysgd_tune.rst | 3 + doc/source/raysgd/v2/api.rst | 21 +- doc/source/raysgd/v2/examples.rst | 3 + .../tune_cifar_pytorch_pbt_example.rst | 6 + doc/source/raysgd/v2/migration-guide.rst | 393 +++++++++++++++ doc/source/raysgd/v2/raysgd.rst | 3 +- doc/source/raysgd/v2/user_guide.rst | 12 +- doc/source/serve/core-apis.rst | 43 +- doc/source/serve/deployment.rst | 10 +- doc/source/serve/ml-models.rst | 15 +- doc/source/tune/_tutorials/_faq.inc | 55 ++- doc/source/tune/api_docs/suggestion.rst | 3 +- doc/source/tune/user-guide.rst | 10 +- java/BUILD.bazel | 5 + java/dependencies.bzl | 2 + .../java/io/ray/runtime/RayNativeRuntime.java | 18 +- .../runtime/object/LocalModeObjectStore.java | 2 +- .../ray/runtime/object/NativeObjectStore.java | 6 +- .../io/ray/runtime/object/ObjectRefImpl.java | 2 +- .../io/ray/runtime/object/ObjectStore.java | 4 +- java/serve/pom.xml | 15 + .../src/main/java/io/ray/serve/Constants.java | 6 + .../java/io/ray/serve/DeploymentInfo.java | 38 ++ .../io/ray/serve/DummyBackendReplica.java | 12 + .../main/java/io/ray/serve/HandleOptions.java | 15 + .../src/main/java/io/ray/serve/HttpProxy.java | 161 ++++++ .../main/java/io/ray/serve/ProxyActor.java | 175 +++++++ .../main/java/io/ray/serve/ProxyRouter.java | 72 +++ .../java/io/ray/serve/RayServeConfig.java | 6 + .../java/io/ray/serve/RayServeHandle.java | 73 +++ .../java/io/ray/serve/RayServeMetrics.java | 74 +++ .../java/io/ray/serve/RayServeReplica.java | 211 +++++--- .../io/ray/serve/RayServeWrappedReplica.java | 42 +- .../main/java/io/ray/serve/ReplicaConfig.java | 8 +- .../java/io/ray/serve/ReplicaContext.java | 2 +- .../main/java/io/ray/serve/ReplicaSet.java | 138 ++++++ .../src/main/java/io/ray/serve/Router.java | 64 +++ .../java/io/ray/serve/ServeController.java | 6 + .../main/java/io/ray/serve/ServeProxy.java | 14 + .../main/java/io/ray/serve/api/Client.java | 72 +++ .../src/main/java/io/ray/serve/api/Serve.java | 54 +- .../java/io/ray/serve/poll/KeyListener.java | 2 +- .../io/ray/serve/poll/LongPollClient.java | 69 ++- .../io/ray/serve/poll/LongPollNamespace.java | 4 +- .../java/io/ray/serve/poll/UpdatedObject.java | 33 -- .../io/ray/serve/util/CollectionUtil.java | 10 + .../java/io/ray/serve/util/CommonUtil.java | 13 + .../java/io/ray/serve/util/ReflectUtil.java | 14 + .../io/ray/serve/util/ServeProtoUtil.java | 75 ++- .../java/io/ray/serve/util/SocketUtil.java | 49 ++ .../io/ray/serve/DummyServeController.java | 21 + .../test/java/io/ray/serve/HttpProxyTest.java | 74 +++ .../java/io/ray/serve/ProxyActorTest.java | 110 +++++ .../java/io/ray/serve/ProxyRouterTest.java | 68 +++ .../java/io/ray/serve/RayServeHandleTest.java | 76 +++ .../io/ray/serve/RayServeReplicaTest.java | 46 +- .../java/io/ray/serve/ReplicaSetTest.java | 108 ++++ .../test/java/io/ray/serve/RouterTest.java | 80 +++ .../java/io/ray/serve/api/ClientTest.java | 47 ++ .../test/java/io/ray/serve/api/ServeTest.java | 71 ++- .../java/io/ray/serve/poll/KeyTypeTest.java | 15 +- .../io/ray/serve/poll/LongPollClientTest.java | 29 +- python/build-wheel-windows.sh | 7 + python/ray/_private/client_mode_hook.py | 50 +- python/ray/_private/parameter.py | 16 +- python/ray/_private/runtime_env/__init__.py | 3 - python/ray/_private/runtime_env/conda.py | 4 +- .../ray/_private/runtime_env/conda_utils.py | 15 + python/ray/_private/runtime_env/context.py | 13 +- python/ray/_private/runtime_env/plugin.py | 70 +++ python/ray/_private/runtime_env/validation.py | 458 +++++++++++------ .../ray/_private/runtime_env/working_dir.py | 2 +- python/ray/_private/services.py | 51 +- python/ray/_private/test_utils.py | 68 +-- python/ray/_private/tls_utils.py | 85 ++++ python/ray/_private/utils.py | 28 +- python/ray/_raylet.pxd | 2 +- python/ray/_raylet.pyx | 142 +++--- python/ray/actor.py | 68 +-- python/ray/autoscaler/_private/autoscaler.py | 12 +- python/ray/autoscaler/_private/docker.py | 2 +- .../_private/fake_multi_node/__init__.py | 0 .../_private/fake_multi_node/example.yaml | 55 +++ .../_private/fake_multi_node/node_provider.py | 114 +++++ python/ray/autoscaler/_private/gcp/node.py | 20 +- python/ray/autoscaler/_private/monitor.py | 12 +- .../ray/autoscaler/_private/node_launcher.py | 5 +- python/ray/autoscaler/_private/providers.py | 8 + .../_private/resource_demand_scheduler.py | 27 +- python/ray/autoscaler/gcp/tpu.yaml | 18 +- python/ray/autoscaler/node_provider.py | 12 + python/ray/autoscaler/ray-schema.json | 2 +- python/ray/cluster_utils.py | 69 +++ python/ray/cross_language.py | 3 +- python/ray/data/__init__.py | 7 +- python/ray/data/block.py | 24 +- python/ray/data/dataset.py | 390 +++++++++++---- python/ray/data/dataset_pipeline.py | 188 ++++++- python/ray/data/datasource/__init__.py | 4 +- python/ray/data/datasource/datasource.py | 63 ++- .../data/datasource/file_based_datasource.py | 57 ++- .../ray/data/datasource/numpy_datasource.py | 13 +- python/ray/data/examples/demo_infer.py | 2 +- .../ray/data/extensions/tensor_extension.py | 8 +- python/ray/data/impl/arrow_block.py | 20 +- python/ray/data/impl/block_list.py | 7 + python/ray/data/impl/compute.py | 23 +- python/ray/data/impl/lazy_block_list.py | 57 ++- python/ray/data/impl/pipeline_executor.py | 15 +- python/ray/data/impl/progress_bar.py | 12 +- python/ray/data/impl/remote_fn.py | 5 +- python/ray/data/impl/simple_block.py | 4 +- python/ray/data/impl/tensor_block.py | 80 --- python/ray/data/read_api.py | 77 ++- python/ray/data/tests/test_dataset.py | 466 ++++++++++++------ .../ray/data/tests/test_dataset_pipeline.py | 89 +++- python/ray/data/tests/test_raydp_dataset.py | 4 + python/ray/exceptions.py | 6 +- python/ray/experimental/array/remote/core.py | 4 +- python/ray/experimental/internal_kv.py | 12 +- python/ray/experimental/raysort/constants.py | 11 +- python/ray/experimental/raysort/main.py | 369 +++++++++----- python/ray/experimental/raysort/sortlib.py | 8 +- .../ray/experimental/raysort/tracing_utils.py | 127 ++++- python/ray/experimental/raysort/types.py | 12 +- python/ray/includes/common.pxd | 6 +- python/ray/includes/libcoreworker.pxd | 5 +- python/ray/job_config.py | 61 +-- python/ray/node.py | 31 +- python/ray/remote_function.py | 77 +-- python/ray/runtime_context.py | 16 +- python/ray/scripts/scripts.py | 27 +- python/ray/serialization.py | 2 +- python/ray/serve/BUILD | 10 +- python/ray/serve/api.py | 154 ++++-- python/ray/serve/autoscaling_metrics.py | 5 +- python/ray/serve/autoscaling_policy.py | 1 - python/ray/serve/backend_state.py | 77 ++- python/ray/serve/common.py | 5 +- python/ray/serve/config.py | 38 +- python/ray/serve/controller.py | 139 ++++-- python/ray/serve/endpoint_state.py | 1 - python/ray/serve/examples/doc/conda_env.py | 23 +- python/ray/serve/handle.py | 15 +- python/ray/serve/http_proxy.py | 7 +- python/ray/serve/long_poll.py | 26 +- .../serve/{backend_worker.py => replica.py} | 54 +- python/ray/serve/router.py | 10 +- python/ray/serve/storage/checkpoint_path.py | 4 +- python/ray/serve/storage/kv_store.py | 4 +- python/ray/serve/tests/conftest.py | 7 + python/ray/serve/tests/test_advanced.py | 12 +- .../serve/tests/test_autoscaling_metrics.py | 12 +- .../serve/tests/test_autoscaling_policy.py | 52 ++ python/ray/serve/tests/test_backend_state.py | 77 +-- python/ray/serve/tests/test_config.py | 2 + python/ray/serve/tests/test_deploy.py | 57 ++- python/ray/serve/tests/test_get_deployment.py | 31 ++ python/ray/serve/tests/test_handle.py | 27 +- python/ray/serve/tests/test_long_poll.py | 14 + python/ray/serve/tests/test_ray_client.py | 5 +- python/ray/serve/tests/test_regression.py | 2 +- python/ray/serve/tests/test_standalone.py | 9 +- python/ray/sgd/__init__.py | 3 +- python/ray/sgd/callbacks.py | 1 + python/ray/state.py | 14 +- python/ray/tests/BUILD | 13 +- python/ray/tests/client_test_utils.py | 17 + python/ray/tests/mock_setup_worker.py | 3 + python/ray/tests/test_advanced.py | 5 +- python/ray/tests/test_advanced_3.py | 13 +- python/ray/tests/test_autoscaler.py | 61 ++- .../tests/test_autoscaler_fake_multinode.py | 58 +++ python/ray/tests/test_autoscaler_yaml.py | 3 + python/ray/tests/test_basic.py | 15 +- python/ray/tests/test_basic_3.py | 11 +- python/ray/tests/test_client.py | 59 ++- python/ray/tests/test_client_compat.py | 33 ++ .../tests/test_client_library_integration.py | 8 +- python/ray/tests/test_client_proxy.py | 10 +- python/ray/tests/test_client_reconnect.py | 9 +- python/ray/tests/test_dashboard.py | 52 +- python/ray/tests/test_distributed_sort.py | 19 +- python/ray/tests/test_failure_2.py | 3 +- python/ray/tests/test_global_state.py | 10 +- python/ray/tests/test_multi_tenancy.py | 12 +- python/ray/tests/test_object_manager.py | 48 +- python/ray/tests/test_output.py | 8 +- python/ray/tests/test_placement_group.py | 7 + python/ray/tests/test_placement_group_3.py | 35 ++ python/ray/tests/test_ray_debugger.py | 7 +- python/ray/tests/test_ray_init.py | 40 ++ .../tests/test_resource_demand_scheduler.py | 130 +++-- python/ray/tests/test_runtime_context.py | 115 +++++ python/ray/tests/test_runtime_env.py | 68 +-- .../ray/tests/test_runtime_env_complicated.py | 137 +++-- python/ray/tests/test_runtime_env_env_vars.py | 244 +++------ python/ray/tests/test_runtime_env_plugin.py | 75 +++ .../ray/tests/test_runtime_env_validation.py | 379 ++++++++++++++ python/ray/tests/test_scheduling.py | 99 +++- python/ray/tests/test_tls_auth.py | 66 ++- python/ray/tests/test_traceback.py | 39 ++ python/ray/tune/BUILD | 2 +- .../ray/tune/analysis/experiment_analysis.py | 22 +- python/ray/tune/commands.py | 7 +- python/ray/tune/durable_trainable.py | 14 +- python/ray/tune/function_runner.py | 11 +- python/ray/tune/logger.py | 17 +- python/ray/tune/progress_reporter.py | 96 +++- python/ray/tune/ray_trial_executor.py | 27 +- python/ray/tune/registry.py | 22 +- python/ray/tune/result.py | 4 + python/ray/tune/schedulers/hyperband.py | 2 + python/ray/tune/schedulers/trial_scheduler.py | 6 + python/ray/tune/suggest/bohb.py | 4 +- python/ray/tune/tests/test_api.py | 8 + python/ray/tune/tests/test_cluster.py | 2 +- python/ray/tune/tests/test_logger.py | 22 - .../ray/tune/tests/test_progress_reporter.py | 200 +++++--- .../ray/tune/tests/test_ray_trial_executor.py | 69 ++- python/ray/tune/tests/test_trial_runner_3.py | 3 +- .../tune/tests/test_trial_runner_callbacks.py | 2 +- python/ray/tune/tests/test_trial_scheduler.py | 1 + .../tune/tests/test_trial_scheduler_pbt.py | 12 +- python/ray/tune/trainable.py | 59 ++- python/ray/tune/trial.py | 65 ++- python/ray/tune/trial_runner.py | 104 +++- python/ray/tune/tune.py | 82 ++- python/ray/tune/utils/util.py | 9 +- python/ray/util/__init__.py | 2 +- python/ray/util/client/__init__.py | 3 +- python/ray/util/client/client_pickler.py | 15 +- python/ray/util/client/options.py | 1 - python/ray/util/client/server/proxier.py | 10 +- python/ray/util/client/server/server.py | 2 +- python/ray/util/client/worker.py | 20 +- python/ray/util/dask/scheduler_utils.py | 5 +- python/ray/util/placement_group.py | 4 +- python/ray/util/sgd/torch/torch_runner.py | 14 +- .../ray/util/sgd/torch/training_operator.py | 42 +- python/ray/util/sgd/v2/BUILD | 27 + python/ray/util/sgd/v2/__init__.py | 4 +- python/ray/util/sgd/v2/backends/backend.py | 26 +- python/ray/util/sgd/v2/backends/horovod.py | 2 + python/ray/util/sgd/v2/backends/torch.py | 2 + python/ray/util/sgd/v2/constants.py | 4 + .../v2/examples/tensorflow_mnist_example.py | 4 +- .../tune_cifar_pytorch_pbt_example.py | 200 ++++++++ python/ray/util/sgd/v2/tests/test_backend.py | 4 + python/ray/util/sgd/v2/tests/test_gpu.py | 92 ++++ python/ray/util/sgd/v2/tests/test_trainer.py | 82 +-- python/ray/util/sgd/v2/tests/test_tune.py | 31 +- python/ray/util/tracing/tracing_helper.py | 7 +- python/ray/worker.py | 47 +- python/ray/workers/setup_worker.py | 10 +- python/ray/workflow/common.py | 3 +- python/ray/workflow/execution.py | 7 +- python/ray/workflow/recovery.py | 29 +- python/ray/workflow/step_executor.py | 92 ++-- .../workflow/tests/test_basic_workflows_2.py | 43 +- python/ray/workflow/tests/test_lifetime.py | 26 +- python/ray/workflow/workflow_access.py | 4 +- python/ray/workflow/workflow_context.py | 109 +++- python/ray/workflow/workflow_storage.py | 27 +- python/requirements.txt | 5 +- python/requirements/ml/requirements_rllib.txt | 4 +- python/requirements/requirements_default.txt | 2 +- python/requirements_linters.txt | 1 + python/setup.py | 29 +- release/.buildkite/build_pipeline.py | 1 + release/RELEASE_CHECKLIST.md | 1 + release/RELEASE_PROCESS.rst | 3 + release/alerts/xgboost_tests.py | 4 +- release/e2e.py | 207 +++++--- .../dask_xgboost_app_config.yaml | 5 +- .../golden_notebook_tests.yaml | 21 +- .../modin_xgboost_app_config.yaml | 5 +- .../workloads/dask_xgboost_test.py | 123 +---- .../workloads/modin_xgboost_test.py | 119 +---- .../workloads/torch_tune_serve_test.py | 4 +- .../golden_notebook_tests/workloads/util.py | 49 ++ .../workloads/utils/utils.py | 5 - release/kubernetes_manual_tests/README.md | 25 + release/kubernetes_manual_tests/helm-test.sh | 8 + .../kubernetes_manual_tests/k8s-test-scale.sh | 11 + release/kubernetes_manual_tests/k8s-test.sh | 9 + .../k8s_release_tests.sh | 30 ++ release/long_running_tests/tpl_cpu_1.yaml | 5 + .../large_scale_dask_on_ray_app_config.yaml | 1 - release/nightly_tests/dataset/app_config.yaml | 1 - .../dataset/dataset_shuffle_data_loader.py | 2 +- .../dataset/pipelined_ingestion_app.yaml | 1 - .../dataset/pipelined_training.py | 4 +- .../dataset/pipelined_training_app.yaml | 1 - .../dataset/shuffle_app_config.yaml | 1 - .../decision_tree_app_config.yaml | 1 - .../many_nodes_tests/app_config.yaml | 2 +- release/nightly_tests/nightly_tests.yaml | 25 +- .../placement_group_tests/app_config.yaml | 12 + .../placement_group_tests/cluster.py | 13 + .../placement_group_tests/compute.yaml | 27 + .../placement_group_tests/pg_run.py | 65 +++ .../shuffle/shuffle_app_config.yaml | 2 - .../shuffle_data_loader_app_config.yaml | 1 - .../stress_tests/stress_tests_app_config.yaml | 1 - .../1.7.0/benchmarks/many_actors.txt | 10 + .../1.7.0/benchmarks/many_nodes.txt | 10 + .../1.7.0/benchmarks/many_pgs.txt | 10 + .../1.7.0/benchmarks/many_tasks.txt | 10 + release/release_logs/1.7.0/microbenchmark.txt | 134 +++++ .../1.7.0/scalability/object_store.txt | 10 + .../1.7.0/scalability/single_node.txt | 16 + .../1.7.0/stress_tests/dead_actors.txt | 11 + .../1.7.0/stress_tests/many_tasks.txt | 19 + .../1.7.0/stress_tests/placement_group.txt | 9 + release/serve_tests/serve_tests.yaml | 15 + .../serve_cluster_fault_tolerance.py | 119 +++++ .../workloads/serve_test_cluster_utils.py | 25 +- release/util/pip_download_test.sh | 2 +- rllib/BUILD | 89 +++- rllib/agents/a3c/a3c_tf_policy.py | 2 +- rllib/agents/a3c/a3c_torch_policy.py | 18 +- rllib/agents/a3c/tests/test_a2c.py | 11 +- rllib/agents/a3c/tests/test_a3c.py | 3 +- rllib/agents/ars/tests/test_ars.py | 10 +- rllib/agents/cql/cql.py | 3 +- rllib/agents/cql/cql_torch_policy.py | 67 +-- rllib/agents/cql/tests/test_cql.py | 11 +- rllib/agents/ddpg/ddpg_tf_model.py | 12 +- rllib/agents/ddpg/ddpg_tf_policy.py | 8 +- rllib/agents/ddpg/ddpg_torch_model.py | 12 +- rllib/agents/ddpg/ddpg_torch_policy.py | 37 +- rllib/agents/ddpg/tests/test_apex_ddpg.py | 6 +- rllib/agents/ddpg/tests/test_ddpg.py | 8 +- rllib/agents/ddpg/tests/test_td3.py | 3 +- rllib/agents/dqn/apex.py | 3 +- rllib/agents/dqn/dqn.py | 16 +- rllib/agents/dqn/dqn_torch_policy.py | 46 +- rllib/agents/dqn/learner_thread.py | 24 +- rllib/agents/dqn/r2d2.py | 14 +- rllib/agents/dqn/r2d2_tf_policy.py | 6 +- rllib/agents/dqn/r2d2_torch_policy.py | 44 +- rllib/agents/dqn/simple_q_tf_policy.py | 2 +- rllib/agents/dqn/simple_q_torch_policy.py | 17 +- rllib/agents/dqn/tests/test_apex_dqn.py | 15 +- rllib/agents/dqn/tests/test_dqn.py | 4 +- rllib/agents/dqn/tests/test_r2d2.py | 3 +- rllib/agents/dqn/tests/test_simple_q.py | 3 +- rllib/agents/dreamer/dreamer.py | 3 +- rllib/agents/impala/tests/test_impala.py | 12 +- rllib/agents/impala/vtrace_tf_policy.py | 26 +- rllib/agents/impala/vtrace_torch_policy.py | 45 +- rllib/agents/maml/maml.py | 17 +- rllib/agents/maml/tests/test_maml.py | 6 +- rllib/agents/marwil/tests/test_bc.py | 8 +- rllib/agents/marwil/tests/test_marwil.py | 8 +- rllib/agents/mbmpo/mbmpo.py | 17 +- rllib/agents/mbmpo/tests/test_mbmpo.py | 8 +- rllib/agents/pg/pg_torch_policy.py | 14 +- rllib/agents/pg/tests/test_pg.py | 10 +- rllib/agents/ppo/appo_tf_policy.py | 2 +- rllib/agents/ppo/appo_torch_policy.py | 47 +- rllib/agents/ppo/ddppo.py | 15 +- rllib/agents/ppo/ppo.py | 12 +- rllib/agents/ppo/ppo_torch_policy.py | 33 +- rllib/agents/ppo/tests/test_appo.py | 16 +- rllib/agents/ppo/tests/test_ddppo.py | 26 +- rllib/agents/ppo/tests/test_ppo.py | 29 +- rllib/agents/qmix/qmix_policy.py | 2 +- rllib/agents/sac/rnnsac.py | 7 - rllib/agents/sac/rnnsac_torch_policy.py | 32 +- rllib/agents/sac/sac_tf_model.py | 10 +- rllib/agents/sac/sac_tf_policy.py | 8 +- rllib/agents/sac/sac_torch_model.py | 8 +- rllib/agents/sac/sac_torch_policy.py | 62 +-- rllib/agents/sac/tests/test_rnnsac.py | 73 +++ rllib/agents/sac/tests/test_sac.py | 38 +- rllib/agents/tests/test_trainer.py | 3 +- rllib/agents/trainer.py | 299 +++++++---- .../alpha_zero/core/alpha_zero_policy.py | 7 +- rllib/contrib/bandits/agents/policy.py | 2 +- .../bandits/examples/LinTS_train_wheel_env.py | 3 +- rllib/contrib/maddpg/maddpg_policy.py | 2 +- rllib/contrib/sumo/connector.py | 5 +- rllib/env/base_env.py | 7 +- rllib/env/multi_agent_env.py | 3 +- rllib/env/policy_server_input.py | 16 +- rllib/env/remote_vector_env.py | 20 +- rllib/env/tests/test_local_inference.sh | 42 -- .../tests/test_policy_client_server_setup.sh | 63 +++ rllib/env/tests/test_remote_inference.sh | 41 -- rllib/env/tests/test_remote_worker_envs.py | 98 ++++ rllib/env/wrappers/unity3d_env.py | 16 +- .../collectors/simple_list_collector.py | 7 +- rllib/evaluation/metrics.py | 21 +- rllib/evaluation/rollout_worker.py | 30 +- rllib/examples/centralized_critic.py | 2 +- rllib/examples/custom_keras_model.py | 5 +- .../examples/custom_model_loss_and_metrics.py | 12 +- rllib/examples/deterministic_training.py | 6 +- .../env/coin_game_non_vectorized_env.py | 11 +- .../examples/env/coin_game_vectorized_env.py | 9 +- .../env/matrix_sequential_social_dilemma.py | 6 +- rllib/examples/env/random_env.py | 26 +- rllib/examples/pettingzoo_env.py | 18 +- .../remote_vector_env_with_custom_api.py | 3 +- .../rock_paper_scissors_multiagent.py | 8 +- rllib/examples/serving/cartpole_client.py | 2 +- rllib/examples/serving/unity3d_client.py | 14 +- .../examples/serving/unity3d_dummy_client.py | 144 ++++++ rllib/examples/serving/unity3d_server.py | 70 ++- rllib/examples/trajectory_view_api.py | 50 +- rllib/execution/common.py | 3 - rllib/execution/learner_thread.py | 25 +- rllib/execution/multi_gpu_learner_thread.py | 68 ++- rllib/execution/rollout_ops.py | 24 +- rllib/execution/train_ops.py | 79 ++- rllib/models/tests/test_preprocessors.py | 6 +- rllib/models/tf/complex_input_net.py | 8 +- rllib/models/torch/complex_input_net.py | 9 +- rllib/models/torch/torch_modelv2.py | 8 + rllib/policy/eager_tf_policy.py | 118 +++-- rllib/policy/policy.py | 114 ++--- rllib/policy/policy_template.py | 3 +- rllib/policy/sample_batch.py | 7 +- .../tests/test_compute_log_likelihoods.py | 2 +- rllib/policy/tf_policy.py | 31 +- rllib/policy/tf_policy_template.py | 9 +- rllib/policy/torch_policy.py | 38 +- rllib/tests/test_exec_api.py | 3 +- rllib/tests/test_supported_multi_agent.py | 26 +- rllib/tests/test_supported_spaces.py | 8 +- rllib/utils/__init__.py | 3 +- .../utils/exploration/stochastic_sampling.py | 20 +- rllib/utils/metrics/__init__.py | 0 rllib/utils/metrics/learner_info.py | 84 ++++ rllib/utils/multi_agent.py | 21 +- rllib/utils/sgd.py | 55 +-- rllib/utils/test_utils.py | 316 +++++++++--- rllib/utils/tf_ops.py | 2 +- rllib/utils/tf_run_builder.py | 5 +- rllib/utils/torch_ops.py | 6 +- .../ray/gcs/gcs_server/gcs_node_manager.h | 1 + .../gcs_placement_group_scheduler.h | 65 +-- .../ray/gcs/gcs_server/gcs_resource_manager.h | 1 + src/mock/ray/gcs/pubsub/gcs_pub_sub.h | 27 + .../gcs/store_client/in_memory_store_client.h | 66 +++ .../ray/gcs/store_client/redis_store_client.h | 67 +++ src/mock/ray/gcs/store_client/store_client.h | 66 +++ src/mock/ray/pubsub/publisher.h | 100 ++++ src/mock/ray/pubsub/subscriber.h | 155 ++++++ src/mock/ray/raylet/node_manager.h | 5 + .../cluster_task_manager_interface.h | 2 - src/mock/ray/raylet_client/raylet_client.h | 47 +- src/mock/ray/rpc/worker/core_worker_client.h | 123 +++++ .../ray/rpc/worker/core_worker_client_pool.h | 23 + src/ray/common/bundle_spec.cc | 25 +- src/ray/common/bundle_spec.h | 6 + src/ray/common/client_connection.cc | 2 +- src/ray/common/constants.h | 3 + src/ray/common/id.h | 1 + src/ray/common/network_util.h | 2 +- src/ray/common/ray_config_def.h | 37 +- src/ray/common/ray_internal_flag_def.h | 3 + src/ray/common/runtime_env_manager.cc | 16 +- src/ray/common/runtime_env_manager.h | 7 +- src/ray/common/task/task.cc | 9 +- src/ray/common/task/task.h | 10 +- src/ray/common/task/task_spec.cc | 32 +- src/ray/common/task/task_spec.h | 18 +- src/ray/common/task/task_util.h | 11 +- src/ray/core_worker/common.h | 24 +- src/ray/core_worker/context.cc | 15 +- src/ray/core_worker/context.h | 9 +- src/ray/core_worker/core_worker.cc | 326 ++++++------ src/ray/core_worker/core_worker.h | 105 ++-- ...io_ray_runtime_object_NativeObjectStore.cc | 5 +- .../io_ray_runtime_object_NativeObjectStore.h | 7 +- ...io_ray_runtime_task_NativeTaskSubmitter.cc | 6 +- src/ray/core_worker/reference_count.h | 2 - src/ray/core_worker/reference_count_test.cc | 6 +- .../memory_store/memory_store.cc | 40 +- .../memory_store/memory_store.h | 16 - src/ray/core_worker/test/core_worker_test.cc | 2 +- .../test/direct_task_transport_mock_test.cc | 4 +- .../test/direct_task_transport_test.cc | 358 +++++++++++--- src/ray/core_worker/test/memory_store_test.cc | 11 +- .../transport/dependency_resolver.cc | 6 +- .../transport/direct_task_transport.cc | 93 +++- .../transport/direct_task_transport.h | 40 +- src/ray/gcs/asio.h | 2 +- .../gcs/gcs_client/service_based_accessor.cc | 2 +- .../test/global_state_accessor_test.cc | 1 + .../test/service_based_gcs_client_test.cc | 1 + .../gcs/gcs_server/gcs_actor_distribution.cc | 30 ++ .../gcs/gcs_server/gcs_actor_distribution.h | 17 + src/ray/gcs/gcs_server/gcs_actor_manager.cc | 67 ++- src/ray/gcs/gcs_server/gcs_actor_manager.h | 13 +- src/ray/gcs/gcs_server/gcs_actor_scheduler.cc | 6 +- src/ray/gcs/gcs_server/gcs_actor_scheduler.h | 12 +- src/ray/gcs/gcs_server/gcs_node_manager.cc | 8 +- .../gcs_server/gcs_placement_group_manager.cc | 138 ++++-- .../gcs_server/gcs_placement_group_manager.h | 52 +- .../gcs_placement_group_scheduler.cc | 28 +- .../gcs_placement_group_scheduler.h | 42 +- .../gcs/gcs_server/gcs_resource_manager.cc | 7 +- src/ray/gcs/gcs_server/gcs_server.cc | 19 +- src/ray/gcs/gcs_server/gcs_server.h | 1 - src/ray/gcs/gcs_server/gcs_server_main.cc | 5 +- src/ray/gcs/gcs_server/gcs_table_storage.h | 130 +++-- .../gcs_server/test/gcs_actor_manager_test.cc | 15 +- .../test/gcs_actor_scheduler_mock_test.cc | 139 ++++++ .../test/gcs_based_actor_scheduler_test.cc | 18 +- .../gcs_placement_group_manager_mock_test.cc | 174 +++++++ .../test/gcs_placement_group_manager_test.cc | 47 +- .../gcs_server/test/gcs_server_rpc_test.cc | 1 + .../gcs_server/test/gcs_server_test_util.h | 14 +- src/ray/gcs/pubsub/gcs_pub_sub.h | 3 + src/ray/gcs/redis_context.h | 2 +- src/ray/object_manager/object_buffer_pool.cc | 165 +++++-- src/ray/object_manager/object_buffer_pool.h | 59 ++- src/ray/object_manager/object_manager.cc | 24 +- src/ray/object_manager/object_manager.h | 6 +- src/ray/object_manager/plasma/store.cc | 7 +- src/ray/object_manager/pull_manager.h | 2 +- src/ray/protobuf/agent_manager.proto | 1 + src/ray/protobuf/common.proto | 23 +- src/ray/protobuf/core_worker.proto | 1 + src/ray/protobuf/event.proto | 1 + src/ray/protobuf/gcs.proto | 31 +- src/ray/protobuf/gcs_service.proto | 2 +- src/ray/protobuf/job_agent.proto | 1 + src/ray/protobuf/node_manager.proto | 21 + src/ray/protobuf/object_manager.proto | 1 + src/ray/protobuf/pubsub.proto | 1 + src/ray/protobuf/ray_client.proto | 5 +- src/ray/protobuf/reporter.proto | 1 + src/ray/protobuf/runtime_env_agent.proto | 5 + src/ray/protobuf/serialization.proto | 1 + src/ray/protobuf/serve.proto | 47 +- src/ray/ray_version_script.lds | 1 - src/ray/raylet/agent_manager.cc | 54 +- src/ray/raylet/agent_manager.h | 9 +- src/ray/raylet/main.cc | 3 +- src/ray/raylet/node_manager.cc | 40 +- src/ray/raylet/node_manager.h | 13 +- .../placement_group_resource_manager.cc | 3 + src/ray/raylet/raylet.cc | 7 +- .../scheduling/cluster_resource_data.cc | 5 +- .../raylet/scheduling/cluster_resource_data.h | 4 +- .../scheduling/cluster_resource_scheduler.cc | 3 +- .../raylet/scheduling/cluster_task_manager.cc | 103 ++-- .../raylet/scheduling/cluster_task_manager.h | 22 +- .../cluster_task_manager_interface.h | 21 +- .../scheduling/cluster_task_manager_test.cc | 151 ++++-- src/ray/raylet/scheduling/fixed_point.cc | 96 ---- src/ray/raylet/scheduling/fixed_point.h | 115 +++-- .../raylet/scheduling/scheduling_policy.cc | 20 +- src/ray/raylet/scheduling/scheduling_policy.h | 13 +- .../scheduling/scheduling_policy_test.cc | 36 ++ src/ray/raylet/worker.cc | 2 +- src/ray/raylet/worker_pool.cc | 124 +++-- src/ray/raylet/worker_pool.h | 7 +- src/ray/raylet/worker_pool_test.cc | 14 +- src/ray/raylet_client/raylet_client.cc | 33 +- src/ray/raylet_client/raylet_client.h | 28 +- src/ray/rpc/common.cc | 6 +- src/ray/rpc/common.h | 2 + src/ray/rpc/grpc_server.cc | 9 +- src/ray/rpc/grpc_server.h | 11 +- .../rpc/node_manager/node_manager_client.h | 3 + .../rpc/node_manager/node_manager_server.h | 5 + src/ray/rpc/server_call.h | 17 +- src/ray/rpc/test/grpc_server_client_test.cc | 18 +- src/ray/util/event.h | 10 +- src/ray/util/util.h | 44 +- src/ray/util/util_test.cc | 17 + streaming/src/queue/queue_handler.h | 2 +- streaming/src/test/mock_actor.cc | 2 +- .../patches/prometheus-windows-pollfd.patch | 37 +- ...les_boost-undefine-boost_fallthrough.patch | 8 - .../rules_boost-windows-linkopts.patch | 21 +- 650 files changed, 16804 insertions(+), 5489 deletions(-) create mode 100644 .buildkite/pipeline.gpu.large.yml create mode 100644 doc/examples/dask_xgboost/README.rst create mode 100644 doc/examples/dask_xgboost/dask_xgboost.py create mode 100644 doc/examples/dask_xgboost/dask_xgboost.yaml create mode 100644 doc/examples/modin_xgboost/README.rst create mode 100644 doc/examples/modin_xgboost/modin_xgboost.py create mode 100644 doc/examples/modin_xgboost/modin_xgboost.yaml create mode 100644 doc/source/data/.gitignore create mode 100644 doc/source/data/_examples/README.rst create mode 100644 doc/source/data/_examples/big_data_ingestion.py create mode 100644 doc/source/data/big_data_ingestion.yaml create mode 100644 doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst create mode 100644 doc/source/raysgd/v2/migration-guide.rst create mode 100644 java/serve/src/main/java/io/ray/serve/DeploymentInfo.java create mode 100644 java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java create mode 100644 java/serve/src/main/java/io/ray/serve/HandleOptions.java create mode 100644 java/serve/src/main/java/io/ray/serve/HttpProxy.java create mode 100644 java/serve/src/main/java/io/ray/serve/ProxyActor.java create mode 100644 java/serve/src/main/java/io/ray/serve/ProxyRouter.java create mode 100644 java/serve/src/main/java/io/ray/serve/RayServeConfig.java create mode 100644 java/serve/src/main/java/io/ray/serve/RayServeHandle.java create mode 100644 java/serve/src/main/java/io/ray/serve/RayServeMetrics.java create mode 100644 java/serve/src/main/java/io/ray/serve/ReplicaSet.java create mode 100644 java/serve/src/main/java/io/ray/serve/Router.java create mode 100644 java/serve/src/main/java/io/ray/serve/ServeController.java create mode 100644 java/serve/src/main/java/io/ray/serve/ServeProxy.java create mode 100644 java/serve/src/main/java/io/ray/serve/api/Client.java delete mode 100644 java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java create mode 100644 java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java create mode 100644 java/serve/src/main/java/io/ray/serve/util/CommonUtil.java create mode 100644 java/serve/src/main/java/io/ray/serve/util/SocketUtil.java create mode 100644 java/serve/src/test/java/io/ray/serve/DummyServeController.java create mode 100644 java/serve/src/test/java/io/ray/serve/HttpProxyTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/ProxyActorTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/RouterTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/api/ClientTest.java create mode 100644 python/ray/_private/runtime_env/plugin.py create mode 100644 python/ray/_private/tls_utils.py create mode 100644 python/ray/autoscaler/_private/fake_multi_node/__init__.py create mode 100644 python/ray/autoscaler/_private/fake_multi_node/example.yaml create mode 100644 python/ray/autoscaler/_private/fake_multi_node/node_provider.py delete mode 100644 python/ray/data/impl/tensor_block.py rename python/ray/serve/{backend_worker.py => replica.py} (91%) create mode 100644 python/ray/sgd/callbacks.py create mode 100644 python/ray/tests/test_autoscaler_fake_multinode.py create mode 100644 python/ray/tests/test_client_compat.py create mode 100644 python/ray/tests/test_runtime_env_plugin.py create mode 100644 python/ray/tests/test_runtime_env_validation.py create mode 100644 python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py create mode 100644 python/ray/util/sgd/v2/tests/test_gpu.py create mode 100644 release/golden_notebook_tests/workloads/util.py delete mode 100644 release/golden_notebook_tests/workloads/utils/utils.py create mode 100644 release/kubernetes_manual_tests/README.md create mode 100755 release/kubernetes_manual_tests/helm-test.sh create mode 100755 release/kubernetes_manual_tests/k8s-test-scale.sh create mode 100755 release/kubernetes_manual_tests/k8s-test.sh create mode 100644 release/kubernetes_manual_tests/k8s_release_tests.sh create mode 100644 release/nightly_tests/placement_group_tests/app_config.yaml create mode 100644 release/nightly_tests/placement_group_tests/cluster.py create mode 100644 release/nightly_tests/placement_group_tests/compute.yaml create mode 100644 release/nightly_tests/placement_group_tests/pg_run.py create mode 100644 release/release_logs/1.7.0/benchmarks/many_actors.txt create mode 100644 release/release_logs/1.7.0/benchmarks/many_nodes.txt create mode 100644 release/release_logs/1.7.0/benchmarks/many_pgs.txt create mode 100644 release/release_logs/1.7.0/benchmarks/many_tasks.txt create mode 100644 release/release_logs/1.7.0/microbenchmark.txt create mode 100644 release/release_logs/1.7.0/scalability/object_store.txt create mode 100644 release/release_logs/1.7.0/scalability/single_node.txt create mode 100644 release/release_logs/1.7.0/stress_tests/dead_actors.txt create mode 100644 release/release_logs/1.7.0/stress_tests/many_tasks.txt create mode 100644 release/release_logs/1.7.0/stress_tests/placement_group.txt create mode 100644 release/serve_tests/workloads/serve_cluster_fault_tolerance.py create mode 100644 rllib/agents/sac/tests/test_rnnsac.py delete mode 100755 rllib/env/tests/test_local_inference.sh create mode 100755 rllib/env/tests/test_policy_client_server_setup.sh delete mode 100755 rllib/env/tests/test_remote_inference.sh create mode 100644 rllib/env/tests/test_remote_worker_envs.py create mode 100644 rllib/examples/serving/unity3d_dummy_client.py create mode 100644 rllib/utils/metrics/__init__.py create mode 100644 rllib/utils/metrics/learner_info.py create mode 100644 src/mock/ray/gcs/pubsub/gcs_pub_sub.h create mode 100644 src/mock/ray/gcs/store_client/in_memory_store_client.h create mode 100644 src/mock/ray/gcs/store_client/redis_store_client.h create mode 100644 src/mock/ray/gcs/store_client/store_client.h create mode 100644 src/mock/ray/pubsub/publisher.h create mode 100644 src/mock/ray/pubsub/subscriber.h create mode 100644 src/mock/ray/rpc/worker/core_worker_client.h create mode 100644 src/mock/ray/rpc/worker/core_worker_client_pool.h create mode 100644 src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc create mode 100644 src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc delete mode 100644 src/ray/raylet/scheduling/fixed_point.cc delete mode 100644 thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch diff --git a/.bazelrc b/.bazelrc index a6ebeba272c0f..2e4e7b36d10f9 100644 --- a/.bazelrc +++ b/.bazelrc @@ -14,12 +14,13 @@ build:macos --copt="-g1" build:linux --cxxopt="-std=c++17" build:macos --cxxopt="-std=c++17" build:clang-cl --cxxopt="-std=c++17" -build:msvc --cxxopt="/std:c++17" +build:msvc-cl --cxxopt="/std:c++17" +build:windows --cxxopt="/std:c++17" # This workaround is needed to prevent Bazel from compiling the same file twice (once PIC and once not). build:linux --force_pic build:macos --force_pic build:clang-cl --compiler=clang-cl -build:msvc --compiler=msvc-cl +build:msvc-cl --compiler=msvc-cl # `LC_ALL` and `LANG` is needed for cpp worker tests, because they will call "ray start". # If we don't add them, python's `click` library will raise an error. build --action_env=LC_ALL @@ -38,7 +39,7 @@ build:windows --enable_runfiles build:linux --per_file_copt="-\\.(asm|S)$@-Werror" build:macos --per_file_copt="-\\.(asm|S)$@-Werror" build:clang-cl --per_file_copt="-\\.(asm|S)$@-Werror" -build:msvc --per_file_copt="-\\.(asm|S)$@-WX" +build:msvc-cl --per_file_copt="-\\.(asm|S)$@-WX" # Ignore warnings for protobuf generated files and external projects. build --per_file_copt="\\.pb\\.cc$@-w" build --per_file_copt="-\\.(asm|S)$,external/.*@-w" @@ -51,7 +52,7 @@ build --per_file_copt="-\\.(asm|S)$,external/com_github_grpc_grpc/.*@-DGRPC_BAZE # Don't generate warnings about kernel features we don't need https://github.com/ray-project/ray/issues/6832 build:linux --per_file_copt="-\\.(asm|S)$,external/com_github_grpc_grpc/.*@-DGPR_MANYLINUX1" # Ignore wchar_t -> char conversion warning on MSVC -build:msvc --per_file_copt="external/boost/libs/regex/src/wc_regex_traits\\.cpp@-wd4244" +build:msvc-cl --per_file_copt="external/boost/libs/regex/src/wc_regex_traits\\.cpp@-wd4244" build --http_timeout_scaling=5.0 build --verbose_failures build:iwyu --experimental_action_listener=//:iwyu_cpp @@ -177,6 +178,7 @@ build:debug --strip="never" # Undefined Behavior Sanitizer build:ubsan --strip=never build:ubsan --copt -fsanitize=undefined +build:ubsan --copt -fno-sanitize=vptr build:ubsan --copt -fno-sanitize-recover=all build:ubsan --copt -g build:ubsan --linkopt -fsanitize=undefined diff --git a/.buildkite/pipeline.gpu.large.yml b/.buildkite/pipeline.gpu.large.yml new file mode 100644 index 0000000000000..0bdbca8846841 --- /dev/null +++ b/.buildkite/pipeline.gpu.large.yml @@ -0,0 +1,8 @@ +- label: ":tv: :octopus: SGD GPU tests " + conditions: ["RAY_CI_SGD_AFFECTED"] + commands: + - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT + - SGD_TESTING=1 INSTALL_HOROVOD=1 ./ci/travis/install-dependencies.sh + - pip install -Ur ./python/requirements_ml_docker.txt + - ./ci/travis/env_info.sh + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=gpu,gpu_only python/ray/util/sgd/... diff --git a/.buildkite/pipeline.gpu.yml b/.buildkite/pipeline.gpu.yml index 0c2c14ecf805f..e89aeaa9f2d63 100644 --- a/.buildkite/pipeline.gpu.yml +++ b/.buildkite/pipeline.gpu.yml @@ -1,3 +1,13 @@ +# Todo: Enable once tests are available +#- label: ":tv: :octopus: Tune GPU tests " +# conditions: ["RAY_CI_TUNE_AFFECTED"] +# commands: +# - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT +# - TUNE_TESTING=1 ./ci/travis/install-dependencies.sh +# - pip install -Ur ./python/requirements_ml_docker.txt +# - ./ci/travis/env_info.sh +# - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-flaky,gpu,gpu_only python/ray/tune/... + - label: ":tv: :brain: RLlib: GPU Examples {A/B}" conditions: ["RAY_CI_RLLIB_AFFECTED"] commands: diff --git a/.buildkite/pipeline.macos.yml b/.buildkite/pipeline.macos.yml index 592347d44007c..e3ba9347c7cc7 100644 --- a/.buildkite/pipeline.macos.yml +++ b/.buildkite/pipeline.macos.yml @@ -64,7 +64,7 @@ steps: commands: - *prelude_commands - TORCH_VERSION=1.6 ./ci/travis/install-dependencies.sh - - bazel test --config=ci --test_env=CI $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-flaky,-flaky-mac -- + - bazel test --config=ci --test_env=CI $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-flaky,-flaky-mac,-post_wheel_build -- //:all python/ray/serve/... python/ray/dashboard/... -rllib/... -core_worker_test - *epilogue_commands @@ -82,7 +82,7 @@ steps: - bazel test $(./scripts/bazel_export_options) --config=ci --test_env=CONDA_EXE --test_env=CONDA_PYTHON_EXE --test_env=CONDA_SHLVL --test_env=CONDA_PREFIX --test_env=CONDA_DEFAULT_ENV --test_env=CONDA_PROMPT_MODIFIER --test_env=CI - --test_tag_filters=-kubernetes,-jenkins_only,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-flaky,-flaky-mac + --test_tag_filters=-kubernetes,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-flaky,-flaky-mac python/ray/tests/... - *epilogue_commands @@ -91,7 +91,7 @@ steps: commands: - *prelude_commands - bazel test --config=ci $(./scripts/bazel_export_options) --test_env=CI - --test_tag_filters=-kubernetes,-jenkins_only,medium_size_python_tests_a_to_j,-flaky,-flaky-mac + --test_tag_filters=-kubernetes,medium_size_python_tests_a_to_j,-flaky,-flaky-mac python/ray/tests/... - *epilogue_commands @@ -100,7 +100,7 @@ steps: commands: - *prelude_commands - bazel test --config=ci $(./scripts/bazel_export_options) --test_env=CI - --test_tag_filters=-kubernetes,-jenkins_only,medium_size_python_tests_k_to_z,-flaky,-flaky-mac + --test_tag_filters=-kubernetes,medium_size_python_tests_k_to_z,-flaky,-flaky-mac python/ray/tests/... - *epilogue_commands @@ -110,7 +110,7 @@ steps: - *prelude_commands - ./ci/travis/install-dependencies.sh - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,-jenkins_only,flaky,flaky-mac + --test_tag_filters=-kubernetes,flaky,flaky-mac --test_env=CONDA_EXE --test_env=CONDA_PYTHON_EXE --test_env=CONDA_SHLVL diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index c0f6ccda286df..2941476580fc8 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -182,7 +182,9 @@ - TORCH_VERSION=1.6 ./ci/travis/install-dependencies.sh - ./dashboard/tests/run_ui_tests.sh - bazel test --config=ci $(./scripts/bazel_export_options) python/ray/dashboard/... - - bazel test --config=ci $(./scripts/bazel_export_options) python/ray/serve/... + - bazel test --config=ci $(./scripts/bazel_export_options) + --test_tag_filters=-post_wheel_build + python/ray/serve/... - label: ":python: Minimal install" conditions: ["RAY_CI_PYTHON_AFFECTED"] @@ -208,7 +210,7 @@ # --test_tag_filters=flaky # -- //:all -rllib/... -core_worker_test - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,-jenkins_only,flaky + --test_tag_filters=-kubernetes,flaky --test_env=CONDA_EXE --test_env=CONDA_PYTHON_EXE --test_env=CONDA_SHLVL @@ -220,7 +222,7 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,-jenkins_only,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-client_tests,-flaky,-post_wheel_build,-worker-container + --test_tag_filters=-kubernetes,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-client_tests,-flaky,-post_wheel_build,-worker-container --test_env=CONDA_EXE --test_env=CONDA_PYTHON_EXE --test_env=CONDA_SHLVL @@ -228,7 +230,7 @@ --test_env=CONDA_DEFAULT_ENV python/ray/tests/... - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,-jenkins_only,client_tests,-flaky + --test_tag_filters=-kubernetes,client_tests,-flaky --test_env=RAY_CLIENT_MODE=1 --test_env=RAY_PROFILING=1 python/ray/tests/... - label: ":python: (Medium A-J)" @@ -236,14 +238,14 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,-jenkins_only,medium_size_python_tests_a_to_j,-flaky + --test_tag_filters=-kubernetes,medium_size_python_tests_a_to_j,-flaky python/ray/tests/... - label: ":python: (Medium K-Z)" conditions: ["RAY_CI_PYTHON_AFFECTED"] commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,-jenkins_only,medium_size_python_tests_k_to_z,-flaky + --test_tag_filters=-kubernetes,medium_size_python_tests_k_to_z,-flaky python/ray/tests/... - label: ":core: Debug Test" commands: @@ -251,7 +253,7 @@ - pip uninstall -y ray - RAY_DEBUG_BUILD=debug ./ci/travis/ci.sh build - bazel test --config=ci-debug $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,-jenkins_only,debug_tests,-flaky + --test_tag_filters=-kubernetes,debug_tests,-flaky python/ray/tests/... - label: ":core: (ASAN tests)" conditions: ["RAY_CI_PYTHON_AFFECTED"] @@ -260,7 +262,7 @@ - RLLIB_TESTING=1 ./ci/travis/install-dependencies.sh - bazel test --config=ci --config=asan $(./scripts/bazel_export_options) --config=asan-buildkite - --test_tag_filters=-kubernetes,-jenkins_only,asan_tests,-flaky + --test_tag_filters=-kubernetes,asan_tests,-flaky --test_env=CONDA_EXE --test_env=CONDA_PYTHON_EXE --test_env=CONDA_SHLVL @@ -462,16 +464,16 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - TUNE_TESTING=1 ./ci/travis/install-dependencies.sh - - bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,-example,-flaky,-py37,-soft_imports python/ray/tune/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=example,-tf,-pytorch,-py37,-flaky,-soft_imports python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-example,-flaky,-py37,-soft_imports,-gpu_only python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=example,-tf,-pytorch,-py37,-flaky,-soft_imports,-gpu_only python/ray/tune/... - label: ":octopus: Tune tests and examples {2/2}" conditions: ["RAY_CI_TUNE_AFFECTED"] commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - TUNE_TESTING=1 ./ci/travis/install-dependencies.sh - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-soft_imports python/ray/tune/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-soft_imports python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-soft_imports,-gpu_only python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-soft_imports,-gpu_only python/ray/tune/... - label: ":octopus: Tune soft imports test" conditions: ["RAY_CI_TUNE_AFFECTED"] @@ -486,10 +488,10 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - SGD_TESTING=1 INSTALL_HOROVOD=1 ./ci/travis/install-dependencies.sh - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-client python/ray/util/sgd/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-client python/ray/util/sgd/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=client_unit_tests --test_env=RAY_CLIENT_MODE=1 python/ray/util/sgd/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only python/ray/util/sgd/v2/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-client,-gpu_only python/ray/util/sgd/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-client,-gpu_only python/ray/util/sgd/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=client_unit_tests,-gpu_only --test_env=RAY_CLIENT_MODE=1 python/ray/util/sgd/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-gpu_only python/ray/util/sgd/v2/... - label: ":octopus: Tune/SGD/Modin/Dask tests and examples. Python 3.7" conditions: ["RAY_CI_TUNE_AFFECTED", "RAY_CI_SGD_AFFECTED"] diff --git a/.buildkite/windows/install/bazel.ps1 b/.buildkite/windows/install/bazel.ps1 index 46411cf3810f3..adeee13df7209 100644 --- a/.buildkite/windows/install/bazel.ps1 +++ b/.buildkite/windows/install/bazel.ps1 @@ -1,4 +1,4 @@ -$Env:BAZEL_URL="https://github.com/bazelbuild/bazel/releases/download/3.2.0/bazel-3.2.0-windows-x86_64.zip" +$Env:BAZEL_URL="https://github.com/bazelbuild/bazel/releases/download/4.2.1/bazel-4.2.1-windows-x86_64.zip" Write-Host ('Downloading {0} ...' -f $env:BAZEL_URL); [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; Invoke-WebRequest -Uri $env:BAZEL_URL -OutFile 'bazel.zip'; diff --git a/.clang-tidy b/.clang-tidy index 2aa176da910cc..607f19902f3f4 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,27 +1,64 @@ -# Disable the following checks due to frequent false positives, noisiness, -# inconsistent style with existing codebase and other reasons: +# Disable the following checks with reasons in parenthesis: +# +# -bugprone-macro-parentheses (inconsistent style) +# -google-readability-todo (potentially too restrictive) # -misc-non-private-member-variables-in-classes (potentially too restrictive) # -misc-unused-parameters (can be cleaned up in batch and enabled) # -modernize-avoid-c-arrays (too restrictive) +# -modernize-concat-nested-namespaces (inconsistent style) # -modernize-pass-by-value (too restrictive) # -modernize-return-braced-init-list (inconsistent style) # -modernize-use-emplace (more subtle behavior) +# -modernize-use-nodiscard (too much noise) # -modernize-use-trailing-return-type (inconsistent style) +# -modernize-avoid-bind (incorrect conversion) +# -modernize-loop-convert (more subtle behavior) +# -modernize-replace-disallow-copy-and-assign-macro (inconsistent style) +# -modernize-make-unique (doesn't work with private constructor) +# -modernize-make-shared (doesn't work with private constructor) +# Other readability-* rules (potentially too noisy, inconsistent style) +# Other rules not mentioned here or below (not yet evaluated) # # TODO: enable google-* and readability-* families of checks. Checks: > abseil-*, bugprone-*, + -bugprone-macro-parentheses, + google-*, + -google-readability-todo, misc-*, -misc-non-private-member-variables-in-classes, -misc-unused-parameters, modernize-*, -modernize-avoid-c-arrays, + -modernize-concat-nested-namespaces, -modernize-pass-by-value, -modernize-return-braced-init-list, -modernize-use-emplace, + -modernize-use-nodiscard, -modernize-use-trailing-return-type, + -modernize-avoid-bind, + -modernize-loop-convert, + -modernize-replace-disallow-copy-and-assign-macro, + -modernize-make-unique, + -modernize-make-shared, performance-*, + readability-avoid-const-params-in-decls, + readability-braces-around-statements, + readability-const-return-type, + readability-container-size-empty, + readability-delete-null-pointer, + readability-else-after-return, + readability-implicit-bool-conversion, + readability-make-member-function-const, + readability-misleading-indentation, + readability-misplaced-array-index, + readability-named-parameter, + readability-non-const-parameter, + readability-redundant-*, + readability-static-definition-in-anonymous-namespace, + readability-string-compare, + readability-suspicious-call-argument, CheckOptions: # Reduce noisiness of the bugprone-narrowing-conversions check. diff --git a/.flake8 b/.flake8 index a4a3510a1bbeb..cb93e3096d3ef 100644 --- a/.flake8 +++ b/.flake8 @@ -24,4 +24,20 @@ ignore = W605 I N + B001 + B002 + B003 + B004 + B005 + B007 + B008 + B009 + B010 + B011 + B012 + B013 + B014 + B015 + B016 + B017 avoid-escape = no diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c4e254c2dd0f9..3502b7042bf20 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -18,6 +18,9 @@ # Dependencies /python/setup.py @richardliaw @ericl @edoakes +# Formatting tool +/ci/travis/format.sh @richardliaw @ericl @edoakes + # Python worker. #/python/ray/ @ray-project/ray-core-python #!/python/ray/tune/ @ray-project/ray-core-python @@ -30,7 +33,6 @@ /java/*/pom_template.xml @jovany-wang @kfstorm @raulchen /java/api/ @jovany-wang @kfstorm @raulchen - # Ray Client /src/ray/protobuf/ray_client.proto @ijrsvt @ameerhajali @ckw017 @mwtian @@ -39,6 +41,14 @@ # Ray tune. /python/ray/tune/ @ray-project/ray-tune +# Ray data. +/python/ray/data/ @ericl @scv119 +/doc/source/data/ @ericl @scv119 + +# Ray workflows. +/python/ray/workflow/ @ericl @iycheng +/doc/source/workflows/ @ericl @iycheng + # RLlib. #/python/ray/rllib/ @ray-project/rllib diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8df9fe895df63..9404a4a4d2517 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,7 +26,7 @@ jobs: os: windows-2019 python-version: 3.8 # Can be 'msvc' or 'clang-cl' - config: msvc + config: msvc-cl env: BAZEL_CONFIG: ${{ matrix.config }} PYTHON: ${{ matrix.python-version }} @@ -111,7 +111,6 @@ jobs: TRAVIS_COMMIT: ${{ github.sha }} TRAVIS_JOB_ID: ${{ github.run_id }} run: | - # Multi thread in windowns for grpc not working now function clean_up() { echo "Performing cleanup" if [ "${GITHUB_EVENT_NAME}" != "pull_request" ]; then ./ci/travis/upload_build_info.sh; fi diff --git a/.gitpod/Dockerfile b/.gitpod/Dockerfile index 23682c0ed9687..ce2af682e0ed9 100644 --- a/.gitpod/Dockerfile +++ b/.gitpod/Dockerfile @@ -15,7 +15,7 @@ RUN set -x; apt update \ && mv bazel.gpg /etc/apt/trusted.gpg.d/ \ && echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list \ && apt update && apt install bazel-3.7.2 -y \ - && pip3 install cython==0.29.0 pytest pandas tree tabulate pexpect sklearn joblib yapf==0.23.0 flake8==3.9.1 mypy==0.782 flake8-quotes setproctitle==1.1.10 psutil \ + && pip3 install cython==0.29.0 pytest pandas tree tabulate pexpect sklearn joblib yapf==0.23.0 flake8==3.9.1 mypy==0.782 flake8-quotes flake8-bugbear==21.9.2 setproctitle==1.1.10 psutil \ && python3 -c 'print("startup --output_base=/workspace/ray/.bazel-cache\nstartup --host_jvm_args=-Xmx1800m\nbuild --jobs=6")' > /etc/bazel.bazelrc RUN update-alternatives --install /usr/local/bin/python python /usr/bin/python3 30 \ diff --git a/BUILD.bazel b/BUILD.bazel index ad6bd083fd4ad..7db7fc20f7cb7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -414,7 +414,6 @@ cc_library( ], ) + [ "src/ray/raylet/scheduling/cluster_resource_data.cc", - "src/ray/raylet/scheduling/fixed_point.cc", "src/ray/raylet/scheduling/scheduling_ids.cc", ], hdrs = glob( @@ -553,6 +552,7 @@ cc_library( ":pubsub_lib", ":raylet_client_lib", ":worker_rpc", + "@com_google_absl//absl/container:btree", ], ) @@ -623,10 +623,12 @@ cc_library( "src/ray/stats/metric_exporter_client.cc", ], hdrs = [ + "src/ray/stats/metric.h", "src/ray/stats/metric_defs.h", "src/ray/stats/metric_exporter.h", "src/ray/stats/metric_exporter_client.h", "src/ray/stats/stats.h", + "src/ray/stats/tag_defs.h", ], copts = COPTS, linkopts = select({ @@ -1181,6 +1183,22 @@ cc_test( ], ) +cc_test( + name = "gcs_placement_group_manager_mock_test", + size = "small", + srcs = [ + "src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc", + ], + copts = COPTS, + tags = ["team:core"], + deps = [ + ":gcs_server_lib", + ":gcs_test_util_lib", + ":ray_mock", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "placement_group_resource_manager_test", size = "small", @@ -1513,6 +1531,21 @@ cc_test( ], ) +# cc_test( +# name = "gcs_actor_scheduler_mock_test", +# size = "small", +# srcs = [ +# "src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc", +# ], +# copts = COPTS, +# tags = ["team:core"], +# deps = [ +# ":gcs_server_lib", +# ":ray_mock", +# "@com_google_googletest//:gtest_main", +# ], +# ) + cc_test( name = "gcs_based_actor_scheduler_test", size = "small", diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index 1925aedfa4edb..96131feadba41 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -151,8 +151,8 @@ def ray_deps_setup(): # declaring it here allows us to avoid patching the latter. name = "boost", build_file = "@com_github_nelhage_rules_boost//:BUILD.boost", - sha256 = "d73a8da01e8bf8c7eda40b4c84915071a8c8a0df4a6734537ddde4a8580524ee", - url = "https://boostorg.jfrog.io/artifactory/main/release/1.71.0/source/boost_1_71_0.tar.bz2", + sha256 = "83bfc1507731a0906e387fc28b7ef5417d591429e51e788417fe9ff025e116b1", + url = "https://boostorg.jfrog.io/artifactory/main/release/1.74.0/source/boost_1_74_0.tar.bz2", patches = [ "//thirdparty/patches:boost-exception-no_warn_typeid_evaluated.patch", ], @@ -161,10 +161,9 @@ def ray_deps_setup(): auto_http_archive( name = "com_github_nelhage_rules_boost", # If you update the Boost version, remember to update the 'boost' rule. - url = "https://github.com/nelhage/rules_boost/archive/2613d04ab3d22dfc4543ea0a083d9adeaa0daf09.tar.gz", - sha256 = "512f913240e026099d4ca4a98b1ce8048c99de77fdc8e8584e9e2539ee119ca2", + url = "https://github.com/nelhage/rules_boost/archive/652b21e35e4eeed5579e696da0facbe8dba52b1f.tar.gz", + sha256 = "c1b8b2adc3b4201683cf94dda7eef3fc0f4f4c0ea5caa3ed3feffe07e1fb5b15", patches = [ - "//thirdparty/patches:rules_boost-undefine-boost_fallthrough.patch", "//thirdparty/patches:rules_boost-windows-linkopts.patch", ], ) diff --git a/benchmarks/object_store/test_object_store.py b/benchmarks/object_store/test_object_store.py index 5e251f55f8884..022cb17e8b890 100644 --- a/benchmarks/object_store/test_object_store.py +++ b/benchmarks/object_store/test_object_store.py @@ -65,6 +65,7 @@ def sum(self, arr): if "TEST_OUTPUT_JSON" in os.environ: out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") results = { + "broadcast_time": end - start, "object_size": OBJECT_SIZE, "num_nodes": NUM_NODES, "success": "1" diff --git a/benchmarks/single_node/test_single_node.py b/benchmarks/single_node/test_single_node.py index fb44e7fe29ade..3deaa389de600 100644 --- a/benchmarks/single_node/test_single_node.py +++ b/benchmarks/single_node/test_single_node.py @@ -199,7 +199,8 @@ def test_large_object(): "num_args": MAX_ARGS, "returns_time": returns_time, "num_returns": MAX_RETURNS, - "get_time": MAX_RAY_GET_ARGS, + "get_time": get_time, + "num_get_args": MAX_RAY_GET_ARGS, "queued_time": queued_time, "num_queued": MAX_QUEUED_TASKS, "large_object_time": large_object_time, diff --git a/ci/asan_tests/run_asan_tests.sh b/ci/asan_tests/run_asan_tests.sh index 5f84fe3ff6d40..ea2d4b8a697c5 100755 --- a/ci/asan_tests/run_asan_tests.sh +++ b/ci/asan_tests/run_asan_tests.sh @@ -39,10 +39,10 @@ asan_run() { cd "${RAY_DIR}" # Ray tests - bazel test --test_tag_filters=-jenkins_only --test_output=streamed python/ray/serve/... - bazel test --test_tag_filters=-jenkins_only --test_output=streamed python/ray/dashboard/... - bazel test --test_tag_filters=-jenkins_only --test_output=streamed python/ray/tests/... - bazel test --test_tag_filters=-jenkins_only --test_output=streamed python/ray/tune/... + bazel test --test_output=streamed python/ray/serve/... + bazel test --test_output=streamed python/ray/dashboard/... + bazel test --test_output=streamed python/ray/tests/... + bazel test --test_output=streamed python/ray/tune/... ) } diff --git a/ci/travis/bazel.py b/ci/travis/bazel.py index d731734b6faa9..d462459fc1ead 100755 --- a/ci/travis/bazel.py +++ b/ci/travis/bazel.py @@ -98,35 +98,45 @@ def info(self, *args): return result def aquery(self, *args): - lines = self._call("aquery", "--output=textproto", *args).splitlines() - return textproto_parse(lines, self.encoding, json.JSONEncoder()) + out = self._call("aquery", "--output=jsonproto", *args) + return json.loads(out.decode(self.encoding)) def parse_aquery_shell_calls(aquery_results): """Extracts and yields the command lines representing the genrule() rules from Bazel aquery results. """ - for (key, val) in aquery_results: - if key == "actions": - [mnemonic] = [pair[1] for pair in val if pair[0] == "mnemonic"] - if mnemonic == "Genrule": - yield [pair[1] for pair in val if pair[0] == "arguments"] + for action in aquery_results["actions"]: + if action["mnemonic"] != "Genrule": + continue + yield action["arguments"] def parse_aquery_output_artifacts(aquery_results): """Extracts and yields the file paths representing the output artifact from the provided Bazel aquery results. + + To understand the output of aquery command in textproto format, try: + bazel aquery --include_artifacts=true --output=jsonproto \ + 'mnemonic("Genrule", deps(//:*))' """ + fragments = {} + for fragment in aquery_results["pathFragments"]: + fragments[fragment["id"]] = fragment + artifacts = {} - for (key, val) in aquery_results: - if key == "artifacts": - [artifact_id] = [pair[1] for pair in val if pair[0] == "id"] - [exec_path] = [pair[1] for pair in val if pair[0] == "exec_path"] - artifacts[artifact_id] = exec_path - elif key == "actions": - output_ids = [pair[1] for pair in val if pair[0] == "output_ids"] - for output_id in output_ids: - yield artifacts[output_id] + for artifact in aquery_results["artifacts"]: + artifacts[artifact["id"]] = artifact + + def _path(fragment_id): + fragment = fragments[fragment_id] + parent = _path(fragment["parentId"]) if "parentId" in fragment else [] + return parent + [fragment["label"]] + + for action in aquery_results["actions"]: + for output_id in action["outputIds"]: + path = os.path.join(*_path(artifacts[output_id]["pathFragmentId"])) + yield path def textproto2json(infile, outfile): diff --git a/ci/travis/ci.sh b/ci/travis/ci.sh index 6aa33a22a2000..7faa9ae02a5be 100755 --- a/ci/travis/ci.sh +++ b/ci/travis/ci.sh @@ -139,6 +139,7 @@ test_python() { args+=( python/ray/serve/... python/ray/tests/... + -python/ray/serve:conda_env # runtime_env unsupported on Windows -python/ray/serve:test_api # segfault on windows? https://github.com/ray-project/ray/issues/12541 -python/ray/serve:test_router # timeout -python/ray/serve:test_handle # "fatal error" (?) https://github.com/ray-project/ray/pull/13695 @@ -181,6 +182,7 @@ test_python() { -python/ray/tests:test_ray_init # test_redis_port() seems to fail here, but pass in isolation -python/ray/tests:test_resource_demand_scheduler -python/ray/tests:test_reference_counting # too flaky 9/25/21 + -python/ray/tests:test_runtime_env_plugin # runtime_env not supported on Windows -python/ray/tests:test_runtime_env_env_vars # runtime_env not supported on Windows -python/ray/tests:test_runtime_env_complicated # conda install slow leading to timeout -python/ray/tests:test_stress # timeout @@ -332,7 +334,52 @@ install_ray() { ) } +validate_wheels_commit_str() { + if [ "${OSTYPE}" = msys ]; then + echo "Windows builds do not set the commit string, skipping wheel commit validity check." + return 0 + fi + + if [ -n "${BUILDKITE_COMMIT}" ]; then + EXPECTED_COMMIT=${BUILDKITE_COMMIT:-} + else + EXPECTED_COMMIT=${TRAVIS_COMMIT:-} + fi + + if [ -z "$EXPECTED_COMMIT" ]; then + echo "Could not validate expected wheel commits: TRAVIS_COMMIT is empty." + return 0 + fi + + for whl in .whl/*.whl; do + basename=${whl##*/} + + if [[ "$basename" =~ "_cpp" ]]; then + # cpp wheels cannot be checked this way + echo "Skipping CPP wheel ${basename} for wheel commit validation." + continue + fi + + folder=${basename%%-cp*} + WHL_COMMIT=$(unzip -p "$whl" "${folder}.data/purelib/ray/__init__.py" | grep "__commit__" | awk -F'"' '{print $2}') + + if [ "${WHL_COMMIT}" != "${EXPECTED_COMMIT}" ]; then + echo "Error: Observed wheel commit (${WHL_COMMIT}) is not expected commit (${EXPECTED_COMMIT}). Aborting." + exit 1 + fi + + echo "Wheel ${basename} has the correct commit: ${WHL_COMMIT}" + done + + echo "All wheels passed the sanity check and have the correct wheel commit set." +} + build_wheels() { + # Create wheel output directory and empty contents + # If buildkite runners are re-used, wheels from previous builds might be here, so we delete them. + mkdir -p .whl + rm -rf .whl/* || true + case "${OSTYPE}" in linux*) # Mount bazel cache dir to the docker container. @@ -353,7 +400,6 @@ build_wheels() { -e "RAY_DEBUG_BUILD=${RAY_DEBUG_BUILD:-}" ) - if [ -z "${BUILDKITE-}" ]; then # This command should be kept in sync with ray/python/README-building-wheels.md, # except the "${MOUNT_BAZEL_CACHE[@]}" part. @@ -361,19 +407,25 @@ build_wheels() { quay.io/pypa/manylinux2014_x86_64 /ray/python/build-wheel-manylinux2014.sh else rm -rf /ray-mount/* + rm -rf /ray-mount/.whl || true + rm -rf /ray/.whl || true cp -rT /ray /ray-mount - ls /ray-mount + ls -a /ray-mount docker run --rm -v /ray:/ray-mounted ubuntu:focal ls / docker run --rm -v /ray:/ray-mounted ubuntu:focal ls /ray-mounted docker run --rm -w /ray -v /ray:/ray "${MOUNT_BAZEL_CACHE[@]}" \ quay.io/pypa/manylinux2014_x86_64 /ray/python/build-wheel-manylinux2014.sh cp -rT /ray-mount /ray # copy new files back here find . | grep whl # testing + + validate_wheels_commit_str fi ;; darwin*) # This command should be kept in sync with ray/python/README-building-wheels.md. "${WORKSPACE_DIR}"/python/build-wheel-macos.sh + + validate_wheels_commit_str ;; msys*) keep_alive "${WORKSPACE_DIR}"/python/build-wheel-windows.sh diff --git a/ci/travis/format.sh b/ci/travis/format.sh index e31245faad61d..7dbf608d18734 100755 --- a/ci/travis/format.sh +++ b/ci/travis/format.sh @@ -83,6 +83,10 @@ if [[ $(flake8 --version) != *"flake8_quotes"* ]]; then echo "WARNING: Ray uses flake8 with flake8_quotes. Might error without it. Install with: pip install flake8-quotes" fi +if [[ $(flake8 --version) != *"flake8-bugbear"* ]]; then + echo "WARNING: Ray uses flake8 with flake8-bugbear. Might error without it. Install with: pip install flake8-bugbear" +fi + SHELLCHECK_FLAGS=( --exclude=1090 # "Can't follow non-constant source. Use a directive to specify location." --exclude=1091 # "Not following {file} due to some error" diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index 32b39ded1401e..b52f75e8a4164 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -408,7 +408,7 @@ install_dependencies() { # RLlib testing with TF 1.x. if [ "${RLLIB_TESTING-}" = 1 ] && { [ -n "${TF_VERSION-}" ] || [ -n "${TFP_VERSION-}" ]; }; then - pip install --upgrade tensorflow-probability=="${TFP_VERSION}" tensorflow=="${TF_VERSION}" gym + pip install --upgrade tensorflow-probability=="${TFP_VERSION}" tensorflow=="${TF_VERSION}" gym==0.19 fi # Additional Tune dependency for Horovod. diff --git a/ci/travis/test-worker-in-container.sh b/ci/travis/test-worker-in-container.sh index 0d5b01eb49043..00caeeb15839f 100644 --- a/ci/travis/test-worker-in-container.sh +++ b/ci/travis/test-worker-in-container.sh @@ -23,7 +23,7 @@ bash ./ci/travis/install-bazel.sh --system # shellcheck disable=SC2046 bazel test --test_timeout 60 --config=ci $(./scripts/bazel_export_options) \ ---test_tag_filters=-kubernetes,-jenkins_only,worker-container,-flaky \ +--test_tag_filters=-kubernetes,worker-container,-flaky \ python/ray/tests/... --test_output=all #pytest python/ray/tests/test_actor_in_container.py -s diff --git a/cpp/BUILD.bazel b/cpp/BUILD.bazel index 9d4e7416cda1b..9603c863546c1 100644 --- a/cpp/BUILD.bazel +++ b/cpp/BUILD.bazel @@ -90,6 +90,7 @@ genrule( mkdir -p "$$PY_CPP_DIR/lib/" && cp -f -r $$WORK_DIR/external/msgpack/include/* "$$PY_CPP_DIR/include" && cp -f -r "$$WORK_DIR/external/boost/boost/archive" "$$BOOST_DIR" && + cp -f -r "$$WORK_DIR/external/boost/boost/assert" "$$BOOST_DIR" && cp -f -r "$$WORK_DIR/external/boost/boost/bind" "$$BOOST_DIR" && cp -f -r "$$WORK_DIR/external/boost/boost/callable_traits" "$$BOOST_DIR" && cp -f -r "$$WORK_DIR/external/boost/boost/concept" "$$BOOST_DIR" && diff --git a/cpp/src/ray/api.cc b/cpp/src/ray/api.cc index a1a8c6507541c..ed2b1b89230cd 100644 --- a/cpp/src/ray/api.cc +++ b/cpp/src/ray/api.cc @@ -40,7 +40,7 @@ void Init() { bool IsInitialized() { return is_init_; } void Shutdown() { - // TODO(guyang.sgy): Clean the ray runtime. + // TODO(SongGuyang): Clean the ray runtime. internal::AbstractRayRuntime::DoShutdown(); is_init_ = false; } diff --git a/cpp/src/ray/runtime/abstract_ray_runtime.cc b/cpp/src/ray/runtime/abstract_ray_runtime.cc index 177fae17d3122..db9fac32db4e8 100644 --- a/cpp/src/ray/runtime/abstract_ray_runtime.cc +++ b/cpp/src/ray/runtime/abstract_ray_runtime.cc @@ -145,7 +145,7 @@ InvocationSpec BuildInvocationSpec1(TaskType task_type, InvocationSpec invocation_spec; invocation_spec.task_type = task_type; invocation_spec.task_id = - TaskID::ForFakeTask(); // TODO(Guyang Song): make it from different task + TaskID::ForFakeTask(); // TODO(SongGuyang): make it from different task invocation_spec.remote_function_holder = remote_function_holder; invocation_spec.actor_id = actor; invocation_spec.args = TransformArgs(args); diff --git a/cpp/src/ray/runtime/object/native_object_store.cc b/cpp/src/ray/runtime/object/native_object_store.cc index d9326feb2ae66..7add3b72b73af 100644 --- a/cpp/src/ray/runtime/object/native_object_store.cc +++ b/cpp/src/ray/runtime/object/native_object_store.cc @@ -116,7 +116,7 @@ std::vector NativeObjectStore::Wait(const std::vector &ids, int num_objects, int timeout_ms) { std::vector results; auto &core_worker = CoreWorkerProcess::GetCoreWorker(); - // TODO(guyang.sgy): Support `fetch_local` option in API. + // TODO(SongGuyang): Support `fetch_local` option in API. // Simply set `fetch_local` to be true. ::ray::Status status = core_worker.Wait(ids, num_objects, timeout_ms, &results, true); if (!status.ok()) { diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc index cb24e9d3a2b8d..40b7845578a74 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc @@ -32,7 +32,7 @@ LocalModeTaskSubmitter::LocalModeTaskSubmitter( ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, const ActorCreationOptions &options) { - /// TODO(Guyang Song): Make the information of TaskSpecification more reasonable + /// TODO(SongGuyang): Make the information of TaskSpecification more reasonable /// We just reuse the TaskSpecification class and make the single process mode work. /// Maybe some infomation of TaskSpecification are not reasonable or invalid. /// We will enhance this after implement the cluster mode. @@ -82,7 +82,7 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, AbstractRayRuntime *runtime = &local_mode_ray_tuntime_; if (invocation.task_type == TaskType::ACTOR_CREATION_TASK || invocation.task_type == TaskType::ACTOR_TASK) { - /// TODO(Guyang Song): Handle task dependencies. + /// TODO(SongGuyang): Handle task dependencies. /// Execute actor task directly in the main thread because we must guarantee the actor /// task executed by calling order. TaskExecutor::Invoke(task_specification, actor, runtime, actor_contexts_, diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index be24fe98d9a27..f0a1e12faaa78 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -75,7 +75,7 @@ std::shared_ptr TaskExecutor::current_actor_ = nullptr; TaskExecutor::TaskExecutor(AbstractRayRuntime &abstract_ray_tuntime_) : abstract_ray_tuntime_(abstract_ray_tuntime_) {} -// TODO(Guyang Song): Make a common task execution function used for both local mode and +// TODO(SongGuyang): Make a common task execution function used for both local mode and // cluster mode. std::unique_ptr TaskExecutor::Execute(InvocationSpec &invocation) { abstract_ray_tuntime_.GetWorkerContext(); diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index a528f17e03af3..825e5ca52ab20 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -16,8 +16,10 @@ #include #include + #include #include + #include "absl/synchronization/mutex.h" #include "invocation_spec.h" #include "ray/common/id.h" @@ -62,7 +64,7 @@ class TaskExecutor { public: TaskExecutor(AbstractRayRuntime &abstract_ray_tuntime_); - /// TODO(Guyang Song): support multiple tasks execution + /// TODO(SongGuyang): support multiple tasks execution std::unique_ptr Execute(InvocationSpec &invocation); static void Invoke( diff --git a/cpp/src/ray/util/process_helper.cc b/cpp/src/ray/util/process_helper.cc index 40f115e646e95..35ecd8123daa2 100644 --- a/cpp/src/ray/util/process_helper.cc +++ b/cpp/src/ray/util/process_helper.cc @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "process_helper.h" + #include -#include "process_helper.h" #include "ray/util/process.h" #include "ray/util/util.h" #include "src/ray/protobuf/gcs.pb.h" @@ -27,9 +28,9 @@ using ray::core::WorkerType; void ProcessHelper::StartRayNode(const int redis_port, const std::string redis_password, const std::vector &head_args) { - std::vector cmdargs({"ray", "start", "--head", "--port", - std::to_string(redis_port), "--redis-password", - redis_password}); + std::vector cmdargs( + {"ray", "start", "--head", "--port", std::to_string(redis_port), "--redis-password", + redis_password, "--node-ip-address", GetNodeIpAddress()}); if (!head_args.empty()) { cmdargs.insert(cmdargs.end(), head_args.begin(), head_args.end()); } @@ -124,7 +125,7 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback) if (!ConfigInternal::Instance().job_id.empty()) { options.job_id = JobID::FromHex(ConfigInternal::Instance().job_id); } else { - /// TODO(Guyang Song): Get next job id from core worker by GCS client. + /// TODO(SongGuyang): Get next job id from core worker by GCS client. /// Random a number to avoid repeated job ids. /// The repeated job ids will lead to task hang when driver connects to a existing /// cluster more than once. diff --git a/dashboard/agent.py b/dashboard/agent.py index 7301b4299f95f..f56e76f61fff9 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -83,7 +83,7 @@ def __init__(self, assert self.ppid > 0 logger.info("Parent pid is %s", self.ppid) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = ray._private.utils.add_port_to_grpc_server( + self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( self.server, f"[::]:{self.dashboard_agent_port}") logger.info("Dashboard agent grpc address: %s:%s", self.ip, self.grpc_port) diff --git a/dashboard/client/src/pages/job/JobDetail.tsx b/dashboard/client/src/pages/job/JobDetail.tsx index b720b9c057de1..892034937f107 100644 --- a/dashboard/client/src/pages/job/JobDetail.tsx +++ b/dashboard/client/src/pages/job/JobDetail.tsx @@ -11,6 +11,7 @@ import { TableRow, Tabs, } from "@material-ui/core"; +import dayjs from "dayjs"; import React from "react"; import { Link, RouteComponentProps } from "react-router-dom"; import ActorTable from "../../components/ActorTable"; @@ -140,6 +141,16 @@ const JobDetailPage = (props: RouteComponentProps<{ id: string }>) => { Driver Pid:{" "} {jobInfo.driverPid} + + StartTime:{" "} + {dayjs(Number(jobInfo.startTime)).format("YYYY/MM/DD HH:mm:ss")} + + + EndTime:{" "} + {jobInfo.endTime > 0 + ? dayjs(Number(jobInfo.endTime)).format("YYYY/MM/DD HH:mm:ss") + : "-"} + {jobInfo.eventUrl && ( Event Link:{" "} diff --git a/dashboard/client/src/pages/job/index.tsx b/dashboard/client/src/pages/job/index.tsx index e52af1ce5ec01..81be74b03e2f4 100644 --- a/dashboard/client/src/pages/job/index.tsx +++ b/dashboard/client/src/pages/job/index.tsx @@ -24,7 +24,14 @@ const useStyles = makeStyles((theme) => ({ }, })); -const columns = ["ID", "DriverIpAddress", "DriverPid", "IsDead", "Timestamp"]; +const columns = [ + "ID", + "DriverIpAddress", + "DriverPid", + "IsDead", + "StartTime", + "EndTime", +]; const JobList = () => { const classes = useStyles(); @@ -98,7 +105,8 @@ const JobList = () => { driverIpAddress, isDead, driverPid, - timestamp, + startTime, + endTime, }) => ( @@ -110,7 +118,12 @@ const JobList = () => { {isDead ? "true" : "false"} - {dayjs(Number(timestamp)).format("YYYY/MM/DD HH:mm:ss")} + {dayjs(Number(startTime)).format("YYYY/MM/DD HH:mm:ss")} + + + {endTime > 0 + ? dayjs(Number(endTime)).format("YYYY/MM/DD HH:mm:ss") + : "-"} ), diff --git a/dashboard/client/src/type/job.d.ts b/dashboard/client/src/type/job.d.ts index c5ca4dce874c1..ef9181dd2c92d 100644 --- a/dashboard/client/src/type/job.d.ts +++ b/dashboard/client/src/type/job.d.ts @@ -9,6 +9,8 @@ export type Job = { driverEntry: string; state: string; timestamp: number; + startTime: number; + endTime: number; namespaceId: string; driverPid: number; driverIpAddress: string; diff --git a/dashboard/head.py b/dashboard/head.py index 7d7cb002b652a..c7cc857c5c787 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -7,6 +7,7 @@ import threading from grpc.experimental import aio as aiogrpc +from distutils.version import LooseVersion import ray._private.utils import ray._private.services @@ -120,7 +121,7 @@ def __init__(self, http_host, http_port, http_port_retries, redis_address, ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = ray._private.utils.add_port_to_grpc_server( + self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( self.server, "[::]:0") logger.info("Dashboard head grpc address: %s:%s", self.ip, self.grpc_port) @@ -174,8 +175,12 @@ async def run(self): sys.exit(-1) # Create a http session for all modules. - self.http_session = aiohttp.ClientSession( - loop=asyncio.get_event_loop()) + # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore + if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"): + self.http_session = aiohttp.ClientSession( + loop=asyncio.get_event_loop()) + else: + self.http_session = aiohttp.ClientSession() # Waiting for GCS is ready. self.aiogrpc_gcs_channel = await make_gcs_grpc_channel( diff --git a/dashboard/modules/job/job_agent.py b/dashboard/modules/job/job_agent.py index 34b72462501ab..f56a24db83586 100644 --- a/dashboard/modules/job/job_agent.py +++ b/dashboard/modules/job/job_agent.py @@ -202,7 +202,9 @@ def _gen_driver_code(self): # Per job config job_config_items = { - "worker_env": self._job_info.env, + "runtime_env": { + "env_vars": self._job_info.env + }, "code_search_path": [job_package_dir], } diff --git a/dashboard/modules/runtime_env/runtime_env_agent.py b/dashboard/modules/runtime_env/runtime_env_agent.py index 5151278b1ab26..3c8b9c18bf9f3 100644 --- a/dashboard/modules/runtime_env/runtime_env_agent.py +++ b/dashboard/modules/runtime_env/runtime_env_agent.py @@ -6,6 +6,7 @@ import os import time from typing import Dict, Set +from ray._private.utils import import_attr from ray.core.generated import runtime_env_agent_pb2 from ray.core.generated import runtime_env_agent_pb2_grpc @@ -17,8 +18,8 @@ _internal_kv_initialized) from ray._private.ray_logging import setup_component_logger from ray._private.runtime_env.conda import CondaManager +from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.runtime_env.working_dir import WorkingDirManager -from ray._private.runtime_env import RuntimeEnvContext logger = logging.getLogger(__name__) @@ -78,13 +79,20 @@ def get_or_create_logger(self, job_id: bytes): return self._per_job_logger_cache[job_id] async def CreateRuntimeEnv(self, request, context): - async def _setup_runtime_env(serialized_runtime_env): + async def _setup_runtime_env(serialized_runtime_env, + serialized_allocated_resource_instances): # This function will be ran inside a thread def run_setup_with_logger(): runtime_env: dict = json.loads(serialized_runtime_env or "{}") + allocated_resource: dict = json.loads( + serialized_allocated_resource_instances or "{}") # Use a separate logger for each job. per_job_logger = self.get_or_create_logger(request.job_id) + # TODO(chenk008): Add log about allocated_resource to + # avoid lint error. That will be moved to cgroup plugin. + per_job_logger.debug(f"Worker has resource :" + f"{allocated_resource}") context = RuntimeEnvContext( env_vars=runtime_env.get("env_vars")) self._conda_manager.setup( @@ -98,6 +106,15 @@ def run_setup_with_logger(): self._working_dir_uri_to_envs[uri].add( serialized_runtime_env) + # Run setup function from all the plugins + for plugin_class_path in runtime_env.get("plugins", {}).keys(): + plugin_class = import_attr(plugin_class_path) + # TODO(simon): implement uri support + plugin_class.create("uri not implemented", runtime_env, + context) + plugin_class.modify_context("uri not implemented", + runtime_env, context) + return context loop = asyncio.get_event_loop() @@ -138,7 +155,8 @@ def run_setup_with_logger(): for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES): try: runtime_env_context = await _setup_runtime_env( - serialized_env) + serialized_env, + request.serialized_allocated_resource_instances) break except Exception as ex: logger.exception("Runtime env creation failed.") diff --git a/dashboard/modules/snapshot/snapshot_head.py b/dashboard/modules/snapshot/snapshot_head.py index 424e41ff45e16..87082f5463147 100644 --- a/dashboard/modules/snapshot/snapshot_head.py +++ b/dashboard/modules/snapshot/snapshot_head.py @@ -73,11 +73,10 @@ async def get_job_info(self): for job_table_entry in reply.job_info_list: job_id = job_table_entry.job_id.hex() config = { - "env_vars": dict(job_table_entry.config.worker_env), "namespace": job_table_entry.config.ray_namespace, "metadata": dict(job_table_entry.config.metadata), "runtime_env": json.loads( - job_table_entry.config.serialized_runtime_env), + job_table_entry.config.runtime_env.serialized_runtime_env), } entry = { "is_dead": job_table_entry.is_dead, diff --git a/dashboard/modules/snapshot/snapshot_schema.json b/dashboard/modules/snapshot/snapshot_schema.json index f660813110f1e..4768c2a5e292c 100644 --- a/dashboard/modules/snapshot/snapshot_schema.json +++ b/dashboard/modules/snapshot/snapshot_schema.json @@ -39,9 +39,6 @@ "config": { "type": "object", "properties": { - "envVars": { - "type": "object" - }, "namespace": { "type": "string" }, @@ -53,7 +50,6 @@ } }, "required": [ - "envVars", "namespace", "metadata", "runtimeEnv" diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index ea335c61bad21..6565ea08814cf 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -107,7 +107,7 @@ def _search_agent(processes): agent_proc.kill() agent_proc.wait() # The agent will be restarted for imports failure. - for x in range(50): + for _ in range(300): agent_proc = _search_agent(raylet_proc.children()) if agent_proc: agent_pids.add(agent_proc.pid) diff --git a/doc/BUILD b/doc/BUILD index eed30be63b145..81c112530ffec 100644 --- a/doc/BUILD +++ b/doc/BUILD @@ -3,6 +3,38 @@ # Please keep these sorted alphabetically, but start with the # root directory. # -------------------------------------------------------------------- + +# Support for Dask has been dropped in 3.6. +py_test( + name = "dask_xgboost", + size = "medium", + main = "examples/dask_xgboost/dask_xgboost.py", + srcs = ["examples/dask_xgboost/dask_xgboost.py"], + tags = ["exclusive", "team:ml", "py37"], + args = ["--smoke-test", "--address ''", "--num-actors 4", + "--cpus-per-actor 1", "--num-actors-inference 4", + "--cpus-per-actor-inference 1"] +) + +# Support for Modin has been dropped in 3.6. +py_test( + name = "modin_xgboost", + size = "medium", + main = "examples/modin_xgboost/modin_xgboost.py", + srcs = ["examples/modin_xgboost/modin_xgboost.py"], + tags = ["exclusive", "team:ml", "py37"], + args = ["--smoke-test", "--address ''", "--num-actors 4", + "--cpus-per-actor 1", "--num-actors-inference 4", + "--cpus-per-actor-inference 1"] +) + +py_test( + name = "big_data_ingestion", + size = "small", + srcs = ["source/data/_examples/big_data_ingestion.py"], + tags = ["exclusive", "team:core", "py37"] +) + py_test( name = "plot_hyperparameter", size = "small", diff --git a/doc/Makefile b/doc/Makefile index 3b0914ab942fe..39013f4175b43 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -6,7 +6,7 @@ SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build -AUTOGALLERYDIR= source/auto_examples source/tune/tutorials source/tune/generated_guides +AUTOGALLERYDIR= source/auto_examples source/tune/tutorials source/tune/generated_guides source/data/examples # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/doc/examples/dask_xgboost/README.rst b/doc/examples/dask_xgboost/README.rst new file mode 100644 index 0000000000000..8feca331c5d78 --- /dev/null +++ b/doc/examples/dask_xgboost/README.rst @@ -0,0 +1 @@ +:orphan: diff --git a/doc/examples/dask_xgboost/dask_xgboost.py b/doc/examples/dask_xgboost/dask_xgboost.py new file mode 100644 index 0000000000000..d4e50a33faf70 --- /dev/null +++ b/doc/examples/dask_xgboost/dask_xgboost.py @@ -0,0 +1,321 @@ +# flake8: noqa: E501 +""" +XGBoost-Ray with Dask +====================== + +This notebook includes an example workflow using +`XGBoost-Ray `_ and +`Dask `_ for distributed model training, +hyperparameter optimization, and prediction. +""" + +############################################################################### +# Cluster Setup +# ------------- +# +# First, we'll set up our Ray Cluster. The provided ``dask_xgboost.yaml`` +# cluster config can be used to set up an AWS cluster with 64 CPUs. +# +# The following steps assume you are in a directory with both +# ``dask_xgboost.yaml`` and this file saved as ``dask_xgboost.ipynb``. +# +# **Step 1:** Bring up the Ray cluster. +# +# .. code-block:: bash +# +# $ pip install ray boto3 +# $ ray up dask_xgboost.yaml +# +# **Step 2:** Move ``dask_xgboost.ipynb`` to the cluster and start Jupyter. +# +# .. code-block:: bash +# +# $ ray rsync_up dask_xgboost.yaml "./dask_xgboost.ipynb" \ +# "~/dask_xgboost.ipynb" +# $ ray exec dask_xgboost.yaml --port-forward=9999 "jupyter notebook \ +# --port=9999" +# +# You can then access this notebook at the URL that is output: +# ``http://localhost:9999/?token=`` + +############################################################################### +# Python Setup +# ------------ +# +# First, we'll import all the libraries we'll be using. This step also helps us +# verify that the environment is configured correctly. If any of the imports +# are missing, an exception will be raised. + +import argparse +import time + +import dask +import dask.dataframe as dd +from xgboost_ray import RayDMatrix, RayParams, train, predict + +import ray +from ray import tune +from ray.util.dask import ray_dask_get + +############################################################################### +# +# Next, let's parse some arguments. This will be used for executing the ``.py`` +# file, but not for the ``.ipynb``. If you are using the interactive notebook, +# you can directly override the arguments manually. + +parser = argparse.ArgumentParser() +parser.add_argument( + "--address", type=str, default="auto", help="The address to use for Ray.") +parser.add_argument( + "--smoke-test", + action="store_true", + help="Read a smaller dataset for quick testing purposes.") +parser.add_argument( + "--num-actors", + type=int, + default=4, + help="Sets number of actors for training.") +parser.add_argument( + "--cpus-per-actor", + type=int, + default=6, + help="The number of CPUs per actor for training.") +parser.add_argument( + "--num-actors-inference", + type=int, + default=16, + help="Sets number of actors for inference.") +parser.add_argument( + "--cpus-per-actor-inference", + type=int, + default=2, + help="The number of CPUs per actor for inference.") +# Ignore -f from ipykernel_launcher +args, _ = parser.parse_known_args() + +############################################################################### +# Override these arguments as needed: + +address = args.address +smoke_test = args.smoke_test +num_actors = args.num_actors +cpus_per_actor = args.cpus_per_actor +num_actors_inference = args.num_actors_inference +cpus_per_actor_inference = args.cpus_per_actor_inference + +############################################################################### +# Connecting to the Ray cluster +# ----------------------------- +# Now, let's connect our Python script to this newly deployed Ray cluster! + +if not ray.is_initialized(): + ray.init(address=address) + +############################################################################### +# Data Preparation +# ----------------- +# We will use the `HIGGS dataset from the UCI Machine Learning dataset +# repository `_. The HIGGS +# dataset consists of 11,000,000 samples and 28 attributes, which is large +# enough size to show the benefits of distributed computation. +# +# We set the Dask scheduler to ``ray_dask_get`` to use `Dask on Ray +# `_ backend. + +LABEL_COLUMN = "label" +if smoke_test: + # Test dataset with only 10,000 records. + FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \ + ".csv" +else: + # Full dataset. This may take a couple of minutes to load. + FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \ + "/00280/HIGGS.csv.gz" +colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)] +dask.config.set(scheduler=ray_dask_get) + +############################################################################### + +load_data_start_time = time.time() + +data = dd.read_csv(FILE_URL, names=colnames) +data = data[sorted(colnames)] +data = data.persist() + +load_data_end_time = time.time() +load_data_duration = load_data_end_time - load_data_start_time +print(f"Dataset loaded in {load_data_duration} seconds.") + +############################################################################### +# With the connection established, we can now create the Dask dataframe. +# +# We will split the data into a training set and a evaluation set using a 80-20 +# proportion. + +train_df, eval_df = data.random_split([0.8, 0.2]) +train_df, eval_df = train_df.persist(), eval_df.persist() +print(train_df, eval_df) + +############################################################################### +# Distributed Training +# -------------------- +# The ``train_xgboost`` function contains all of the logic necessary for +# training using XGBoost-Ray. +# +# Distributed training can not only speed up the process, but also allow you +# to use datasets that are to large to fit in memory of a single node. With +# distributed training, the dataset is sharded across different actors +# running on separate nodes. Those actors communicate with each other to +# create the final model. +# +# First, the dataframes are wrapped in ``RayDMatrix`` objects, which handle +# data sharding across the cluster. Then, the ``train`` function is called. +# The evaluation scores will be saved to ``evals_result`` dictionary. The +# function returns a tuple of the trained model (booster) and the evaluation +# scores. +# +# The ``ray_params`` variable expects a ``RayParams`` object that contains +# Ray-specific settings, such as the number of workers. + + +def train_xgboost(config, train_df, test_df, target_column, ray_params): + train_set = RayDMatrix(train_df, target_column) + test_set = RayDMatrix(test_df, target_column) + + evals_result = {} + + train_start_time = time.time() + + # Train the classifier + bst = train( + params=config, + dtrain=train_set, + evals=[(test_set, "eval")], + evals_result=evals_result, + ray_params=ray_params) + + train_end_time = time.time() + train_duration = train_end_time - train_start_time + print(f"Total time taken: {train_duration} seconds.") + + model_path = "model.xgb" + bst.save_model(model_path) + print("Final validation error: {:.4f}".format( + evals_result["eval"]["error"][-1])) + + return bst, evals_result + + +############################################################################### +# We can now pass our Dask dataframes and run the function. We will use +# ``RayParams`` to specify that the number of actors and CPUs to train with. +# +# The dataset has to be downloaded onto the cluster, which may take a few +# minutes. + +# standard XGBoost config for classification +config = { + "tree_method": "approx", + "objective": "binary:logistic", + "eval_metric": ["logloss", "error"], +} + +bst, evals_result = train_xgboost( + config, train_df, eval_df, LABEL_COLUMN, + RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors)) +print(f"Results: {evals_result}") + +############################################################################### +# Hyperparameter optimization +# --------------------------- +# If we are not content with the results obtained with default XGBoost +# parameters, we can use `Ray Tune +# `_ for cutting-edge +# distributed hyperparameter tuning. XGBoost-Ray automatically integrates +# with Ray Tune, meaning we can use the same training function as before. +# +# In this workflow, we will tune three hyperparameters - ``eta``, ``subsample`` +# and ``max_depth``. We are using `Tune's samplers to define the search +# space `_. +# +# The experiment configuration is done through ``tune.run``. We set the amount +# of resources each trial (hyperparameter combination) requires by using the +# ``get_tune_resources`` method of ``RayParams``. The ``num_samples`` argument +# controls how many trials will be ran in total. In the end, the best +# combination of hyperparameters evaluated during the experiment will be +# returned. +# +# By default, Tune will use simple random search. However, Tune also +# provides various `search algorithms +# `_ and +# `schedulers `_ +# to further improve the optimization process. + + +def tune_xgboost(train_df, test_df, target_column): + # Set XGBoost config. + config = { + "tree_method": "approx", + "objective": "binary:logistic", + "eval_metric": ["logloss", "error"], + "eta": tune.loguniform(1e-4, 1e-1), + "subsample": tune.uniform(0.5, 1.0), + "max_depth": tune.randint(1, 9) + } + + ray_params = RayParams( + max_actor_restarts=1, + cpus_per_actor=cpus_per_actor, + num_actors=num_actors) + + tune_start_time = time.time() + + analysis = tune.run( + tune.with_parameters( + train_xgboost, + train_df=train_df, + test_df=test_df, + target_column=target_column, + ray_params=ray_params), + # Use the `get_tune_resources` helper function to set the resources. + resources_per_trial=ray_params.get_tune_resources(), + config=config, + num_samples=10, + metric="eval-error", + mode="min") + + tune_end_time = time.time() + tune_duration = tune_end_time - tune_start_time + print(f"Total time taken: {tune_duration} seconds.") + + accuracy = 1. - analysis.best_result["eval-error"] + print(f"Best model parameters: {analysis.best_config}") + print(f"Best model total accuracy: {accuracy:.4f}") + + return analysis.best_config + + +############################################################################### +# Hyperparameter optimization may take some time to complete. + +tune_xgboost(train_df, eval_df, LABEL_COLUMN) + +############################################################################### +# Prediction +# ---------- +# With the model trained, we can now predict on unseen data. For the +# purposes of this example, we will use the same dataset for prediction as +# for training. +# +# Since prediction is naively parallelizable, distributing it over multiple +# actors can measurably reduce the amount of time needed. + +inference_df = RayDMatrix(data, ignore=[LABEL_COLUMN, "partition"]) +results = predict( + bst, + inference_df, + ray_params=RayParams( + cpus_per_actor=cpus_per_actor_inference, + num_actors=num_actors_inference)) + +print(results) diff --git a/doc/examples/dask_xgboost/dask_xgboost.yaml b/doc/examples/dask_xgboost/dask_xgboost.yaml new file mode 100644 index 0000000000000..e598a115069b6 --- /dev/null +++ b/doc/examples/dask_xgboost/dask_xgboost.yaml @@ -0,0 +1,24 @@ +cluster_name: dask_xgboost + +max_workers: 3 + +provider: + type: aws + region: us-west-1 + +auth: + ssh_user: ubuntu + +available_node_types: + 16_cpu_node: + min_workers: 3 + max_workers: 3 + node_config: + InstanceType: m5.4xlarge + ImageId: latest_dlami + resources: { } + +head_node_type: 16_cpu_node + +setup_commands: + - pip install -U jupyter ray[tune] xgboost_ray[default] dask pandas diff --git a/doc/examples/modin_xgboost/README.rst b/doc/examples/modin_xgboost/README.rst new file mode 100644 index 0000000000000..8feca331c5d78 --- /dev/null +++ b/doc/examples/modin_xgboost/README.rst @@ -0,0 +1 @@ +:orphan: diff --git a/doc/examples/modin_xgboost/modin_xgboost.py b/doc/examples/modin_xgboost/modin_xgboost.py new file mode 100644 index 0000000000000..bcbe6c0968068 --- /dev/null +++ b/doc/examples/modin_xgboost/modin_xgboost.py @@ -0,0 +1,233 @@ +""" +XGBoost-Ray with Modin +====================== + +This notebook includes an example workflow using +`XGBoost-Ray `_ and +`Modin `_ for distributed model +training and prediction. +""" + +############################################################################### +# Cluster Setup +# ------------- +# +# First, we'll set up our Ray Cluster. The provided ``modin_xgboost.yaml`` +# cluster config can be used to set up an AWS cluster with 64 CPUs. +# +# The following steps assume you are in a directory with both +# ``modin_xgboost.yaml`` and this file saved as ``modin_xgboost.ipynb``. +# +# **Step 1:** Bring up the Ray cluster. +# +# .. code-block:: bash +# +# $ pip install ray boto3 +# $ ray up modin_xgboost.yaml +# +# **Step 2:** Move ``modin_xgboost.ipynb`` to the cluster and start Jupyter. +# +# .. code-block:: bash +# +# $ ray rsync_up modin_xgboost.yaml "./modin_xgboost.ipynb" \ +# "~/modin_xgboost.ipynb" +# $ ray exec modin_xgboost.yaml --port-forward=9999 "jupyter notebook \ +# --port=9999" +# +# You can then access this notebook at the URL that is output: +# ``http://localhost:9999/?token=`` + +############################################################################### +# Python Setup +# ------------ +# +# First, we'll import all the libraries we'll be using. This step also helps us +# verify that the environment is configured correctly. If any of the imports +# are missing, an exception will be raised. + +import argparse +import time + +import modin.pandas as pd +from modin.experimental.sklearn.model_selection import train_test_split +from xgboost_ray import RayDMatrix, RayParams, train, predict + +import ray + +############################################################################### +# +# Next, let's parse some arguments. This will be used for executing the ``.py`` +# file, but not for the ``.ipynb``. If you are using the interactive notebook, +# you can directly override the arguments manually. + +parser = argparse.ArgumentParser() +parser.add_argument( + "--address", type=str, default="auto", help="The address to use for Ray.") +parser.add_argument( + "--smoke-test", + action="store_true", + help="Read a smaller dataset for quick testing purposes.") +parser.add_argument( + "--num-actors", + type=int, + default=4, + help="Sets number of actors for training.") +parser.add_argument( + "--cpus-per-actor", + type=int, + default=8, + help="The number of CPUs per actor for training.") +parser.add_argument( + "--num-actors-inference", + type=int, + default=16, + help="Sets number of actors for inference.") +parser.add_argument( + "--cpus-per-actor-inference", + type=int, + default=2, + help="The number of CPUs per actor for inference.") +# Ignore -f from ipykernel_launcher +args, _ = parser.parse_known_args() + +############################################################################### +# Override these arguments as needed: + +address = args.address +smoke_test = args.smoke_test +num_actors = args.num_actors +cpus_per_actor = args.cpus_per_actor +num_actors_inference = args.num_actors_inference +cpus_per_actor_inference = args.cpus_per_actor_inference + +############################################################################### +# Connecting to the Ray cluster +# ----------------------------- +# Now, let's connect our Python script to this newly deployed Ray cluster! + +if not ray.is_initialized(): + ray.init(address=address) + +############################################################################### +# Data Preparation +# ----------------- +# We will use the `HIGGS dataset from the UCI Machine Learning dataset +# repository `_. The HIGGS +# dataset consists of 11,000,000 samples and 28 attributes, which is large +# enough size to show the benefits of distributed computation. + +LABEL_COLUMN = "label" +if smoke_test: + # Test dataset with only 10,000 records. + FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \ + ".csv" +else: + # Full dataset. This may take a couple of minutes to load. + FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \ + "/00280/HIGGS.csv.gz" + +colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)] + +############################################################################### + +load_data_start_time = time.time() + +df = pd.read_csv(FILE_URL, names=colnames) + +load_data_end_time = time.time() +load_data_duration = load_data_end_time - load_data_start_time +print(f"Dataset loaded in {load_data_duration} seconds.") + +############################################################################### +# Split data into training and validation. + +df_train, df_validation = train_test_split(df) +print(df_train, df_validation) + +############################################################################### +# Distributed Training +# -------------------- +# The ``train_xgboost`` function contains all of the logic necessary for +# training using XGBoost-Ray. +# +# Distributed training can not only speed up the process, but also allow you +# to use datasets that are to large to fit in memory of a single node. With +# distributed training, the dataset is sharded across different actors +# running on separate nodes. Those actors communicate with each other to +# create the final model. +# +# First, the dataframes are wrapped in ``RayDMatrix`` objects, which handle +# data sharding across the cluster. Then, the ``train`` function is called. +# The evaluation scores will be saved to ``evals_result`` dictionary. The +# function returns a tuple of the trained model (booster) and the evaluation +# scores. +# +# The ``ray_params`` variable expects a ``RayParams`` object that contains +# Ray-specific settings, such as the number of workers. + + +def train_xgboost(config, train_df, test_df, target_column, ray_params): + train_set = RayDMatrix(train_df, target_column) + test_set = RayDMatrix(test_df, target_column) + + evals_result = {} + + train_start_time = time.time() + + # Train the classifier + bst = train( + params=config, + dtrain=train_set, + evals=[(test_set, "eval")], + evals_result=evals_result, + verbose_eval=False, + num_boost_round=100, + ray_params=ray_params) + + train_end_time = time.time() + train_duration = train_end_time - train_start_time + print(f"Total time taken: {train_duration} seconds.") + + model_path = "model.xgb" + bst.save_model(model_path) + print("Final validation error: {:.4f}".format( + evals_result["eval"]["error"][-1])) + + return bst, evals_result + + +############################################################################### +# We can now pass our Modin dataframes and run the function. We will use +# ``RayParams`` to specify that the number of actors and CPUs to train with. + +# standard XGBoost config for classification +config = { + "tree_method": "approx", + "objective": "binary:logistic", + "eval_metric": ["logloss", "error"], +} + +bst, evals_result = train_xgboost( + config, df_train, df_validation, LABEL_COLUMN, + RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors)) +print(f"Results: {evals_result}") + +############################################################################### +# Prediction +# ---------- +# With the model trained, we can now predict on unseen data. For the +# purposes of this example, we will use the same dataset for prediction as +# for training. +# +# Since prediction is naively parallelizable, distributing it over multiple +# actors can measurably reduce the amount of time needed. + +inference_df = RayDMatrix(df, ignore=[LABEL_COLUMN, "partition"]) +results = predict( + bst, + inference_df, + ray_params=RayParams( + cpus_per_actor=cpus_per_actor_inference, + num_actors=num_actors_inference)) + +print(results) diff --git a/doc/examples/modin_xgboost/modin_xgboost.yaml b/doc/examples/modin_xgboost/modin_xgboost.yaml new file mode 100644 index 0000000000000..914cbdb207af2 --- /dev/null +++ b/doc/examples/modin_xgboost/modin_xgboost.yaml @@ -0,0 +1,24 @@ +cluster_name: modin_xgboost + +max_workers: 3 + +provider: + type: aws + region: us-west-1 + +auth: + ssh_user: ubuntu + +available_node_types: + 16_cpu_node: + min_workers: 3 + max_workers: 3 + node_config: + InstanceType: m5.4xlarge + ImageId: latest_dlami + resources: { } + +head_node_type: 16_cpu_node + +setup_commands: + - pip install -U jupyter ray xgboost_ray[default] modin pandas diff --git a/doc/examples/overview.rst b/doc/examples/overview.rst index 8555799094ef9..be438f3580783 100644 --- a/doc/examples/overview.rst +++ b/doc/examples/overview.rst @@ -61,6 +61,8 @@ Machine Learning Examples plot_lbfgs.rst plot_example-lm.rst plot_newsreader.rst + dask_xgboost/dask_xgboost.rst + modin_xgboost/modin_xgboost.rst .. customgalleryitem:: @@ -86,6 +88,14 @@ Machine Learning Examples :tooltip: Implementing a simple news reader using Ray. :description: :doc:`/auto_examples/plot_newsreader` +.. customgalleryitem:: + :tooltip: Train an XGBoost-Ray model using Dask for data processing. + :description: :doc:`/auto_examples/dask_xgboost/dask_xgboost` + +.. customgalleryitem:: + :tooltip: Train an XGBoost-Ray model using Modin for data processing. + :description: :doc:`/auto_examples/modin_xgboost/modin_xgboost` + .. raw:: html @@ -138,4 +148,4 @@ These are full guides on how you can use Ray with various Machine Learning libra .. customgalleryitem:: :tooltip: Using Ray with PyTorch Lightning. :figure: /images/pytorch_lightning_small.png - :description: :doc:`/auto_examples/using-ray-with-pytorch-lightning` \ No newline at end of file + :description: :doc:`/auto_examples/using-ray-with-pytorch-lightning` diff --git a/doc/kubernetes/ray-cluster.yaml b/doc/kubernetes/ray-cluster.yaml index 1b3da82e9ccaa..f4f493152608c 100644 --- a/doc/kubernetes/ray-cluster.yaml +++ b/doc/kubernetes/ray-cluster.yaml @@ -3,7 +3,7 @@ apiVersion: v1 kind: Service metadata: namespace: ray - name: ray-head + name: example-cluster-ray-head spec: ports: - name: client @@ -111,7 +111,7 @@ spec: imagePullPolicy: IfNotPresent command: ["/bin/bash", "-c", "--"] args: - - "ray start --num-cpus=$MY_CPU_REQUEST --address=$RAY_HEAD_SERVICE_HOST:$RAY_HEAD_SERVICE_PORT_REDIS --object-manager-port=12345 --node-manager-port=12346 --block" + - "ray start --num-cpus=$MY_CPU_REQUEST --address=$EXAMPLE_CLUSTER_RAY_HEAD_SERVICE_HOST:$EXAMPLE_CLUSTER_RAY_HEAD_SERVICE_PORT_REDIS --object-manager-port=12345 --node-manager-port=12346 --block" # This volume allocates shared memory for Ray to use for its plasma # object store. If you do not provide this, Ray will fall back to # /tmp which cause slowdowns if is not a shared memory volume. diff --git a/doc/source/advanced.rst b/doc/source/advanced.rst index 75ff25045592e..fa4ff9cffa65c 100644 --- a/doc/source/advanced.rst +++ b/doc/source/advanced.rst @@ -42,17 +42,23 @@ This often occurs for data loading and preprocessing. # hi there! # hi there! -Multi-node synchronization using ``SignalActor`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Multi-node synchronization using an Actor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -When you have multiple tasks that need to wait on some condition, you can use a ``SignalActor`` to coordinate. +When you have multiple tasks that need to wait on some condition or otherwise +need to synchronize across tasks & actors on a cluster, you can use a central +actor to coordinate among them. Below is an example of using a ``SignalActor`` +that wraps an ``asyncio.Event`` for basic synchronization. .. code-block:: python - # Also available via `from ray._private.test_utils import SignalActor` - import ray import asyncio + import ray + + ray.init() + + # We set num_cpus to zero because this actor will mostly just block on I/O. @ray.remote(num_cpus=0) class SignalActor: def __init__(self): @@ -73,7 +79,6 @@ When you have multiple tasks that need to wait on some condition, you can use a print("go!") - ray.init() signal = SignalActor.remote() tasks = [wait_and_go.remote(signal) for _ in range(4)] print("ready...") @@ -441,7 +446,7 @@ On Mac OS and Linux, Ray 1.4+ supports dynamically setting the runtime environme The ``runtime_env`` is a (JSON-serializable) dictionary that can be passed as an option to tasks and actors, and can also be passed to ``ray.init()``. The runtime environment defines the dependencies required for your workload. -You can specify a runtime environment for your whole job using ``ray.init()`` or Ray Client... +You can specify a runtime environment for your whole job using ``ray.init()`` or Ray Client: .. literalinclude:: ../examples/doc_code/runtime_env_example.py :language: python @@ -456,19 +461,20 @@ You can specify a runtime environment for your whole job using ``ray.init()`` or # Using Ray Client ray.init("ray://localhost:10001", runtime_env=runtime_env) -...or specify per-actor or per-task in the ``@ray.remote()`` decorator or by using ``.options()``: +Or specify per-actor or per-task in the ``@ray.remote()`` decorator or by using ``.options()``: .. literalinclude:: ../examples/doc_code/runtime_env_example.py :language: python :start-after: __per_task_per_actor_start__ :end-before: __per_task_per_actor_end__ +Note: specifying within the ``@ray.remote()`` decorator is currently unsupported while using Ray Client; please use ``.options()`` instead in this case. + The ``runtime_env`` is a Python dictionary including one or more of the following arguments: - ``working_dir`` (Path): Specifies the working directory for your job. This must be an existing local directory. It will be cached on the cluster, so the next time you connect with Ray Client you will be able to skip uploading the directory contents. - Furthermore, if you locally make a small change to your directory, the next time you connect only the updated part will be uploaded. - All Ray workers for your job will be started in their node's copy of this working directory. + All Ray workers for your job will be started in their node's local copy of this working directory. - Examples @@ -486,7 +492,7 @@ The ``runtime_env`` is a Python dictionary including one or more of the followin - Example: ``["my_file.txt", "path/to/dir", "*.log"]`` - ``pip`` (List[str] | str): Either a list of pip packages, or a string containing the path to a pip - `“requirements.txt” `_ file. The path may be an absolute path or a relative path. (Note: A relative path will be interpreted relative to ``working_dir`` if ``working_dir`` is specified.) + `“requirements.txt” `_ file. The path may be an absolute path or a relative path. This will be dynamically installed in the ``runtime_env``. To use a library like Ray Serve or Ray Tune, you will need to include ``"ray[serve]"`` or ``"ray[tune]"`` here. @@ -494,7 +500,7 @@ The ``runtime_env`` is a Python dictionary including one or more of the followin - Example: ``"./requirements.txt"`` -- ``conda`` (dict | str): Either (1) a dict representing the conda environment YAML, (2) a string containing the path to a +- ``conda`` (dict | str): Either (1) a dict representing the conda environment YAML, (2) a string containing the absolute or relative path to a `conda “environment.yml” `_ file, or (3) the name of a local conda environment already installed on each node in your cluster (e.g., ``"pytorch_p36"``). In the first two cases, the Ray and Python dependencies will be automatically injected into the environment to ensure compatibility, so there is no need to manually include them. @@ -506,12 +512,15 @@ The ``runtime_env`` is a Python dictionary including one or more of the followin - Example: ``"pytorch_p36"`` - Note: if specifying the path to an "environment.yml" file, you may provide an absolute path or a relative path. A relative path will be interpreted relative to ``working_dir`` if ``working_dir`` is specified. - ``env_vars`` (Dict[str, str]): Environment variables to set. - Example: ``{"OMP_NUM_THREADS": "32", "TF_WARNINGS": "none"}`` +- ``eager_install`` (bool): A boolean indicates whether to install runtime env eagerly before the workers are leased. This flag is set to false by default. + + - Example: ``{"eager_install": True}`` + The runtime environment is inheritable, so it will apply to all tasks/actors within a job and all child tasks/actors of a task or actor, once set. If a child actor or task specifies a new ``runtime_env``, it will be merged with the parent’s ``runtime_env`` via a simple dict update. diff --git a/doc/source/cluster/config.rst b/doc/source/cluster/config.rst index 7ba7e2ccbcbef..867e8398e6985 100644 --- a/doc/source/cluster/config.rst +++ b/doc/source/cluster/config.rst @@ -109,6 +109,8 @@ Provider :ref:`region `: str :ref:`availability_zone `: str :ref:`cache_stopped_nodes `: bool + :ref:`security_group `: + :ref:`Security Group ` .. group-tab:: Azure @@ -130,6 +132,20 @@ Provider :ref:`project_id `: str :ref:`cache_stopped_nodes `: bool +.. _cluster-configuration-security-group-type: + +Security Group +~~~~~~~~~~~~~~ + +.. tabs:: + .. group-tab:: AWS + + .. parsed-literal:: + + :ref:`GroupName `: str + :ref:`IpPermissions `: + - `IpPermission `_ + .. _cluster-configuration-node-types-type: Node types @@ -923,6 +939,52 @@ If enabled, nodes will be *stopped* when the cluster scales down. If disabled, n * **Type:** Boolean * **Default:** ``True`` +.. _cluster-configuration-security-group: + +``provider.security_group`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. tabs:: + .. group-tab:: AWS + + A security group that can be used to specify custom inbound rules. + + * **Required:** No + * **Importance:** Medium + * **Type:** :ref:`Security Group ` + + .. group-tab:: Azure + + Not available. + + .. group-tab:: GCP + + Not available. + + +.. _cluster-configuration-group-name: + +``security_group.GroupName`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The name of the security group. This name must be unique within the VPC. + +* **Required:** No +* **Importance:** Low +* **Type:** String +* **Default:** ``"ray-autoscaler-{cluster-name}"`` + +.. _cluster-configuration-ip-permissions: + +``security_group.IpPermissions`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The inbound rules associated with the security group. + +* **Required:** No +* **Importance:** Medium +* **Type:** `IpPermission `_ + .. _cluster-configuration-node-config: ``available_node_types..node_type.node_config`` diff --git a/doc/source/cluster/ray-client.rst b/doc/source/cluster/ray-client.rst index 550bb75480127..1b9099160f9ae 100644 --- a/doc/source/cluster/ray-client.rst +++ b/doc/source/cluster/ray-client.rst @@ -62,9 +62,9 @@ Step 1: set up your Ray cluster First, you'll want to create a remote Ray cluster. Follow the directions in :ref:`ref-cluster-quick-start` to do this. -If using the `Ray cluster launcher `_, the remote cluster will be listening on port ``10001`` of the head node. If necessary, you can modify this port by setting ``--ray-client-server-port`` to the ``ray start`` `command `_. +If using the :doc:`Ray cluster launcher `, the remote cluster will be listening on port ``10001`` of the head node. If necessary, you can modify this port by setting ``--ray-client-server-port`` to the ``ray start`` `command `_. -If not using the `Ray cluster launcher `_, you can start the "Ray Client Server" manually on the head node of your remote cluster by running the following: +If not using the :doc:`Ray cluster launcher `, you can start the "Ray Client Server" manually on the head node of your remote cluster by running the following: .. code-block:: bash @@ -77,6 +77,32 @@ Ensure that the Ray Client port on the head node is reachable from your local ma This means opening that port up by configuring security groups or other access controls (on `EC2 `_) or proxying from your local machine to the cluster (on `K8s `_). +.. tabs:: + .. group-tab:: AWS + + With the Ray cluster launcher, you can configure the security group + to allow inbound access by defining :ref:`cluster-configuration-security-group` + in your `cluster.yaml`. + + .. code-block:: yaml + + # An unique identifier for the head node and workers of this cluster. + cluster_name: minimal_security_group + + # Cloud-provider specific configuration. + provider: + type: aws + region: us-west-2 + security_group: + GroupName: ray_client_security_group + IpPermissions: + - FromPort: 10001 + ToPort: 10001 + IpProtocol: TCP + IpRanges: + # This will enable inbound access from ALL IPv4 addresses. + - CidrIp: 0.0.0.0/0 + Step 3: Run Ray code ~~~~~~~~~~~~~~~~~~~~ @@ -99,8 +125,43 @@ Now, connect to the Ray Cluster with the following and then use Ray like you nor #.... -Connect to multiple ray clusters --------------------------------- +Alternative Approach: SSH Port Forwarding +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +As an alternative to configuring inbound traffic rules, you can also set up +Ray Client via port forwarding. While this approach does require an open SSH +connection, it can be useful in a test environment where the +``head_node_host`` often changes. + +First, open up an SSH connection with your Ray cluster and forward the +listening port (``10001``). + +.. code-block:: bash + + $ ray up cluster.yaml + $ ray attach cluster.yaml -p 10001 + +Then, you can connect to the Ray cluster using ``localhost`` as the +``head_node_host``. + +.. code-block:: python + + import ray + + # This will connect to the cluster via the open SSH session. + ray.init("ray://localhost:10001") + + # Normal Ray code follows + @ray.remote + def do_work(x): + return x ** x + + do_work.remote(2) + + #.... + +Connect to multiple ray clusters (Experimental) +----------------------------------------------- Ray client allows connecting to multiple ray clusters in one Python process. To do this, just pass ``allow_multiple=True`` to ``ray.init``: diff --git a/doc/source/conf.py b/doc/source/conf.py index 05cc18898b7dc..c554dfec1eda9 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -162,10 +162,10 @@ def __getattr__(cls, name): versionwarning_body_selector = "#main-content" sphinx_gallery_conf = { - "examples_dirs": ["../examples", - "tune/_tutorials"], # path to example scripts + "examples_dirs": ["../examples", "tune/_tutorials", + "data/_examples"], # path to example scripts # path where to save generated examples - "gallery_dirs": ["auto_examples", "tune/tutorials"], + "gallery_dirs": ["auto_examples", "tune/tutorials", "data/examples"], "ignore_pattern": "../examples/doc_code/", "plot_gallery": "False", "min_reported_time": sys.maxsize, diff --git a/doc/source/configure.rst b/doc/source/configure.rst index 5e93b2c6e4f82..186255d855373 100644 --- a/doc/source/configure.rst +++ b/doc/source/configure.rst @@ -234,6 +234,28 @@ to localhost when the ray is started using ``ray.init``. See the `Redis security documentation `__ for more information. +TLS Authentication +------------------ + +Ray can be configured to use TLS on it's gRPC channels. +This has means that connecting to the Ray client on the head node will +require an appropriate set of credentials and also that data exchanged between +various processes (client, head, workers) will be encrypted. + +Enabling TLS will cause a performance hit due to the extra overhead of mutual +authentication and encryption. +Testing has shown that this overhead is large for small workloads and becomes +relatively smaller for large workloads. +The exact overhead will depend on the nature of your workload. + +TLS is enabled by setting environment variables. + +- ``RAY_USE_TLS``: Either 1 or 0 to use/not-use TLS. If this is set to 1 then all of the environment variables below must be set. Default: 0. +- ``RAY_TLS_SERVER_CERT``: Location of a `certificate file` which is presented to other endpoints so as to achieve mutual authentication. +- ``RAY_TLS_SERVER_KEY``: Location of a `private key file` which is the cryptographic means to prove to other endpoints that you are the authorized user of a given certificate. +- ``RAY_TLS_CA_CERT``: Location of a `CA certificate file` which allows TLS to decide whether an endpoint's certificate has been signed by the correct authority. + + Java Applications ----------------- diff --git a/doc/source/data/.gitignore b/doc/source/data/.gitignore new file mode 100644 index 0000000000000..d838da9865693 --- /dev/null +++ b/doc/source/data/.gitignore @@ -0,0 +1 @@ +examples/ diff --git a/doc/source/data/_examples/README.rst b/doc/source/data/_examples/README.rst new file mode 100644 index 0000000000000..8feca331c5d78 --- /dev/null +++ b/doc/source/data/_examples/README.rst @@ -0,0 +1 @@ +:orphan: diff --git a/doc/source/data/_examples/big_data_ingestion.py b/doc/source/data/_examples/big_data_ingestion.py new file mode 100644 index 0000000000000..7cc569ea8e161 --- /dev/null +++ b/doc/source/data/_examples/big_data_ingestion.py @@ -0,0 +1,276 @@ +# flake8: noqa: E501 +""" +Example: Large-scale ML Ingest +================================================= + +In this example, you will learn how to build, deploy and scale up a machine +learning shuffle ingestion pipeline using +`Ray Dataset `_ and +`Dataset Pipelines `_. + +In particular, we will show you: + +* How to build a shuffle ingestion pipeline that loads, shuffles and feeds data + into distributed trainers in a few lines of code; +* How to scale the pipeline from ingesting 100MiB data to + 500GiB data. + +.. image:: ../../data/dataset-repeat-2.svg + :align: center + +""" + +############################################################################### +# Python Setup +# ------------ +# +# First, we'll import all of the libraries we'll be using. This step also helps us +# verify that the environment is configured correctly. If any of the imports +# are missing, an exception will be raised. + +import argparse +import tempfile +import time +from typing import List + +import pandas +import pyarrow + +import ray +from ray.data.dataset_pipeline import DatasetPipeline +from ray.data.datasource.datasource import RandomIntRowDatasource + +####################################################################### +# Build shuffle ingestion pipeline +# ---------------------------------- +# +# A typical machine learning ingestion pipeline consists of the following 4 +# steps: +# +# 1. Load the training data from external storage; +# 2. Iterate over the data for multiple epochs; +# 3. In each epoch, applying global shuffle to decorrelate the data; +# 4. In each epoch, split the shuffled data into shards, and feed shards to +# distributed trainers; +# +# Let’s see how we implement such pipeline using Ray Dataset: + + +def create_shuffle_pipeline(training_data_dir: str, num_epochs: int, + num_shards: int) -> List[DatasetPipeline]: + + return ray.data.read_parquet(training_data_dir) \ + .repeat(num_epochs) \ + .random_shuffle_each_window() \ + .split(num_shards, equal=True) + + +############################################################################ +# We’ve now defined a ``create_shuffle_pipeline`` function that creates an +# ingestion pipeline. +# It reads ``training_data_dir``, iterates for ``num_epochs`` times, +# where in each epoch it +# shuffles and splits the training data into ``num_shards``. + +############################################################################### +# Feed the pipeline into trainers +# ----------------------------------- +# Let’s also implement a ``TrainingWorker`` which consumes the shuffled data +# from each shard. +# +# For simplicity, we will define a +# `Ray Actor `_ that emulates +# training workers. Specifically, +# +# 1. It takes one shard of the shuffle pipeline for training; +# 2. It iterates over the shard to get a training dataset per epoch; +# 3. It then consumes the dataset by batches; + + +@ray.remote +class TrainingWorker: + def __init__(self, rank: int, shard: DatasetPipeline): + self.rank = rank + self.shard = shard + + def train(self): + for epoch, training_dataset in enumerate(self.shard.iter_datasets()): + # Following code emulates epoch based SGD training. + print(f"Training... worker: {self.rank}, epoch: {epoch}") + for i, batch in enumerate(training_dataset.iter_batches()): + # TODO: replace the code for real training. + pass + + +########################################################################### +# Let's run it +# ----------------------------- +# +# Now let’s run the data pipeline end-to-end: +# +# First, let's parse some arguments. + +parser = argparse.ArgumentParser() +parser.add_argument( + "--large-scale-test", + action="store_true", + help="Run large scale test (500GiB of data).") + +args, _ = parser.parse_known_args() + +############################################################################### +# +# After that, let's generate 100MiB of Parquet files, +# create the shuffle pipeline by reading those generated Parquet files, +# and use training workers to consume the pipeline. + +if not args.large_scale_test: + + NUM_TRAINING_WORKERS = 4 + NUM_EPOCHS = 5 + NUM_COLUMNS = 10 + SIZE_100MiB = 100 * 1024 * 1024 + + # create a local ray cluster. + ray.init() + + def generate_example_files(size_bytes: int) -> str: + tmpdir = tempfile.mkdtemp() + ray.data.read_datasource( + RandomIntRowDatasource(), + n=size_bytes // 8 // NUM_COLUMNS, + num_columns=NUM_COLUMNS).write_parquet(tmpdir) + return tmpdir + + example_files_dir = generate_example_files(SIZE_100MiB) + + splits = create_shuffle_pipeline(example_files_dir, NUM_EPOCHS, + NUM_TRAINING_WORKERS) + + training_workers = [ + TrainingWorker.remote(rank, shard) for rank, shard in enumerate(splits) + ] + + # Let's run the e2e pipeline + start = time.time() + ray.get([worker.train.remote() for worker in training_workers]) + print(f"total ingestion time: {int(time.time() - start)}s") + + # -> Write Progress: 100%|████████████████████| 201/201 [00:00<00:00, 228.67it/s] + # -> Stage 0: 0%| | 0/5 [00:00 Stage 0: 40%|████ | 2/5 [00:11<00:17, 5.75s/it] + # -> Stage 0: 60%|██████ | 3/5 [00:23<00:16, 8.15s/it] + # -> ... + # -> (TrainingWorker pid=1651600) Training... worker: 2, epoch: 0 + # -> Stage 0: 80%|████████ | 4/5 [00:35<00:09, 9.59s/it] + # -> ... + # -> (TrainingWorker pid=1651599) Training... worker: 0, epoch: 1 + # -> Stage 0: 100%|██████████| 5/5 [00:46<00:00, 10.34s/it] + # -> ... + # -> (TrainingWorker pid=1651387) Training... worker: 3, epoch: 4 + # -> total ingestion time: 61s + +################################################################################# +# Scale the shuffle ingestion pipeline +# -------------------------------------------------------- +# +# Scaling the shuffle ingestion pipeline is simple. With Ray, we can linearly +# scale the pipeline from ingesting 100MiB of data to 500GiB of data by adding +# more machines. +# +# To ingest 500GiB of data, we'll set up a Ray Cluster. +# The provided :download:`big_data_ingestion.yaml <../big_data_ingestion.yaml>` +# cluster config can be used to set up an AWS cluster with 70 CPU nodes and +# 16 GPU nodes. Using following command to bring up the Ray cluster. +# +# .. code-block:: bash +# +# $ pip install ray boto3 +# $ ray up big_data_ingestion.yaml +# +# After the cluster is started, let's implement our large scale ingestion test: +# +# First, since we are runing on a cluster, let's create the pipeline from +# RandomIntRowDatasource directly. In this way we don't need to set up S3 for storing +# generated data. + + +def create_large_shuffle_pipeline(data_size_bytes: int, num_epochs: int, + num_columns: int, + num_shards: int) -> List[DatasetPipeline]: + # _spread_resource_prefix is used to ensure tasks are evenly spread to all + # CPU nodes. + return ray.data.read_datasource( + RandomIntRowDatasource(), n=data_size_bytes // 8 // num_columns, + num_columns=num_columns, + _spread_resource_prefix="node:") \ + .repeat(num_epochs) \ + .random_shuffle_each_window(_spread_resource_prefix="node:") \ + .split(num_shards, equal=True) + + +################################################################################# +# +# Now, it's time to implement the 500GiB shuffle ingestion pipeline. + +if args.large_scale_test: + NUM_TRAINING_WORKERS = 16 + NUM_EPOCHS = 5 + NUM_COLUMNS = 10 + GiB = 1024 * 1024 * 1024 + SIZE_500GiB = 500 * GiB + TOTAL_NUM_NODES = 70 + 16 + 1 + + # use the AWS cluster we just set up. + ray.init(address="auto") + + # waiting for cluster nodes to come up. + while len(ray.nodes()) < TOTAL_NUM_NODES: + print( + f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}" + ) + time.sleep(5) + + splits = create_large_shuffle_pipeline(SIZE_500GiB, NUM_EPOCHS, + NUM_COLUMNS, NUM_TRAINING_WORKERS) + + # Note we set num_gpus=1 for workers so that + # the workers will only run on GPU nodes. + training_workers = [ + TrainingWorker.options(num_gpus=1) \ + .remote(rank, shard) for rank, shard in enumerate(splits) + ] + + start = time.time() + + # Let's run the large scale test. + ray.get([worker.train.remote() for worker in training_workers]) + print(f"total ingestion time: {int(time.time() - start)}s") + throughput = SIZE_500GiB * NUM_EPOCHS / (time.time() - start) / GiB + print("throughput: {0:0.2f}GiB/s".format(throughput)) + +################################################################################# +# +# Finally, let's run our pipeline on the cluster we just started: +# +# .. code-block:: bash +# +# $ ray submit ./big_data_ingestion.yaml ./big_data_ingestion.py --large-scale-test +# # -> Connecting to existing Ray cluster at address: 172.31.47.38:6379 +# # -> waiting for nodes to start up: 1/87 +# # -> ... +# # -> waiting for nodes to start up: 87/87 +# # -> Stage 0: 0%| | 0/5 [00:00 Stage 0: 20%|██ | 1/5 [00:00<00:02, 1.77it/s] +# # -> Stage 0: 40%|████ | 2/5 [00:38<00:35, 11.67s/it] +# # -> Stage 0: 60%|██████ | 3/5 [01:13<00:37, 18.83s/it] +# # -> ... +# # -> (TrainingWorker pid=5084, ip=172.31.35.245) Training... worker: 12, epoch: 0 +# # -> Stage 0: 80%|████████ | 4/5 [03:15<00:49, 49.63s/it] +# # -> ... +# # -> (TrainingWorker pid=5076, ip=172.31.40.190) Training... worker: 9, epoch: 1 +# # -> Stage 0: 100%|██████████| 5/5 [05:02<00:00, 67.01s/it] +# # -> ... +# # -> (TrainingWorker pid=5074, ip=172.31.40.190) Training... worker: 0, epoch: 4 +# # -> total ingestion time: 291s +# # -> throughput: 8.56GiB/s diff --git a/doc/source/data/big_data_ingestion.yaml b/doc/source/data/big_data_ingestion.yaml new file mode 100644 index 0000000000000..2609afdf4426d --- /dev/null +++ b/doc/source/data/big_data_ingestion.yaml @@ -0,0 +1,54 @@ +cluster_name: big_data_ingestion.yaml + +max_workers: 86 + +provider: + type: aws + region: us-west-1 + +auth: + ssh_user: ubuntu + +available_node_types: + head: + node_config: + InstanceType: i3.8xlarge + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 300 + resources: { } + + gpu_nodes: + min_workers: 16 + max_workers: 16 + node_config: + InstanceType: i3.8xlarge + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 300 + resources: + GPU: 1 + + memory_nodes: + min_workers: 70 + max_workers: 70 + node_config: + InstanceType: i3.8xlarge + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 300 + resources: { } + +head_node_type: head + +setup_commands: + - pip install -U ray ray[default] pyarrow pandas + +head_start_ray_commands: + - ray start --head --port=6379 --object-manager-port=8076 --object-store-memory=90000000000 --autoscaling-config=~/ray_bootstrap_config.yaml + +worker_start_ray_commands: + - ray start --address=$RAY_HEAD_IP:6379 --object-manager-port=8076 --object-store-memory=90000000000 diff --git a/doc/source/data/dask-on-ray.rst b/doc/source/data/dask-on-ray.rst index 6057b740db441..9e08977bdb16e 100644 --- a/doc/source/data/dask-on-ray.rst +++ b/doc/source/data/dask-on-ray.rst @@ -6,16 +6,16 @@ Dask on Ray `Dask `__ is a Python parallel computing library geared towards scaling analytics and scientific computing workloads. It provides `big data collections `__ that mimic the APIs of -the familiar `NumPy `__ and `Pandas `__ libraries, +the familiar `NumPy `__ and `Pandas `__ libraries, allowing those abstractions to represent -larger-than-memory data and/or allowing operations on that data to be run on a multi-machine cluster, +larger-than-memory data and/or allowing operations on that data to be run on a multi-machine cluster, while also providing automatic data parallelism, smart scheduling, and optimized operations. Operations on these collections create a task graph, which is executed by a scheduler. Ray provides a scheduler for Dask (`dask_on_ray`) which allows you to build data analyses using Dask's collections and execute -the underlying tasks on a Ray cluster. +the underlying tasks on a Ray cluster. `dask_on_ray` uses Dask's scheduler API, which allows you to specify any callable as the scheduler that you would like Dask to use to execute your @@ -30,8 +30,12 @@ workload. Using the Dask-on-Ray scheduler, the entire Dask ecosystem can be exec * - Ray Version - Dask Version + * - ``1.7.0`` + - ``2021.9.1`` + * - ``1.6.0`` + - ``2021.8.1`` * - ``1.5.0`` - - ``2021.7.0`` + - ``2021.7.0`` * - ``1.4.1`` - ``2021.6.1`` * - ``1.4.0`` @@ -82,7 +86,7 @@ In this case, there are two recommended setup. # Head node. Set `num_cpus=0` to avoid tasks are being scheduled on a head node. RAY_SCHEDULER_SPREAD_THRESHOLD=0.0 ray start --head --num-cpus=0 - # Worker node. + # Worker node. RAY_SCHEDULER_SPREAD_THRESHOLD=0.0 ray start --address=[head-node-address] Out-of-Core Data Processing @@ -101,10 +105,10 @@ Persist .. _dask-on-ray-persist: -Dask-on-Ray patches `dask.persist() -`__ in order to match `Dask +Dask-on-Ray patches `dask.persist() +`__ in order to match `Dask Distributed's persist semantics -`; namely, calling `dask.persist()` with a Dask-on-Ray +`; namely, calling `dask.persist()` with a Dask-on-Ray scheduler will submit the tasks to the Ray cluster and return Ray futures inlined in the Dask collection. This is nice if you wish to compute some base collection (such as a Dask array), followed by multiple different downstream computations (such as diff --git a/doc/source/data/dataset-pipeline.rst b/doc/source/data/dataset-pipeline.rst index 8b60ca3cb7985..d954df8051eb5 100644 --- a/doc/source/data/dataset-pipeline.rst +++ b/doc/source/data/dataset-pipeline.rst @@ -6,12 +6,12 @@ Overview Datasets execute their transformations synchronously in blocking calls. However, it can be useful to overlap dataset computations with output. This can be done with a `DatasetPipeline `__. -A DatasetPipeline is an unified iterator over a (potentially infinite) sequence of Ray Datasets. Conceptually it is similar to a `Spark DStream `__, but manages execution over a bounded amount of source data instead of an unbounded stream. Ray computes each dataset on-demand and stitches their output together into a single logical data iterator. DatasetPipeline implements most of the same transformation and output methods as Datasets (e.g., map, filter, split, iter_rows, to_torch, etc.). +A DatasetPipeline is an unified iterator over a (potentially infinite) sequence of Ray Datasets, each of which represents a *window* over the original data. Conceptually it is similar to a `Spark DStream `__, but manages execution over a bounded amount of source data instead of an unbounded stream. Ray computes each dataset window on-demand and stitches their output together into a single logical data iterator. DatasetPipeline implements most of the same transformation and output methods as Datasets (e.g., map, filter, split, iter_rows, to_torch, etc.). Creating a DatasetPipeline ~~~~~~~~~~~~~~~~~~~~~~~~~~ -A DatasetPipeline can be constructed in two ways: either by pipelining the execution of an existing Dataset (via ``Dataset.pipeline``), or generating repeats of an existing Dataset (via ``Dataset.repeat``). Similar to Datasets, you can freely pass DatasetPipelines between Ray tasks, actors, and libraries. Get started with this synthetic data example: +A DatasetPipeline can be constructed in two ways: either by pipelining the execution of an existing Dataset (via ``Dataset.window``), or generating repeats of an existing Dataset (via ``Dataset.repeat``). Similar to Datasets, you can freely pass DatasetPipelines between Ray tasks, actors, and libraries. Get started with this synthetic data example: .. code-block:: python @@ -30,16 +30,16 @@ A DatasetPipeline can be constructed in two ways: either by pipelining the execu base = ray.data.range(1000000) print(base) # -> Dataset(num_blocks=200, num_rows=1000000, schema=) - pipe = base.pipeline(parallelism=10) + pipe = base.window(blocks_per_window=10) print(pipe) - # -> DatasetPipeline(length=20, num_stages=1) + # -> DatasetPipeline(num_windows=20, num_stages=1) # Applying transforms to pipelines adds more pipeline stages. pipe = pipe.map(func1) pipe = pipe.map(func2) pipe = pipe.map(func3) print(pipe) - # -> DatasetPipeline(length=20, num_stages=4) + # -> DatasetPipeline(num_windows=20, num_stages=4) # Output can be pulled from the pipeline concurrently with its execution. num_rows = 0 @@ -53,8 +53,7 @@ A DatasetPipeline can be constructed in two ways: either by pipelining the execu print("Total num rows", num_rows) # -> Total num rows 1000000 - -You can also create a DatasetPipeline from a custom iterator over dataset creators using ``DatasetPipeline.from_iterable``. For example, this is how you would implement ``Dataset.repeat`` and ``Dataset.pipeline`` using ``from_iterable``: +You can also create a DatasetPipeline from a custom iterator over dataset creators using ``DatasetPipeline.from_iterable``. For example, this is how you would implement ``Dataset.repeat`` and ``Dataset.window`` using ``from_iterable``: .. code-block:: python @@ -66,10 +65,52 @@ You can also create a DatasetPipeline from a custom iterator over dataset creato pipe = DatasetPipeline.from_iterable( [lambda: source, lambda: source, lambda: source, lambda: source]) - # Equivalent to ray.data.range(1000).pipeline(parallelism=10) + # Equivalent to ray.data.range(1000).window(blocks_per_window=10) splits = ray.data.range(1000, parallelism=200).split(20) pipe = DatasetPipeline.from_iterable([lambda s=s: s for s in splits]) +Per-Window Transformations +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +While most Dataset operations are per-row (e.g., map, filter), some operations apply to the Dataset as a whole (e.g., sort, shuffle). When applied to a pipeline, holistic transforms like shuffle are applied separately to each window in the pipeline: + +.. code-block:: python + + # Example of randomly shuffling each window of a pipeline. + ray.data.range(5).repeat(2).random_shuffle_each_window().show_windows() + # -> + # === Window 0 === + # 4 + # 3 + # 1 + # 0 + # 2 + # === Window 1 === + # 2 + # 1 + # 4 + # 0 + # 3 + +You can also apply arbitrary transformations to each window using ``DatasetPipeline.foreach_window()``: + +.. code-block:: python + + # Equivalent transformation using .foreach_window() + ray.data.range(5).repeat(2).foreach_window(lambda w: w.random_shuffle()).show_windows() + # -> + # === Window 0 === + # 1 + # 0 + # 4 + # 2 + # 3 + # === Window 1 === + # 4 + # 2 + # 0 + # 3 + # 1 Example: Pipelined Batch Inference ---------------------------------- @@ -109,28 +150,28 @@ Ignoring the output, the above script has three separate stages: loading, prepro Enabling Pipelining ~~~~~~~~~~~~~~~~~~~ -We can optimize this by *pipelining* the execution of the dataset with the ``.pipeline()`` call, which returns a DatasetPIpeline instead of a Dataset object. The pipeline supports similar transformations to the original Dataset: +We can optimize this by *pipelining* the execution of the dataset with the ``.window()`` call, which returns a DatasetPipeline instead of a Dataset object. The pipeline supports similar transformations to the original Dataset: .. code-block:: python # Convert the Dataset into a DatasetPipeline. pipe: DatasetPipeline = ray.data \ .read_binary_files("s3://bucket/image-dir") \ - .pipeline(parallelism=2) + .window(blocks_per_window=2) # The remainder of the steps do not change. pipe = pipe.map(preprocess) pipe = pipe.map_batches(BatchInferModel, compute="actors", batch_size=256, num_gpus=1) pipe.write_json("/tmp/results") -Here we specified ``parallelism=2``, which means that the Dataset is split into smaller sub-Datasets of two blocks each. Each transformation or *stage* of the pipeline is operating over these two-block Datasets in parallel. This means batch inference processing can start as soon as two blocks are read and preprocessed, greatly reducing the GPU idle time: +Here we specified ``blocks_per_window=2``, which means that the Dataset is split into smaller sub-Datasets of two blocks each. Each transformation or *stage* of the pipeline is operating over these two-block Datasets in parallel. This means batch inference processing can start as soon as two blocks are read and preprocessed, greatly reducing the GPU idle time: .. image:: dataset-pipeline-2.svg Tuning Parallelism ~~~~~~~~~~~~~~~~~~ -Tune the throughput vs latency of your pipeline with the ``parallelism`` setting. As a rule of thumb, higher parallelism settings perform better, however ``parallelism == num_blocks`` effectively disables pipelining, since the DatasetPipeline will only contain a single Dataset. The other extreme is setting ``parallelism=1``, which minimizes the latency to initial output but only allows one concurrent transformation task per stage: +Tune the throughput vs latency of your pipeline with the ``blocks_per_window`` setting. As a rule of thumb, higher parallelism settings perform better, however ``blocks_per_window == num_blocks`` effectively disables pipelining, since the DatasetPipeline will only contain a single Dataset. The other extreme is setting ``blocks_per_window=1``, which minimizes the latency to initial output but only allows one concurrent transformation task per stage: .. image:: dataset-pipeline-3.svg @@ -155,7 +196,7 @@ Transformations made prior to the Dataset prior to the call to ``.repeat()`` are pipe: DatasetPipeline = ray.data \ .read_datasource(...) \ .repeat() \ - .random_shuffle() + .random_shuffle_each_window() @ray.remote(num_gpus=1) def train_func(pipe: DatasetPipeline): @@ -184,7 +225,7 @@ Similar to how you can ``.split()`` a Dataset, you can also split a DatasetPipel pipe: DatasetPipeline = ray.data \ .read_parquet("s3://bucket/dir") \ .repeat() \ - .random_shuffle() + .random_shuffle_each_window() @ray.remote(num_gpus=1) class TrainingWorker: @@ -201,3 +242,55 @@ Similar to how you can ``.split()`` a Dataset, you can also split a DatasetPipel **Pipeline**: .. image:: dataset-repeat-2.svg + +Changing Pipeline Structure +--------------------------- + +Sometimes, you may want to change the structure of an existing pipeline. For example, after generating a pipeline with ``ds.window(k)``, you may want to repeat that windowed pipeline ``n`` times. This can be done with ``ds.window(k).repeat(n)``. As another example, suppose you have a repeating pipeline generated with ``ds.repeat(n)``. The windowing of that pipeline can be changed with ``ds.repeat(n).rewindow(k)``. Note the subtle difference in the two examples: the former is repeating a windowed pipeline that has a base window size of ``k``, while the latter is re-windowing a pipeline of initial window size of ``ds.num_blocks()``. The latter may produce windows that span multiple copies of the same original data: + +.. code-block:: python + + # Window followed by repeat. + ray.data.range(5) \ + .window(blocks_per_window=2) \ + .repeat(2) \ + .show_windows() + # -> + # === Window 0 === + # 0 + # 1 + # === Window 1 === + # 2 + # 3 + # === Window 2 === + # 4 + # === Window 3 === + # 0 + # 1 + # === Window 4 === + # 2 + # 3 + # === Window 5 === + # 4 + + # Repeat followed by window. + ray.data.range(5) \ + .repeat(2) \ + .rewindow(blocks_per_window=2) \ + .show_windows() + # -> + # === Window 0 === + # 0 + # 1 + # === Window 1 === + # 2 + # 3 + # === Window 2 === + # 4 + # 0 + # === Window 3 === + # 1 + # 2 + # === Window 4 === + # 3 + # 4 diff --git a/doc/source/data/dataset-tensor-support.rst b/doc/source/data/dataset-tensor-support.rst index b8a4ad68eed4e..d2d3ebf40c6f1 100644 --- a/doc/source/data/dataset-tensor-support.rst +++ b/doc/source/data/dataset-tensor-support.rst @@ -3,66 +3,34 @@ Dataset Tensor Support ====================== -Tensor-typed values -------------------- +Tables with tensor columns +-------------------------- + +Datasets supports tables with fixed-shape tensor columns, where each element in the column is a tensor (n-dimensional array) with the same shape. As an example, this allows you to use Pandas and Ray Datasets to read, write, and manipulate e.g., images. All conversions between Pandas, Arrow, and Parquet, and all application of aggregations/operations to the underlying image ndarrays are taken care of by Ray Datasets. + +With our Pandas extension type, :class:`TensorDtype `, and extension array, :class:`TensorArray `, you can do familiar aggregations and arithmetic, comparison, and logical operations on a DataFrame containing a tensor column and the operations will be applied to the underlying tensors as expected. With our Arrow extension type, :class:`ArrowTensorType `, and extension array, :class:`ArrowTensorArray `, you'll be able to import that DataFrame into Ray Datasets and read/write the data from/to the Parquet format. + +Automatic conversion between the Pandas and Arrow extension types/arrays keeps the details under-the-hood, so you only have to worry about casting the column to a tensor column using our Pandas extension type when first ingesting the table into a ``Dataset``, whether from storage or in-memory. All table operations downstream from that cast should work automatically. -Datasets support tensor-typed values, which are represented in-memory as Arrow tensors (i.e., np.ndarray format). Tensor datasets can be read from and written to ``.npy`` files. Here are some examples: +Single-column tensor datasets +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The most basic case is when a dataset only has a single column, which is of tensor type. This kind of dataset can be created with ``.range_tensor()``, and can be read from and written to ``.npy`` files. Here are some examples: .. code-block:: python # Create a Dataset of tensor-typed values. ds = ray.data.range_tensor(10000, shape=(3, 5)) # -> Dataset(num_blocks=200, num_rows=10000, - # schema=) - - ds.map_batches(lambda t: t + 2).show(2) - # -> [[2 2 2 2 2] - # [2 2 2 2 2] - # [2 2 2 2 2]] - # [[3 3 3 3 3] - # [3 3 3 3 3] - # [3 3 3 3 3]] + # schema={value: }) # Save to storage. - ds.write_numpy("/tmp/tensor_out") + ds.write_numpy("/tmp/tensor_out", column="value") # Read from storage. ray.data.read_numpy("/tmp/tensor_out") # -> Dataset(num_blocks=200, num_rows=?, - # schema=) - -Tensor datasets are also created whenever an array type is returned from a map function: - -.. code-block:: python - - # Create a dataset of Python integers. - ds = ray.data.range(10) - # -> Dataset(num_blocks=10, num_rows=10, schema=) - - # It is now converted into a Tensor dataset. - ds = ds.map_batches(lambda x: np.array(x)) - # -> Dataset(num_blocks=10, num_rows=10, - # schema=) - -Tensor datasets can also be created from NumPy ndarrays that are already stored in the Ray object store: - -.. code-block:: python - - import numpy as np - - # Create a Dataset from a list of NumPy ndarray objects. - arr1 = np.arange(0, 10) - arr2 = np.arange(10, 20) - ds = ray.data.from_numpy([ray.put(arr1), ray.put(arr2)]) - -Tables with tensor columns --------------------------- - -In addition to tensor datasets, Datasets also supports tables with fixed-shape tensor columns, where each element in the column is a tensor (n-dimensional array) with the same shape. As an example, this allows you to use both Pandas and Ray Datasets to read, write, and manipulate a table with a column of e.g. images (2D arrays), with all conversions between Pandas, Arrow, and Parquet, and all application of aggregations/operations to the underlying image ndarrays, being taken care of by Ray Datasets. - -With our Pandas extension type, :class:`TensorDtype `, and extension array, :class:`TensorArray `, you can do familiar aggregations and arithmetic, comparison, and logical operations on a DataFrame containing a tensor column and the operations will be applied to the underlying tensors as expected. With our Arrow extension type, :class:`ArrowTensorType `, and extension array, :class:`ArrowTensorArray `, you'll be able to import that DataFrame into Ray Datasets and read/write the data from/to the Parquet format. - -Automatic conversion between the Pandas and Arrow extension types/arrays keeps the details under-the-hood, so you only have to worry about casting the column to a tensor column using our Pandas extension type when first ingesting the table into a ``Dataset``, whether from storage or in-memory. All table operations downstream from that cast should work automatically. + # schema={value: }) Reading existing serialized tensor columns ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -87,7 +55,7 @@ If you already have a Parquet dataset with columns containing serialized tensors # Write the dataset to Parquet. The tensor column will be written as an # array of opaque byte blobs. - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(path) # Read the Parquet files into a new Dataset, with the serialized tensors @@ -117,7 +85,7 @@ If your serialized tensors don't fit the above constraints (e.g. they're stored # Write the dataset to Parquet. The tensor column will be written as an # array of opaque byte blobs. - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(path) # Manually deserialize the tensor pickle bytes and cast to our tensor @@ -150,7 +118,7 @@ Now that the tensor column is properly typed and in a ``Dataset``, we can perfor # Arrow and Pandas is now aware of this tensor column, so we can do the # typical DataFrame operations on this column. - ds = ds.map_batches(lambda x: 2 * (x + 1), format="pandas") + ds = ds.map_batches(lambda x: 2 * (x + 1), batch_format="pandas") # -> Map Progress: 100%|████████████████████| 200/200 [00:00<00:00, 1123.54it/s] print(ds) # -> Dataset( @@ -244,7 +212,7 @@ If working with in-memory Pandas DataFrames that you want to analyze, manipulate # In addition to doing Pandas operations on the tensor column, # you can now put the DataFrame directly into a Dataset. - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) # Internally, this column is represented with the corresponding # Arrow tensor extension type. print(ds.schema()) @@ -259,7 +227,7 @@ If working with in-memory Pandas DataFrames that you want to analyze, manipulate # -> one: int64 # two: extension> - read_df = ray.get(read_ds.to_pandas())[0] + read_df = read_ds.to_pandas() print(read_df.dtypes) # -> one int64 # two TensorDtype diff --git a/doc/source/data/dataset.rst b/doc/source/data/dataset.rst index 7142691e5df45..20018765c1a69 100644 --- a/doc/source/data/dataset.rst +++ b/doc/source/data/dataset.rst @@ -16,7 +16,7 @@ Ray Datasets are the standard way to load and exchange data in Ray libraries and Concepts -------- -Ray Datasets implement `Distributed Arrow `__. A Dataset consists of a list of Ray object references to *blocks*. Each block holds a set of items in either an `Arrow table `__, `Arrow tensor `__, or a Python list (for Arrow incompatible objects). Having multiple blocks in a dataset allows for parallel transformation and ingest of the data. +Ray Datasets implement `Distributed Arrow `__. A Dataset consists of a list of Ray object references to *blocks*. Each block holds a set of items in either an `Arrow table `__ or a Python list (for Arrow incompatible objects). Having multiple blocks in a dataset allows for parallel transformation and ingest of the data. The following figure visualizes a Dataset that has three Arrow table blocks, each block holding 1000 rows each: @@ -145,6 +145,10 @@ Datasource Compatibility Matrices Creating Datasets ----------------- +.. tip:: + + Run ``pip install ray[data]`` to get started! + Get started by creating Datasets from synthetic data using ``ray.data.range()`` and ``ray.data.from_items()``. Datasets can hold either plain Python objects (schema is a Python type), or Arrow records (schema is Arrow). .. code-block:: python @@ -198,7 +202,7 @@ Finally, you can create a ``Dataset`` from existing data in the Ray object store # Create a Dataset from a list of Pandas DataFrame objects. pdf = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([ray.put(pdf)]) + ds = ray.data.from_pandas([pdf]) # Create a Dataset from a Dask-on-Ray DataFrame. dask_df = dd.from_pandas(pdf, npartitions=10) diff --git a/doc/source/data/package-ref.rst b/doc/source/data/package-ref.rst index 0af38ba8297c7..afdace98bf719 100644 --- a/doc/source/data/package-ref.rst +++ b/doc/source/data/package-ref.rst @@ -15,11 +15,13 @@ Creating a Dataset .. autofunction:: ray.data.read_datasource .. autofunction:: ray.data.from_items .. autofunction:: ray.data.from_arrow +.. autofunction:: ray.data.from_arrow_refs .. autofunction:: ray.data.from_spark .. autofunction:: ray.data.from_dask .. autofunction:: ray.data.from_modin .. autofunction:: ray.data.from_mars .. autofunction:: ray.data.from_pandas +.. autofunction:: ray.data.from_pandas_refs .. autofunction:: ray.data.from_numpy Dataset API diff --git a/doc/source/development.rst b/doc/source/development.rst index d672c2b3fb5a0..f41b48d14fdef 100644 --- a/doc/source/development.rst +++ b/doc/source/development.rst @@ -100,8 +100,9 @@ Ray can be built from the repository as follows. git clone https://github.com/ray-project/ray.git # Install Bazel. - # (Windows users: please manually place Bazel in your PATH, and point BAZEL_SH to MSYS2's Bash.) ray/ci/travis/install-bazel.sh + # (Windows users: please manually place Bazel in your PATH, and point + # BAZEL_SH to MSYS2's Bash: ``set BAZEL_SH=C:\Program Files\Git\bin\bash.exe``) # Build the dashboard # (requires Node.js, see https://nodejs.org/ for more information). @@ -126,7 +127,7 @@ Building Ray on Windows (full) The following links were correct during the writing of this section. In case the URLs changed, search at the organizations' sites. -- bazel 3.4 (https://github.com/bazelbuild/bazel/releases/tag/3.4.0) +- bazel 4.2 (https://github.com/bazelbuild/bazel/releases/tag/4.2.1) - Microsoft Visual Studio 2019 (or Microsoft Build Tools 2019 - https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2019) - JDK 15 (https://www.oracle.com/java/technologies/javase-jdk15-downloads.html) - Miniconda 3 (https://docs.conda.io/en/latest/miniconda.html) @@ -149,7 +150,11 @@ The following links were correct during the writing of this section. In case the 3. Define an environment variable BAZEL_SH to point to bash.exe. If git for Windows was installed for all users, bash's path should be ``C:\Program Files\Git\bin\bash.exe``. If git was installed for a single user, adjust the path accordingly. -4. Bazel 3.4 installation. Go to bazel 3.4 release web page and download bazel-3.4.0-windows-x86_64.exe. Copy the exe into the directory of your choice. Define an environment variable BAZEL_PATH to full exe path (example: ``C:\bazel\bazel-3.4.0-windows-x86_64.exe``) +4. Bazel 4.2 installation. Go to bazel 4.2 release web page and download +bazel-4.2.1-windows-x86_64.exe. Copy the exe into the directory of your choice. +Define an environment variable BAZEL_PATH to full exe path (example: +``set BAZEL_PATH=C:\bazel\bazel.exe``). Also add the bazel directory to the +``PATH`` (example: ``set PATH=%PATH%;C:\bazel``) 5. Install cython and pytest: diff --git a/doc/source/index.rst b/doc/source/index.rst index 2024802af37d7..784df20c59e07 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -277,8 +277,9 @@ Papers :caption: Ray Data data/dataset.rst - data/dataset-tensor-support.rst data/dataset-pipeline.rst + data/examples/big_data_ingestion + data/dataset-tensor-support.rst data/package-ref.rst data/dask-on-ray.rst data/mars-on-ray.rst @@ -338,6 +339,7 @@ Papers raysgd/v2/examples.rst raysgd/v2/architecture.rst raysgd/v2/api.rst + raysgd/v2/migration-guide.rst RaySGD v1: Distributed Training Wrappers .. toctree:: @@ -365,7 +367,7 @@ Papers .. toctree:: :hidden: :maxdepth: -1 - :caption: Contributing + :caption: Contributor Guide getting-involved.rst development.rst diff --git a/doc/source/raysgd/raysgd.rst b/doc/source/raysgd/raysgd.rst index 87696e68d6535..55ddcdb389fc1 100644 --- a/doc/source/raysgd/raysgd.rst +++ b/doc/source/raysgd/raysgd.rst @@ -6,7 +6,7 @@ RaySGD: Distributed Training Wrappers .. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. - See the documentation :ref:`here `. + See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. RaySGD is a lightweight library for distributed deep learning, providing thin wrappers around PyTorch and TensorFlow native modules for data parallel training. diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index 5e9c1ce099141..635d003e55032 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -3,13 +3,16 @@ Distributed PyTorch =================== +.. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. + See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. + The RaySGD ``TorchTrainer`` simplifies distributed model training for PyTorch. .. image:: raysgd-actors.svg :align: center -.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! +.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! The ``TorchTrainer`` is a wrapper around ``torch.distributed.launch`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to needing to wrap your training code in bash scripts. diff --git a/doc/source/raysgd/raysgd_tensorflow.rst b/doc/source/raysgd/raysgd_tensorflow.rst index f18d7f9ec3924..2cbf01da2e3c3 100644 --- a/doc/source/raysgd/raysgd_tensorflow.rst +++ b/doc/source/raysgd/raysgd_tensorflow.rst @@ -1,6 +1,9 @@ Distributed TensorFlow ====================== +.. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. + See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. + RaySGD's ``TFTrainer`` simplifies distributed model training for Tensorflow. The ``TFTrainer`` is a wrapper around ``MultiWorkerMirroredStrategy`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to write custom logic of setting environments and starting separate processes. Under the hood, ``TFTrainer`` will create *replicas* of your model (controlled by ``num_replicas``), each of which is managed by a Ray actor. @@ -8,7 +11,7 @@ Under the hood, ``TFTrainer`` will create *replicas* of your model (controlled b .. image:: raysgd-actors.svg :align: center -.. tip:: We need your feedback! RaySGD is currently early in its development, and we're hoping to get feedback from people using or considering it. We'd love `to get in touch `_! +.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! ---------- diff --git a/doc/source/raysgd/raysgd_tune.rst b/doc/source/raysgd/raysgd_tune.rst index cacaea0a20c4e..740ff78b0390c 100644 --- a/doc/source/raysgd/raysgd_tune.rst +++ b/doc/source/raysgd/raysgd_tune.rst @@ -3,6 +3,9 @@ RaySGD Hyperparameter Tuning ============================ +.. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. + See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. + RaySGD integrates with :ref:`Ray Tune ` to easily run distributed hyperparameter tuning experiments with your RaySGD Trainer. PyTorch diff --git a/doc/source/raysgd/v2/api.rst b/doc/source/raysgd/v2/api.rst index fc3028bc9fc19..97b48a26b11ce 100644 --- a/doc/source/raysgd/v2/api.rst +++ b/doc/source/raysgd/v2/api.rst @@ -22,10 +22,8 @@ SGDIterator .. _sgd-api-backend-config: -BackendConfig -------------- - -.. autoclass:: ray.sgd.BackendConfig +Backend Configurations +---------------------- .. _sgd-api-torch-config: @@ -48,10 +46,14 @@ HorovodConfig .. autoclass:: ray.sgd.HorovodConfig + +Callbacks +--------- + .. _sgd-api-callback: SGDCallback ------------ +~~~~~~~~~~~ .. autoclass:: ray.sgd.SGDCallback :members: @@ -61,19 +63,22 @@ SGDCallback JsonLoggerCallback ~~~~~~~~~~~~~~~~~~ -.. autoclass:: ray.sgd.JsonLoggerCallback +.. autoclass:: ray.sgd.callbacks.JsonLoggerCallback .. _sgd-api-tbx-logger-callback: TBXLoggerCallback ~~~~~~~~~~~~~~~~~ -.. autoclass:: ray.sgd.TBXLoggerCallback +.. autoclass:: ray.sgd.callbacks.TBXLoggerCallback + +Checkpointing +------------- .. _sgd-api-checkpoint-strategy: CheckpointStrategy ------------------- +~~~~~~~~~~~~~~~~~~ .. autoclass:: ray.sgd.CheckpointStrategy diff --git a/doc/source/raysgd/v2/examples.rst b/doc/source/raysgd/v2/examples.rst index a35f394c7593c..3edee334aea2a 100644 --- a/doc/source/raysgd/v2/examples.rst +++ b/doc/source/raysgd/v2/examples.rst @@ -61,6 +61,9 @@ Ray Tune Integration Examples * :doc:`/raysgd/v2/examples/tune_tensorflow_mnist_example`: End-to-end example for tuning a TensorFlow model. +* :doc:`/raysgd/v2/examples/tune_cifar_pytorch_pbt_example`: + End-to-end example for tuning a PyTorch model with PBT. + .. TODO implement these examples! diff --git a/doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst b/doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst new file mode 100644 index 0000000000000..31aabc7ca78ab --- /dev/null +++ b/doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst @@ -0,0 +1,6 @@ +:orphan: + +tune_cifar_pytorch_pbt_example +============================== + +.. literalinclude:: /../../python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py diff --git a/doc/source/raysgd/v2/migration-guide.rst b/doc/source/raysgd/v2/migration-guide.rst new file mode 100644 index 0000000000000..08effe4b25e98 --- /dev/null +++ b/doc/source/raysgd/v2/migration-guide.rst @@ -0,0 +1,393 @@ +.. _sgd-migration: + +Migrating from Ray SGD v1 +========================= + +In Ray 1.7, we are rolling out a new and more streamlined version of Ray SGD. Ray SGD v2 focuses on usability and composability - it has a much simpler API, has support for more deep learning backends, integrates better with other libraries in the Ray ecosystem, and will continue to be actively developed with more features. + +This guide will help you easily migrate existing code from Ray SGD v1 to Ray SGD v2. If you are new to Ray SGD as a whole, you should get started with :ref:`Ray SGD v2 directly `. + +For a full list of features that Ray SGD v2 provides, please check out the :ref:`user guide`. + +.. note:: If there are any issues or anything missing with this guide or any feedback on Ray SGD v2 overall, please file a `Github issue on the Ray repo `_! + +What are the API differences? +----------------------------- + +There are 3 primary API differences between Ray SGD v1 and v2. + +1. There is a single ``Trainer`` interface for all backends (torch, tensorflow, horovod), and the backend is simply specified via an argument: ``Trainer(backend="torch")``\ , ``Trainer(backend="horovod")``\ , etc. Any features that we add to Ray SGD will be supported for all backends, and there won't be any API divergence like there was with a separate ``TorchTrainer`` and ``TFTrainer``. +2. The ``TrainingOperator`` and creator functions are replaced by a more natural user-defined training function. You no longer have to make your training logic fit into a restrictive interface. In Ray SGD v2, you simply have to provide a training function that describes the full logic for your training execution and this will be distributed by Ray SGD v2. + + .. code-block:: python + + from torch.nn.parallel import DistributedDataParallel + from torch import nn, optim + + # Torch Example + def train_func_distributed(): + num_epochs = 3 + model = NeuralNetwork() + model = DistributedDataParallel(model) + loss_fn = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.1) + + for epoch in range(num_epochs): + output = model(input) + loss = loss_fn(output, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + print(f"epoch: {epoch}, loss: {loss.item()}") + + from ray.sgd import Trainer + + trainer = Trainer(backend="torch", num_workers=4) + trainer.start() + results = trainer.run(train_func_distributed) + trainer.shutdown() + +Currently, this means that you are now responsible for modifying your code to support distributed training (specifying ``DistributedDataParallel`` for ``torch`` or ``MultiWorkerMirroredStrategy`` for ``tensorflow``) as opposed to having this be automatically handled internally. However, we have plans to provide utilities that you can use to automatically handle these recipes for you. + +3. Rather than iteratively calling ``trainer.train()`` or ``trainer.validate()`` for each epoch, in Ray SGD v2 the training function defines the full training execution and is run via ``trainer.run(train_func)``. + +In the following sections, we will guide you through the steps to migrate: + +1. :ref:`sgd-migration-logic` +2. :ref:`Interacting with Trainer state (intermediate metrics, checkpointing) ` +3. :ref:`Hyperparameter Tuning with Ray Tune ` + +.. _sgd-migration-logic: + +Training Logic +-------------- +The main change you will have to make is how you define your training logic. In Ray SGD v1, the API for defining training logic differed for `TorchTrainer` vs. `TFTrainer`, so the steps to migrate will be different for each of these. + +PyTorch +~~~~~~~ +In v1, the training logic is defined through the ``train_epoch`` and ``train_batch`` methods of a ``TrainingOperator`` class which is passed into the ``TorchTrainer``. To migrate to Ray SGD v2, there are 2 options: + +1. If you felt the ``TrainingOperator`` is too unnecessary and complex, or you had to customize it extensively, you can define your own training function. +2. If you liked having your training logic in the ``TrainingOperator``, you can continue to use the ``TrainingOperator`` with Ray SGD v2. + +**Alternative 1: Custom Training Function** +You can define your own custom training function, and use only the parts from ``TrainingOperator.train_epoch``, ``TrainingOperator.setup``, and ``TrainingOperator.validate`` that are necessary for your application. + +You can see a full example on how to :ref:`port over regular PyTorch DDP code to Ray SGD here ` + +**Alternative 2: Continue to use TrainingOperator** +Alternatively, if you liked having the ``TrainingOperator``, you can define a training function that instantiates your `TrainingOperator` and you can call methods directly on the operator object. + +So instead of + +.. code-block:: python + + from ray.util.sgd import TrainingOperator, TorchTrainer + + class MyTrainingOperator(TrainingOperator): + ... + + trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=4, use_gpu=True) + + num_epochs=10 + for _ in range(num_epochs): + trainer.train() + trainer.validate() + + final_model = trainer.get_model() + + +you would do + +.. code-block:: python + + from ray.util.sgd import TrainingOperator + from ray.sgd import Trainer + from ray import sgd + + class MyTrainingOperator(TrainingOperator): + ... + + def train_func(config): + device = torch.device(f"cuda:{sgd.local_rank()}" if + torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + torch.cuda.set_device(device) + + # Set the args to whatever values you want. + training_operator = MyTrainingOperator( + config=config, + world_rank=sgd.world_rank(), + local_rank=sgd.local_rank(), + is_distributed=True, + device=device, + use_gpu=True, + wrap_ddp=True, + add_dist_sampler=True + + training_operator.setup(config) + + for idx in range(config["num_epochs"]): + train_loader = training_operator._get_train_loader() + # If using DistributedSampler, set the epoch here. + train_loader.set_epoch(idx) + training_operator.train_epoch(epoch_idx=idx, iter(train_loader)) + + validation_loader = training_operator._get_validation_loader() + training_operator.validate(iterator=iter(validation_loader)) + + if sgd.world_rank() == 0: + return training_operator._get_original_models() + else: + return None + + trainer = Trainer(backend="torch", num_workers=4, use_gpu=True) + trainer.start() + results = trainer.run(train_func, config={"num_epochs": 10}) + final_model = results[0] + +Tensorflow +~~~~~~~~~~ + +The API for ``TFTrainer`` uses creator functions instead of a ``TrainingOperator`` to define the training logic. To port over Ray SGD v1 Tensorflow code to v2 you can do the following: + +.. code-block:: python + + from tensorflow.distribute import MultiWorkerMirroredStrategy + + from ray.sgd import Trainer + from ray import sgd + + def train_func(config): + train_dataset, val_dataset = data_creator(config) + strategy = MultiWorkerMirroredStrategy() + with strategy.scope(): + model = model_creator(config) + + for epoch_idx in range(config["num_epochs"]): + model.fit(train_dataset) + + if sgd.world_rank() == 0: + return model + else: + return None + + trainer = Trainer(backend="tensorflow", num_workers=4, config={"num_epochs": 3, ...}) + trainer.start() + model = trainer.run(train_func)[0] + +You can see a full example :ref:`here `. + +.. _sgd-migration-trainer: + +Interacting with the ``Trainer`` +-------------------------------- + +In Ray SGD v1, you can iteratively call ``trainer.train()`` or ``trainer.validate()`` for each epoch, and can then interact with the trainer to get certain state (model, checkpoints, results, etc.). In Ray SGD v2, this is replaced by a single training function that defines the full training & validation loop for all epochs. + +There are 3 ways to get state during or after the training execution: + + +#. Return values from your training function +#. Intermediate results via ``sgd.report()`` +#. Saving & loading checkpoints via ``sgd.save_checkpoint()`` and ``sgd.load_checkpoint()`` + +Return Values +~~~~~~~~~~~~~ + +To get any state from training *after* training has completed, you can simply return it from your training function. The return values from each the workers will be added to a list and returned from the ``trainer.run()`` call. + +For example, to get the final model: + +**SGD v1** + +.. code-block:: python + + from ray.util.sgd import TorchTrainer, TrainingOperator + + class MyTrainingOperator(TrainingOperator): + ... + + trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=2) + + trainer.train() + + trained_model = trainer.get_model() + +**SGD v2** + +.. code-block:: python + + from ray.sgd import Trainer + + def train_func(): + model = Net() + trainer_loader = MyDataset() + for batch in train_loader: + model.train(batch) + + return model + + trainer = Trainer(backend="torch") + trainer.start() + results = trainer.run(train_func, num_workers=2) + assert len(results) == 2 + trained_model = results[0] + +Intermediate Reporting +~~~~~~~~~~~~~~~~~~~~~~ + +If you want to access any values *during* the training process, you can do so via ``sgd.report()``. You can pass in any values to ``sgd.report()`` and these values from all workers will be sent to any callbacks passed into your ``Trainer``. + +**SGD v1** + +.. code-block:: python + + from ray.util.sgd import TorchTrainer, TrainingOperator + + class MyTrainingOperator(TrainingOperator): + ... + + trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=2) + + for _ in range(3): + print(trainer.train(reduce_results=False)) + + +**SGD v2** + +.. code-block:: python + + from ray import sgd + from ray.sgd Trainer + from ray.sgd.callbacks import SGDCallback + from typing import List, Dict + + class PrintingCallback(SGDCallback): + def handle_result(self, results: List[Dict], **info): + print(results) + + def train_func(): + for i in range(3): + sgd.report(epoch=i) + + trainer = Trainer(backend="torch", num_workers=2) + trainer.start() + result = trainer.run( + train_func, + callbacks=[PrintingCallback()] + ) + # [{'epoch': 0, '_timestamp': 1630471763, '_time_this_iter_s': 0.0020279884338378906, '_training_iteration': 1}, {'epoch': 0, '_timestamp': 1630471763, '_time_this_iter_s': 0.0014922618865966797, '_training_iteration': 1}] + # [{'epoch': 1, '_timestamp': 1630471763, '_time_this_iter_s': 0.0008401870727539062, '_training_iteration': 2}, {'epoch': 1, '_timestamp': 1630471763, '_time_this_iter_s': 0.0007486343383789062, '_training_iteration': 2}] + # [{'epoch': 2, '_timestamp': 1630471763, '_time_this_iter_s': 0.0014500617980957031, '_training_iteration': 3}, {'epoch': 2, '_timestamp': 1630471763, '_time_this_iter_s': 0.0015292167663574219, '_training_iteration': 3}] + trainer.shutdown() + +See the :ref:`v2 User Guide ` for more details. + +Checkpointing +~~~~~~~~~~~~~ + +Finally, you can also use ``sgd.save_checkpoint()`` and ``sgd.load_checkpoint()`` to write checkpoints to disk during the training process, and to load from the most recently saved checkpoint in the case of node failures. + +See the :ref:`Checkpointing ` and :ref:`Fault Tolerance & Elastic Training ` sections on the user guide for more info. + +For example, in order to save checkpoints after every epoch: + +**SGD v1** + +.. code-block:: python + + from ray.util.sgd import TorchTrainer, TrainingOperator + + class MyTrainingOperator(TrainingOperator): + ... + + trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=2) + + for _ in range(3): + trainer.train() + trainer.save_checkpoint(checkpoint_dir="~/ray_results") + + +**SGD v2** + +.. code-block:: python + + from ray.sgd import Trainer + from ray import sgd + + def train_func(): + model = Net() + trainer_loader = MyDataset() + for i in range(3): + for batch in train_loader: + model.train(batch) + sgd.save_checkpoint(epoch=i, model=model.state_dict())) + + trainer = Trainer(backend="torch") + trainer.start() + trainer.run(train_func, num_workers=2) + + +.. _sgd-migration-tune: + +Hyperparameter Tuning with Ray Tune +----------------------------------- + +Ray SGD v2 also comes with an easier to use interface for Hyperparameter Tuning with Ray Tune using Tune's function API instead of its Class API. In particular, it is much easier to define custom procedures because the logic is entirely defined by your training function. + +There is a 1:1 mapping between rank 0 worker's ``sgd.report()``\ , ``sgd.save_checkpoint()``\ , and ``sgd.load_checkpoint()`` with ``tune.report()``\ , ``tune.save_checkpoint()``\ , and ``tune.load_checkpoint()``. + +**SGD v1** + +.. code-block:: python + + from ray import tune + from ray.util.sgd import TrainingOperator, TorchTrainer + + class MyTrainingOperator(TrainingOperator): + ... + + def custom_step(trainer, info): + train_stats = trainer.train() + return train_stats + + # TorchTrainable is subclass of BaseTorchTrainable. + TorchTrainable = TorchTrainer.as_trainable( + training_operator_cls=MyTrainingOperator, + num_workers=2, + use_gpu=True, + override_tune_step=custom_step + ) + + analysis = tune.run( + TorchTrainable, + config={"input": tune.grid_search([1, 2, 3])} + ) + + + +**SGD v2** + +.. code-block:: python + + from ray import tune + from ray import sgd + from ray.sgd import Trainer + + def train_func(config) + # In this example, nothing is expected to change over epochs, + # and the output metric is equivalent to the input value. + for _ in range(config["num_epochs"]): + sgd.report(output=config["input"]) + + trainer = Trainer(backend="torch", num_workers=2) + trainable = trainer.to_tune_trainable(train_func) + analysis = tune.run(trainable, config={ + "num_epochs": 2, + "input": tune.grid_search([1, 2, 3]) + }) + print(analysis.get_best_config(metric="output", mode="max")) + # {'num_epochs': 2, 'input': 3} + +For more information see :ref:`sgd-tune` \ No newline at end of file diff --git a/doc/source/raysgd/v2/raysgd.rst b/doc/source/raysgd/v2/raysgd.rst index 02111cdae1672..a37e583a7fe7e 100644 --- a/doc/source/raysgd/v2/raysgd.rst +++ b/doc/source/raysgd/v2/raysgd.rst @@ -5,6 +5,8 @@ RaySGD: Deep Learning on Ray .. _`issue on GitHub`: https://github.com/ray-project/ray/issues +.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! + RaySGD is a lightweight library for distributed deep learning, allowing you to scale up and speed up training for your deep learning models. @@ -21,7 +23,6 @@ The main features are: `issue on GitHub`_. If you are looking for the previous API documentation, see :ref:`sgd-index`. - Intro to RaySGD --------------- diff --git a/doc/source/raysgd/v2/user_guide.rst b/doc/source/raysgd/v2/user_guide.rst index fe33949342af0..2c34e59dd29f2 100644 --- a/doc/source/raysgd/v2/user_guide.rst +++ b/doc/source/raysgd/v2/user_guide.rst @@ -3,6 +3,8 @@ RaySGD User Guide ================= +.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! + In this guide, we cover examples for the following use cases: * How do I :ref:`port my code ` to using RaySGD? @@ -88,6 +90,7 @@ training. If you are using GPUs, you need to make sure to the CUDA devices are properly setup inside your training function. This involves 3 steps: + 1. Use the local rank to set the default CUDA device for the worker. 2. Move the model to the default CUDA device (or a specific CUDA device). 3. Specify ``device_ids`` when wrapping in ``DistributedDataParallel``. @@ -341,7 +344,8 @@ You can plug all of these into RaySGD with the following interface: .. code-block:: python from ray import sgd - from ray.sgd import SGDCallback, Trainer + from ray.sgd Trainer + from ray.sgd.callbacks import SGDCallback from typing import List, Dict class PrintingCallback(SGDCallback): @@ -395,7 +399,7 @@ A simple example for creating a callback that will print out results: .. code-block:: python - from ray.sgd import SGDCallback + from ray.sgd.callbacks import SGDCallback class PrintingCallback(SGDCallback): def handle_result(self, results: List[Dict], **info): @@ -635,7 +639,7 @@ Underneath the hood, RaySGD will automatically shard the given dataset. return model trainer = Trainer(num_workers=8, backend="torch") - dataset = ray.data.read_csv("...").filter().pipeline(length=50) + dataset = ray.data.read_csv("...").filter().window(blocks_per_window=50) result = trainer.run( train_func, @@ -738,7 +742,7 @@ A couple caveats: # Declare the specification for training. trainer = Trainer(backend="torch", num_workers=12, use_gpu=True) - dataset = ray.dataset.pipeline() + dataset = ray.dataset.window() # Convert this to a trainable. trainable = trainer.to_tune_trainable(training_func, dataset=dataset) diff --git a/doc/source/serve/core-apis.rst b/doc/source/serve/core-apis.rst index e5130821c98be..2bd1f834c465d 100644 --- a/doc/source/serve/core-apis.rst +++ b/doc/source/serve/core-apis.rst @@ -35,7 +35,14 @@ Deployments can be exposed in two ways: over HTTP or in Python via the :ref:`ser By default, HTTP requests will be forwarded to the ``__call__`` method of the class (or the function) and a ``Starlette Request`` object will be the sole argument. You can also define a deployment that wraps a FastAPI app for more flexible handling of HTTP requests. See :ref:`serve-fastapi-http` for details. -We can also list all available deployments and dynamically get a reference to them: +To serve multiple deployments defined by the same class, use the ``name`` option: + +.. code-block:: python + + MyFirstDeployment.options(name="hello_service").deploy("Hello!") + MyFirstDeployment.options(name="hi_service").deploy("Hi!) + +You can also list all available deployments and dynamically get references to them: .. code-block:: python @@ -238,27 +245,31 @@ Ray Serve supports serving deployments with different (possibly conflicting) Python dependencies. For example, you can simultaneously serve one deployment that uses legacy Tensorflow 1 and another that uses Tensorflow 2. -Currently this is supported on Mac OS and Linux using `conda `_ -via Ray's built-in ``runtime_env`` option for actors. -As with all other actor options, pass these in via ``ray_actor_options`` in -your deployment. -You must have a conda environment set up for each set of -dependencies you want to isolate. If using a multi-node cluster, the -desired conda environment must be present on all nodes. Also, the Python patch version -(e.g. 3.8.10) must be identical on all nodes (this is a requirement for any Ray cluster). -See :ref:`runtime-environments` for details. - -Here's an example script. For it to work, first create a conda -environment named ``ray-tf1`` with Ray Serve and Tensorflow 1 installed, -and another named ``ray-tf2`` with Ray Serve and Tensorflow 2. The Ray and -Python versions must be the same in both environments. +This is supported on Mac OS and Linux using Ray's :ref:`runtime-environments` feature. +As with all other Ray actor options, pass the runtime environment in via ``ray_actor_options`` in +your deployment. Be sure to first run ``pip install "ray[default]"`` to ensure the +Runtime Environments feature is installed. + +Example: .. literalinclude:: ../../../python/ray/serve/examples/doc/conda_env.py +.. note:: + When using a Ray library (for example, Ray Serve) in a runtime environment, it must + explicitly be included in the dependencies, as in the above example. This is not + required when just using Ray Core. + +.. tip:: + Avoid dynamically installing packages that install from source: these can be slow and + use up all resources while installing, leading to problems with the Ray cluster. Consider + precompiling such packages in a private repository or Docker image. + The dependencies required in the deployment may be different than the dependencies installed in the driver program (the one running Serve API calls). In this case, you should use a delayed import within the class to avoid -importing unavailable packages in the driver. +importing unavailable packages in the driver. This applies even when not +using runtime environments. + Example: .. literalinclude:: ../../../python/ray/serve/examples/doc/imported_backend.py diff --git a/doc/source/serve/deployment.rst b/doc/source/serve/deployment.rst index 200b97f5dc710..a57c152561481 100644 --- a/doc/source/serve/deployment.rst +++ b/doc/source/serve/deployment.rst @@ -25,7 +25,7 @@ to update the Serve instance, you can run another script that connects to the sa All non-detached Serve instances will be started in the current namespace that was specified when connecting to the cluster. If a namespace is specified for a detached Serve instance, it will be used. Otherwise if the current namespace is anonymous, the Serve instance will be started in the ``serve`` namespace. -If ``serve.start()`` is called again in a process in which there is already a running Serve instance, Serve will re-connect to the existing instance (regardless of whether the original instance was detached or not). To reconnect to a Serve instance that exists in the Ray cluster but not in the current process, connect to the cluster with the same namespace that was specified when starting the instance and run ``serve.start()``. +If ``serve.start()`` is called again in a process in which there is already a running Serve instance, Serve will re-connect to the existing instance (regardless of whether the original instance was detached or not). To reconnect to a Serve instance that exists in the Ray cluster but not in the current process, connect to the cluster with the same namespace that was specified when starting the instance and run ``serve.start()``. Deploying on a Single Node ========================== @@ -244,7 +244,7 @@ To automatically include the current deployment and replica in your logs, simply ``logger = logging.getLogger("ray")``, and use ``logger`` within your deployment code: .. literalinclude:: ../../../python/ray/serve/examples/doc/snippet_logger.py - :lines: 1, 9, 11-13, 15-16 + :lines: 1, 9, 11-14, 16-17 Querying a Serve endpoint with the above deployment will produce a log line like the following: @@ -290,7 +290,7 @@ Save the following file as ``promtail-local-config.yaml``: job: ray __path__: /tmp/ray/session_latest/logs/*.* -The relevant part for Ray is the ``static_configs`` field, where we have indicated the location of our log files with ``__path__``. +The relevant part for Ray is the ``static_configs`` field, where we have indicated the location of our log files with ``__path__``. The expression ``*.*`` will match all files, but not directories, which cause an error with Promtail. We will run Loki locally. Grab the default config file for Loki with the following command in your terminal: @@ -334,7 +334,7 @@ Now click "Explore" in the left-side panel. You are ready to run some queries! To filter all these Ray logs for the ones relevant to our deployment, use the following `LogQL `__ query: -.. code-block:: shell +.. code-block:: shell {job="ray"} |= "deployment=Counter" @@ -377,7 +377,7 @@ The following metrics are exposed by Ray Serve: - The number of requests processed by the router. * - ``serve_handle_request_counter`` - The number of requests processed by this ServeHandle. - * - ``serve_deployment_queued_queries`` + * - ``serve_deployment_queued_queries`` - The number of queries for this deployment waiting to be assigned to a replica. To see this in action, run ``ray start --head --metrics-export-port=8080`` in your terminal, and then run the following script: diff --git a/doc/source/serve/ml-models.rst b/doc/source/serve/ml-models.rst index 8fe3330af0498..192207b041ac5 100644 --- a/doc/source/serve/ml-models.rst +++ b/doc/source/serve/ml-models.rst @@ -70,10 +70,10 @@ Integration with Model Registries Ray Serve is flexible. If you can load your model as a Python function or class, then you can scale it up and serve it with Ray Serve. -For example, if you are using the +For example, if you are using the `MLflow Model Registry `_ to manage your models, the following wrapper -class will allow you to load a model using its MLflow `Model URI`: +class will allow you to load a model using its MLflow `Model URI`: .. code-block:: python @@ -93,12 +93,19 @@ class will allow you to load a model using its MLflow `Model URI`: model_uri = "model:/my_registered_model/Production" MLflowDeployment.deploy(model_uri) -.. tip:: +To serve multiple different MLflow models in the same program, use the ``name`` option: + +.. code-block:: python + + MLflowDeployment.options(name="my_mlflow_model_1").deploy(model_uri) + + +.. tip:: The above approach will work for any model registry, not just MLflow. Namely, load the model from the registry in ``__init__``, and forward the request to the model in ``__call__``. -For an even more hands-off and seamless integration with MLflow, check out the +For an even more hands-off and seamless integration with MLflow, check out the `Ray Serve MLflow deployment plugin `__. A full tutorial is available `here `__. diff --git a/doc/source/tune/_tutorials/_faq.inc b/doc/source/tune/_tutorials/_faq.inc index d9bb39e1f94dc..c14a0aa4504cd 100644 --- a/doc/source/tune/_tutorials/_faq.inc +++ b/doc/source/tune/_tutorials/_faq.inc @@ -19,10 +19,18 @@ Deciding on which to use mostly depends on your problem: * How many hyperparameters would you like to tune? * What values are valid for hyperparameters? +**If your model returns incremental results** (eg. results per epoch in deep learning, +results per each added tree in GBDTs, etc.) using early stopping usually allows for sampling +more configurations, as unpromising trials are pruned before they run their full course. +Please note that not all search algorithms can use information from pruned trials. +Early stopping cannot be used without incremental results - in case of the functional API, +that means that ``tune.report()`` has to be called more than once - usually in a loop. + **If your model is small**, you can usually try to run many different configurations. A **random search** can be used to generate configurations. You can also grid search over some values. You should probably still use -:ref:`ASHA for early termination of bad trials `. +:ref:`ASHA for early termination of bad trials ` (if your problem +supports early stopping). **If your model is large**, you can try to either use **Bayesian Optimization-based search algorithms** like :ref:`BayesOpt ` or @@ -33,14 +41,19 @@ Alternatively, you can use :ref:`Population Based Training ` works well with few trials, e.g. 8 or even 4. However, this will output a hyperparameter *schedule* rather than one fixed set of hyperparameters. -**If you have a small number of hyperparameters**, Bayesian Optimization-methods -work well. Take a look at :ref:`BOHB ` to combine the -benefits of bayesian optimization with early stopping. +**If you have a small number of hyperparameters**, Bayesian Optimization methods +work well. Take a look at :ref:`BOHB ` or :ref:`Optuna ` +with the :ref:`ASHA ` scheduler to combine the +benefits of Bayesian Optimization with early stopping. **If you only have continuous values for hyperparameters** this will work well -with most Bayesian-Optimization methods. Discrete or categorical variables still +with most Bayesian Optimization methods. Discrete or categorical variables still work, but less good with an increasing number of categories. +**If you have many categorical values for hyperparameters**, consider using random search, +or a TPE-based Bayesian Optimization algorithm such as :ref:`Optuna ` or +:ref:`HyperOpt `. + **Our go-to solution** is usually to use **random search** with :ref:`ASHA for early stopping ` for smaller problems. Use :ref:`BOHB ` for **larger problems** with a **small number of hyperparameters** and :ref:`Population Based Training ` for **larger problems** with a **large number of hyperparameters** @@ -248,6 +261,34 @@ on other nodes as well. Please refer to the :ref:`placement groups documentation ` to learn more about these placement strategies. +Why is my training stuck and Ray reporting that pending actor or tasks cannot be scheduled? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This is usually caused by Ray actors or tasks being started by the +trainable without the trainable resources accounting for them, leading to a deadlock. +This can also be "stealthly" caused by using other libraries in the trainable that are +based on Ray, such as Modin. In order to fix the issue, request additional resources for +the trial using :ref:`placement groups `, as outlined in +the section above. + +For example, if your trainable is using Modin dataframes, operations on those will spawn +Ray tasks. By allocating an additional CPU bundle to the trial, those tasks will be able +to run without being starved of resources. + +.. code-block:: python + + import modin.pandas as pd + + def train_fn(config, checkpoint_dir=None): + # some Modin operations here + tune.report(metric=metric) + + tune.run( + train_fn, + resources_per_trial=tune.PlacementGroupFactory([ + {"CPU": 1}, # this bundle will be used by the trainable itself + {"CPU": 1}, # this bundle will be used by Modin + ], strategy="PACK") How can I pass further parameter values to my trainable? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -286,8 +327,8 @@ also works with class trainables. Please see :ref:`here for further details ` and examples. -How can I reproduce experiments -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +How can I reproduce experiments? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Reproducing experiments and experiment results means that you get the exact same results when running an experiment again and again. To achieve this, the conditions have to be exactly the same each time you run the exeriment. diff --git a/doc/source/tune/api_docs/suggestion.rst b/doc/source/tune/api_docs/suggestion.rst index 32728c4ab2273..4795f0c97816f 100644 --- a/doc/source/tune/api_docs/suggestion.rst +++ b/doc/source/tune/api_docs/suggestion.rst @@ -16,6 +16,7 @@ Summary ------- .. list-table:: + :widths: 5 5 2 10 :header-rows: 1 * - SearchAlgorithm @@ -137,8 +138,6 @@ identifier. search_alg2.restore_from_dir( os.path.join("~/my_results", "my-experiment-1")) -.. note:: This is currently not implemented for: AxSearch, TuneBOHB, SigOptSearch, and DragonflySearch. - .. _tune-basicvariant: Random search and grid search (tune.suggest.basic_variant.BasicVariantGenerator) diff --git a/doc/source/tune/user-guide.rst b/doc/source/tune/user-guide.rst index a4f522add908e..962d53bdad848 100644 --- a/doc/source/tune/user-guide.rst +++ b/doc/source/tune/user-guide.rst @@ -50,7 +50,8 @@ the respective placement group. If not enough resources are available, this will If your trainable function starts more remote workers, you will need to pass placement groups factory objects to request these resources. See the :class:`PlacementGroupFactory documentation ` -for further information. +for further information. This also applies if you are using other libraries making use of Ray, such +as Modin. Failure to set resources correctly may result in a deadlock, "hanging" the cluster. Using GPUs ~~~~~~~~~~ @@ -870,6 +871,10 @@ These are the environment variables Ray Tune currently considers: Ctrl+C) to gracefully shutdown and do a final checkpoint. Setting this variable to ``1`` will disable signal handling and stop execution right away. Defaults to ``0``. +* **TUNE_FORCE_TRIAL_CLEANUP_S**: By default, Ray Tune will gracefully terminate trials, + letting them finish the current training step and any user-defined cleanup. + Setting this variable to a non-zero, positive integer will cause trials to be forcefully + terminated after a grace period of that many seconds. Defaults to ``0``. * **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits for threads to finish after instructing them to complete. Defaults to ``2``. * **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's @@ -903,6 +908,9 @@ These are the environment variables Ray Tune currently considers: to the driver. Enabling this might delay scheduling decisions, as trainables are speculatively continued. Setting this to ``0`` disables result buffering. Defaults to 1000 (results), or to 1 (no buffering) if used with ``checkpoint_at_end``. +* **TUNE_RESULT_DELIM**: Delimiter used for nested entries in + :class:`ExperimentAnalysis ` dataframes. Defaults to ``.`` (but will be + changed to ``/`` in future versions of Ray). * **TUNE_RESULT_BUFFER_MAX_TIME_S**: Similarly, Ray Tune buffers results up to ``number_of_trial/10`` seconds, but never longer than this value. Defaults to 100 (seconds). * **TUNE_RESULT_BUFFER_MIN_TIME_S**: Additionally, you can specify a minimum time to buffer results. Defaults to 0. diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 2d98abb9402ba..06a974befed02 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -147,9 +147,12 @@ define_java_module( ":io_ray_ray_api", ":io_ray_ray_runtime", ":io_ray_ray_serve", + "@maven//:com_google_code_gson_gson", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:org_apache_commons_commons_lang3", + "@maven//:org_apache_httpcomponents_client5_httpclient5", + "@maven//:org_apache_httpcomponents_core5_httpcore5", "@maven//:org_slf4j_slf4j_api", "@maven//:org_testng_testng", ], @@ -157,9 +160,11 @@ define_java_module( deps = [ ":io_ray_ray_api", ":io_ray_ray_runtime", + "@maven//:com_google_code_gson_gson", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:org_apache_commons_commons_lang3", + "@maven//:org_apache_httpcomponents_core5_httpcore5", "@maven//:org_slf4j_slf4j_api", ], ) diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 9c411a1bd9982..e6bb9e384d1cf 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -24,6 +24,8 @@ def gen_java_deps(): "com.lmax:disruptor:3.3.4", "org.yaml:snakeyaml:1.26", "net.java.dev.jna:jna:5.5.0", + "org.apache.httpcomponents.client5:httpclient5:5.0.3", + "org.apache.httpcomponents.core5:httpcore5:5.0.2", maven.artifact( group = "org.testng", artifact = "testng", diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index acda82aa6f1d6..172ff78dfa397 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -1,6 +1,7 @@ package io.ray.runtime; import com.google.common.base.Preconditions; +import com.google.gson.Gson; import io.ray.api.BaseActorHandle; import io.ray.api.id.ActorId; import io.ray.api.id.JobId; @@ -10,6 +11,7 @@ import io.ray.runtime.exception.RayIntentionalSystemExitException; import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.gcs.GcsClientOptions; +import io.ray.runtime.generated.Common.RuntimeEnv; import io.ray.runtime.generated.Common.WorkerType; import io.ray.runtime.generated.Gcs.GcsNodeInfo; import io.ray.runtime.generated.Gcs.JobConfig; @@ -20,6 +22,8 @@ import io.ray.runtime.task.TaskExecutor; import io.ray.runtime.util.BinaryFileUtil; import io.ray.runtime.util.JniUtils; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; @@ -102,8 +106,20 @@ public void start() { JobConfig.newBuilder() .setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess) .addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker) - .putAllWorkerEnv(rayConfig.workerEnv) .addAllCodeSearchPath(rayConfig.codeSearchPath); + RuntimeEnv.Builder runtimeEnvBuilder = RuntimeEnv.newBuilder(); + if (!rayConfig.workerEnv.isEmpty()) { + // TODO(SongGuyang): Suppport complete runtime env interface for users. + // Set worker env to the serialized runtime env json. + Gson gson = new Gson(); + Map> runtimeEnv = new HashMap<>(); + runtimeEnv.put("env_vars", rayConfig.workerEnv); + String gsonString = gson.toJson(runtimeEnv); + runtimeEnvBuilder.setSerializedRuntimeEnv(gsonString); + } else { + runtimeEnvBuilder.setSerializedRuntimeEnv("{}"); + } + jobConfigBuilder.setRuntimeEnv(runtimeEnvBuilder.build()); serializedJobConfig = jobConfigBuilder.build().toByteArray(); } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java index 131d71c5fa2f9..fc139985955c9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java @@ -117,7 +117,7 @@ public Address getOwnerAddress(ObjectId id) { } @Override - public byte[] promoteAndGetOwnershipInfo(ObjectId objectId) { + public byte[] getOwnershipInfo(ObjectId objectId) { return new byte[0]; } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java index 7e0ddc5c9aa74..136712c096cd8 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java @@ -78,8 +78,8 @@ public void removeLocalReference(UniqueId workerId, ObjectId objectId) { } @Override - public byte[] promoteAndGetOwnershipInfo(ObjectId objectId) { - return nativePromoteAndGetOwnershipInfo(objectId.getBytes()); + public byte[] getOwnershipInfo(ObjectId objectId) { + return nativeGetOwnershipInfo(objectId.getBytes()); } @Override @@ -132,7 +132,7 @@ private static native List nativeWait( private static native byte[] nativeGetOwnerAddress(byte[] objectId); - private static native byte[] nativePromoteAndGetOwnershipInfo(byte[] objectId); + private static native byte[] nativeGetOwnershipInfo(byte[] objectId); private static native void nativeRegisterOwnershipInfoAndResolveFuture( byte[] objectId, byte[] outerObjectId, byte[] ownerAddress); diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java index cb9b35becd02d..a352ca22632ef 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java @@ -63,7 +63,7 @@ public void writeExternal(ObjectOutput out) throws IOException { out.writeObject(this.getId()); out.writeObject(this.getType()); RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal(); - byte[] ownerAddress = runtime.getObjectStore().promoteAndGetOwnershipInfo(this.getId()); + byte[] ownerAddress = runtime.getObjectStore().getOwnershipInfo(this.getId()); out.writeInt(ownerAddress.length); out.write(ownerAddress); ObjectSerializer.addContainedObjectId(this.getId()); diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java index d61694fab7e93..6db39cc1e4bd6 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java @@ -224,12 +224,12 @@ public WaitResult wait( public abstract Address getOwnerAddress(ObjectId id); /** - * Promote the given object to the underlying object store, and get the ownership info. + * Get the ownership info. * * @param objectId The ID of the object to promote * @return the serialized ownership address */ - public abstract byte[] promoteAndGetOwnershipInfo(ObjectId objectId); + public abstract byte[] getOwnershipInfo(ObjectId objectId); /** * Add a reference to an ObjectID that will deserialized. This will also start the process to diff --git a/java/serve/pom.xml b/java/serve/pom.xml index d945f8fe83172..7291d4ec79666 100644 --- a/java/serve/pom.xml +++ b/java/serve/pom.xml @@ -27,6 +27,11 @@ ray-runtime ${project.version} + + com.google.code.gson + gson + 2.8.5 + com.google.guava guava @@ -42,6 +47,16 @@ commons-lang3 3.4 + + org.apache.httpcomponents.client5 + httpclient5 + 5.0.3 + + + org.apache.httpcomponents.core5 + httpcore5 + 5.0.2 + org.slf4j slf4j-api diff --git a/java/serve/src/main/java/io/ray/serve/Constants.java b/java/serve/src/main/java/io/ray/serve/Constants.java index 2d8ac4f702839..1ca1739f8d734 100644 --- a/java/serve/src/main/java/io/ray/serve/Constants.java +++ b/java/serve/src/main/java/io/ray/serve/Constants.java @@ -16,4 +16,10 @@ public class Constants { /** Name of controller listen_for_change method. */ public static final String CONTROLLER_LISTEN_FOR_CHANGE_METHOD = "listen_for_change"; + + public static final String SERVE_CONTROLLER_NAME = "SERVE_CONTROLLER_ACTOR"; + + public static final String DEFAULT_CALL_METHOD = "call"; + + public static final String UTF8 = "UTF-8"; } diff --git a/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java b/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java new file mode 100644 index 0000000000000..2ab02deeeeaeb --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java @@ -0,0 +1,38 @@ +package io.ray.serve; + +import java.io.Serializable; + +public class DeploymentInfo implements Serializable { + + private static final long serialVersionUID = -4198364411759931955L; + + private byte[] backendConfig; + + private ReplicaConfig replicaConfig; + + private byte[] backendVersion; + + public byte[] getBackendConfig() { + return backendConfig; + } + + public void setBackendConfig(byte[] backendConfig) { + this.backendConfig = backendConfig; + } + + public ReplicaConfig getReplicaConfig() { + return replicaConfig; + } + + public void setReplicaConfig(ReplicaConfig replicaConfig) { + this.replicaConfig = replicaConfig; + } + + public byte[] getBackendVersion() { + return backendVersion; + } + + public void setBackendVersion(byte[] backendVersion) { + this.backendVersion = backendVersion; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java b/java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java new file mode 100644 index 0000000000000..874a71c26d6db --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java @@ -0,0 +1,12 @@ +package io.ray.serve; + +import java.util.concurrent.atomic.AtomicInteger; + +public class DummyBackendReplica { + + private AtomicInteger counter = new AtomicInteger(); + + public String call() { + return String.valueOf(counter.incrementAndGet()); + } +} diff --git a/java/serve/src/main/java/io/ray/serve/HandleOptions.java b/java/serve/src/main/java/io/ray/serve/HandleOptions.java new file mode 100644 index 0000000000000..e301332976ea3 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/HandleOptions.java @@ -0,0 +1,15 @@ +package io.ray.serve; + +/** Options for each ServeHandle instances. These fields are immutable. */ +public class HandleOptions { + + private String methodName = "call"; + + public String getMethodName() { + return methodName; + } + + public void setMethodName(String methodName) { + this.methodName = methodName; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/HttpProxy.java b/java/serve/src/main/java/io/ray/serve/HttpProxy.java new file mode 100644 index 0000000000000..809337e75d902 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/HttpProxy.java @@ -0,0 +1,161 @@ +package io.ray.serve; + +import com.google.common.collect.ImmutableMap; +import io.ray.api.Ray; +import io.ray.runtime.metric.Count; +import io.ray.runtime.metric.Metrics; +import io.ray.runtime.metric.TagKey; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.util.LogUtil; +import io.ray.serve.util.SocketUtil; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.InetAddress; +import java.nio.charset.Charset; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.apache.hc.core5.http.ClassicHttpRequest; +import org.apache.hc.core5.http.ClassicHttpResponse; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpException; +import org.apache.hc.core5.http.impl.bootstrap.HttpServer; +import org.apache.hc.core5.http.impl.bootstrap.ServerBootstrap; +import org.apache.hc.core5.http.io.HttpRequestHandler; +import org.apache.hc.core5.http.io.entity.ByteArrayEntity; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.protocol.HttpContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class HttpProxy implements ServeProxy { + + private static final Logger LOGGER = LoggerFactory.getLogger(HttpProxy.class); + + public static final String PROXY_NAME = "HTTP_PROXY"; + + public static final String PROXY_HTTP_PORT = "ray.serve.proxy.http.port"; + + public static final String PROXY_HTTP_METHODS = "ray.serve.proxy.http.methods"; + + private int port; + + private Count requestCounter; + + private HttpServer httpServer; + + private ProxyRouter proxyRouter; + + private Object asyncContext = Ray.getAsyncContext(); + + @Override + public void init(Map config, ProxyRouter proxyRouter) { + this.port = + Optional.ofNullable(config) + .map(conf -> conf.get(PROXY_HTTP_PORT)) + .map(httpPort -> Integer.valueOf(httpPort)) + .orElse(SocketUtil.findAvailableTcpPort(8000)); + this.proxyRouter = proxyRouter; + RayServeMetrics.execute( + () -> + this.requestCounter = + Metrics.count() + .name("serve_num_http_requests") + .description("The number of HTTP requests processed.") + .unit("") + .tags(new HashMap<>()) + .register()); + startupHttpServer(port); + LOGGER.info("Proxy {} has been started with port:{}", getName(), this.port); + } + + private void startupHttpServer(int port) { + try { + this.httpServer = + ServerBootstrap.bootstrap() + .setListenerPort(port) + .register("*", new ServeHttpHandler()) + .registerVirtual( + InetAddress.getLocalHost().getHostAddress(), "*", new ServeHttpHandler()) + .create(); + this.httpServer.start(); + } catch (Throwable e) { + String errMsg = + LogUtil.format( + "Proxy {} failed to startup HTTP server on port {}.", getName(), this.port); + LOGGER.error(errMsg); + throw new RayServeException(errMsg, e); + } + } + + @Override + public String getName() { + return PROXY_NAME; + } + + private class ServeHttpHandler implements HttpRequestHandler { + + @Override + public void handle( + ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context) + throws HttpException, IOException { + + Ray.setAsyncContext(asyncContext); + + int code = HttpURLConnection.HTTP_OK; + Object result = null; + String route = request.getPath(); + try { + RayServeMetrics.execute( + () -> + requestCounter.update( + 1.0, + ImmutableMap.of( + new TagKey(RayServeMetrics.TAG_ROUTE), + route))); // TODO the old tag will be covered, it may be a bug. + + Object[] parameters = null; + HttpEntity httpEntity = request.getEntity(); + if (null == httpEntity) { + parameters = new Object[0]; + } else { + byte[] body = EntityUtils.toByteArray(httpEntity); + parameters = MessagePackSerializer.decode(body, Object[].class); + } + + RayServeHandle rayServeHandle = proxyRouter.matchRoute(route); + if (rayServeHandle == null) { + code = HttpURLConnection.HTTP_NOT_FOUND; + } else { + result = rayServeHandle.remote(parameters).get(); + } + + } catch (Throwable e) { + LOGGER.error("HTTP Proxy failed to process request.", e); + code = HttpURLConnection.HTTP_INTERNAL_ERROR; + } finally { + response.setCode(code); + if (code == HttpURLConnection.HTTP_NOT_FOUND) { + response.setEntity( + new StringEntity( + LogUtil.format( + "Path '{}' not found. Please ping http://.../-/routes for route table.", + route), + Charset.forName(Constants.UTF8))); + } else if (result != null) { + response.setEntity( + new ByteArrayEntity(MessagePackSerializer.encode(result).getLeft(), null)); + } + } + } + } + + public int getPort() { + return port; + } + + public ProxyRouter getProxyRouter() { + return proxyRouter; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/ProxyActor.java b/java/serve/src/main/java/io/ray/serve/ProxyActor.java new file mode 100644 index 0000000000000..ac5d1cf870ea9 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/ProxyActor.java @@ -0,0 +1,175 @@ +package io.ray.serve; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; +import io.ray.serve.api.Serve; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.generated.EndpointSet; +import io.ray.serve.poll.KeyListener; +import io.ray.serve.poll.KeyType; +import io.ray.serve.poll.LongPollClient; +import io.ray.serve.poll.LongPollNamespace; +import io.ray.serve.util.CollectionUtil; +import io.ray.serve.util.LogUtil; +import io.ray.serve.util.ReflectUtil; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.ServiceLoader; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ProxyActor { + + private static final Logger LOGGER = LoggerFactory.getLogger(ProxyActor.class); + + private Map config; + + private Map proxies = new ConcurrentHashMap<>(); + + /** Used only for displaying the route table. Key: route, value: endpoint. */ + private volatile Map routeInfo = new HashMap<>(); + + private LongPollClient longPollClient; + + private ProxyRouter proxyRouter = new ProxyRouter(); + + public ProxyActor(String controllerName, Map config) { + this.config = config; + + // Set the controller name so that serve will connect to the controller instance this proxy is + // running in. + Serve.setInternalReplicaContext(null, null, controllerName, null); + + Optional optional = Ray.getActor(controllerName); + Preconditions.checkState(optional.isPresent(), "Controller does not exist"); + + Map keyListeners = new HashMap<>(); + keyListeners.put( + new KeyType(LongPollNamespace.ROUTE_TABLE, null), endpoints -> updateRoutes(endpoints)); + this.longPollClient = new LongPollClient(optional.get(), keyListeners); + this.longPollClient.start(); + this.run(); + } + + private void run() { + startupProxy(); + registerServiceDiscovery(); + } + + private void startupProxy() { + + List serveProxies = null; + + // Get proxy instances according to class names. + String proxyClassNames = config != null ? config.get(RayServeConfig.PROXY_CLASS) : null; + if (StringUtils.isNotBlank(proxyClassNames)) { + try { + serveProxies = ReflectUtil.getInstancesByClassNames(proxyClassNames, ServeProxy.class); + } catch (ClassNotFoundException + | InstantiationException + | IllegalAccessException + | IllegalArgumentException + | InvocationTargetException + | NoSuchMethodException + | SecurityException e) { + String errorMsg = + LogUtil.format("Failed to initialize proxies by class names : {}", proxyClassNames); + LOGGER.error(errorMsg, e); + throw new RayServeException(errorMsg, e); + } + } + + // Get proxy instances through SPI. + if (CollectionUtil.isEmpty(serveProxies)) { + List spiProxies = new ArrayList<>(); + ServiceLoader serviceLoader = ServiceLoader.load(ServeProxy.class); + serviceLoader.forEach(serveProxy -> spiProxies.add(serveProxy)); + serveProxies = spiProxies; + } + + // Set the default proxy if proxies still empty. + if (CollectionUtil.isEmpty(serveProxies)) { + serveProxies = Lists.newArrayList(new HttpProxy()); + } + + if (!CollectionUtil.isEmpty(serveProxies)) { + for (ServeProxy serveProxy : serveProxies) { + if (proxies.containsKey(serveProxy.getName())) { + String errorMsg = + LogUtil.format( + "Proxy {} name {} is duplicate with proxy {} name {}", + serveProxy.getClass().getName(), + serveProxy.getName(), + proxies.get(serveProxy.getName()).getClass().getName(), + proxies.get(serveProxy.getName()).getName()); + LOGGER.error(errorMsg); + throw new RayServeException(errorMsg); + } + proxies.put(serveProxy.getName(), serveProxy); + serveProxy.init(config, proxyRouter); + LOGGER.info("Proxy actor initialized proxy: {}", serveProxy.getName()); + } + } + } + + public void registerServiceDiscovery() { + proxies.forEach((key, value) -> value.registerServiceDiscovery()); + } + + public void updateRoutes(Object endpoints) { + Map endpointInfos = ((EndpointSet) endpoints).getEndpointsMap(); + Map routeInfo = new HashMap<>(); + if (endpointInfos != null) { + endpointInfos.forEach( + (key, value) -> + routeInfo.put( + StringUtils.isNotBlank(value.getRoute()) ? value.getRoute() : key, value)); + } + this.routeInfo = routeInfo; + this.proxyRouter.updateRoutes(endpointInfos); + } + + public void ready() { + return; + } + + public void blockUntilEndpointExists(String endpoint, double timeoutS) { + long timeoutMs = (long) (timeoutS * 1000); + long startTime = System.currentTimeMillis(); + while (true) { + if (System.currentTimeMillis() - startTime > timeoutMs) { + throw new RayServeException( + LogUtil.format("Waited {} for {} to propagate.", timeoutS, endpoint)); + } + for (EndpointInfo endpointInfo : routeInfo.values()) { + if (StringUtils.equals(endpointInfo.getEndpointName(), endpoint)) { + return; + } + } + try { + Thread.sleep(200); + } catch (InterruptedException e) { + LOGGER.error( + "The sleeping was interrupted when waiting for the endpoint {} being existing.", + endpoint, + e); + } + } + } + + public ProxyRouter getProxyRouter() { + return proxyRouter; + } + + public Map getProxies() { + return proxies; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/ProxyRouter.java b/java/serve/src/main/java/io/ray/serve/ProxyRouter.java new file mode 100644 index 0000000000000..041da46bfee08 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/ProxyRouter.java @@ -0,0 +1,72 @@ +package io.ray.serve; + +import io.ray.serve.api.Serve; +import io.ray.serve.generated.EndpointInfo; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Default common router for proxy to match incomming routes. */ +public class ProxyRouter { + + private static final Logger LOGGER = LoggerFactory.getLogger(ProxyRouter.class); + + /** Key: route, value: endpoint. */ + private Map routeInfo = new HashMap<>(); + + /** Key: endpointName, value: handle. */ + private Map handles = new ConcurrentHashMap<>(); + + public void updateRoutes(Map endpoints) { + LOGGER.info("Got updated endpoints: {}.", endpoints); + + Set existingHandles = new HashSet<>(handles.keySet()); + Map routeInfo = new HashMap<>(); + + if (endpoints != null) { + for (Map.Entry entry : endpoints.entrySet()) { + String route = + StringUtils.isNotBlank(entry.getValue().getRoute()) + ? entry.getValue().getRoute() + : entry.getKey(); + routeInfo.put(route, entry.getValue()); + + if (handles.containsKey(entry.getKey())) { + existingHandles.remove(entry.getKey()); + } else { + handles.put(entry.getKey(), Serve.getGlobalClient().getHandle(entry.getKey(), true)); + } + } + } + + this.routeInfo = routeInfo; + for (String endpoint : existingHandles) { + handles.remove(endpoint); + } + LOGGER.info("The final route info: {}.", routeInfo); + } + + /** + * Return the longest prefix match among existing routes for the route. + * + * @param route route to match against. + * @return serve_handle (RayServeHandle) if found, else null. + */ + public RayServeHandle matchRoute(String route) { + EndpointInfo endpointInfo = routeInfo.get(route); + return endpointInfo == null ? null : handles.get(endpointInfo.getEndpointName()); + } + + public Map getRouteInfo() { + return routeInfo; + } + + public Map getHandles() { + return handles; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeConfig.java b/java/serve/src/main/java/io/ray/serve/RayServeConfig.java new file mode 100644 index 0000000000000..5762aae40be4e --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/RayServeConfig.java @@ -0,0 +1,6 @@ +package io.ray.serve; + +public class RayServeConfig { + + public static final String PROXY_CLASS = "ray.serve.proxy.class"; +} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeHandle.java b/java/serve/src/main/java/io/ray/serve/RayServeHandle.java new file mode 100644 index 0000000000000..abcf6ac5abdf2 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/RayServeHandle.java @@ -0,0 +1,73 @@ +package io.ray.serve; + +import com.google.common.collect.ImmutableMap; +import io.ray.api.BaseActorHandle; +import io.ray.api.ObjectRef; +import io.ray.runtime.metric.Count; +import io.ray.runtime.metric.Metrics; +import io.ray.serve.generated.RequestMetadata; +import org.apache.commons.lang3.RandomStringUtils; + +public class RayServeHandle { + + private String endpointName; + + private HandleOptions handleOptions; + + private String handleTag; + + private Count requestCounter; + + private Router router; + + public RayServeHandle( + BaseActorHandle controllerHandle, + String endpointName, + HandleOptions handleOptions, + Router router) { + this.endpointName = endpointName; + this.handleOptions = handleOptions != null ? handleOptions : new HandleOptions(); + this.handleTag = endpointName + "#" + RandomStringUtils.randomAlphabetic(6); + this.router = router != null ? router : new Router(controllerHandle, endpointName); + RayServeMetrics.execute( + () -> + this.requestCounter = + Metrics.count() + .name(RayServeMetrics.SERVE_HANDLE_REQUEST_COUNTER.name()) + .description(RayServeMetrics.SERVE_HANDLE_REQUEST_COUNTER.getDescription()) + .unit("") + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_HANDLE, + handleTag, + RayServeMetrics.TAG_ENDPOINT, + endpointName)) + .register()); + } + + /** + * Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get + * (or ``await object_ref``), respectively. + * + * @param parameters The input parameters of the specified method to invoke on the backend. + * @return ray.ObjectRef + */ + public ObjectRef remote(Object[] parameters) { + RayServeMetrics.execute(() -> requestCounter.inc(1.0)); + RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); + requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); + requestMetadata.setEndpoint(endpointName); + requestMetadata.setCallMethod( + handleOptions != null ? handleOptions.getMethodName() : Constants.DEFAULT_CALL_METHOD); + return router.assignRequest(requestMetadata.build(), parameters); + } + + public RayServeHandle setMethodName(String methodName) { + handleOptions.setMethodName(methodName); + return this; + } + + public Router getRouter() { + return router; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeMetrics.java b/java/serve/src/main/java/io/ray/serve/RayServeMetrics.java new file mode 100644 index 0000000000000..f7b1fac730da9 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/RayServeMetrics.java @@ -0,0 +1,74 @@ +package io.ray.serve; + +import io.ray.api.Ray; + +public enum RayServeMetrics { + SERVE_HANDLE_REQUEST_COUNTER( + "serve_handle_request_counter", + "The number of handle.remote() calls that have been made on this handle."), + + SERVE_NUM_ROUTER_REQUESTS( + "serve_num_router_requests", "The number of requests processed by the router."), + + SERVE_DEPLOYMENT_QUEUED_QUERIES( + "serve_deployment_queued_queries", + "The current number of queries to this deployment waiting to be assigned to a replica."), + + SERVE_BACKEND_REQUEST_COUNTER( + "serve_backend_request_counter", + "The number of queries that have been processed in this replica."), + + SERVE_BACKEND_ERROR_COUNTER( + "serve_backend_error_counter", + "The number of exceptions that have occurred in this replica."), + + SERVE_BACKEND_REPLICA_STARTS( + "serve_backend_replica_starts", + "The number of times this replica has been restarted due to failure."), + + SERVE_BACKEND_PROCESSING_LATENCY_MS( + "serve_backend_processing_latency_ms", "The latency for queries to be processed."), + + SERVE_REPLICA_PROCESSING_QUERIES( + "serve_replica_processing_queries", "The current number of queries being processed."), + ; + + public static final String TAG_HANDLE = "handle"; + + public static final String TAG_ENDPOINT = "endpoint"; + + public static final String TAG_DEPLOYMENT = "deployment"; + + public static final String TAG_ROUTE = "route"; + + public static final String TAG_BACKEND = "backend"; + + public static final String TAG_REPLICA = "replica"; + + private static final boolean isMetricsEnabled = + Ray.isInitialized() && !Ray.getRuntimeContext().isSingleProcess(); + + private String name; + + private String description; + + private RayServeMetrics(String name, String description) { + this.name = name; + this.description = description; + } + + public static void execute(Runnable runnable) { + if (!isMetricsEnabled) { + return; + } + runnable.run(); + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeReplica.java index 9949115fbbd72..259c8555cf3e4 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeReplica.java @@ -1,16 +1,16 @@ package io.ray.serve; import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; import io.ray.api.BaseActorHandle; -import io.ray.api.Ray; import io.ray.runtime.metric.Count; import io.ray.runtime.metric.Gauge; import io.ray.runtime.metric.Histogram; -import io.ray.runtime.metric.MetricConfig; import io.ray.runtime.metric.Metrics; import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendVersion; import io.ray.serve.generated.RequestWrapper; import io.ray.serve.poll.KeyListener; import io.ray.serve.poll.KeyType; @@ -18,7 +18,6 @@ import io.ray.serve.poll.LongPollNamespace; import io.ray.serve.util.LogUtil; import io.ray.serve.util.ReflectUtil; -import io.ray.serve.util.ServeProtoUtil; import java.lang.reflect.Method; import java.util.HashMap; import java.util.Map; @@ -41,8 +40,6 @@ public class RayServeReplica { private Object callable; - private boolean metricsRegistered = false; - private Count requestCounter; private Count errorCounter; @@ -55,13 +52,20 @@ public class RayServeReplica { private LongPollClient longPollClient; + private BackendVersion version; + + private boolean isDeleted = false; + public RayServeReplica( - Object callable, BackendConfig backendConfig, BaseActorHandle actorHandle) { + Object callable, + BackendConfig backendConfig, + BackendVersion version, + BaseActorHandle actorHandle) { this.backendTag = Serve.getReplicaContext().getBackendTag(); this.replicaTag = Serve.getReplicaContext().getReplicaTag(); this.callable = callable; this.config = backendConfig; - this.reconfigure(ServeProtoUtil.parseUserConfig(backendConfig)); + this.version = version; Map keyListeners = new HashMap<>(); keyListeners.put( @@ -73,55 +77,84 @@ public RayServeReplica( } private void registerMetrics() { - if (!Ray.isInitialized() || Ray.getRuntimeContext().isSingleProcess()) { - return; - } - - Metrics.init(MetricConfig.DEFAULT_CONFIG); - requestCounter = - Metrics.count() - .name("serve_backend_request_counter") - .description("The number of queries that have been processed in this replica.") - .unit("") - .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) - .register(); - - errorCounter = - Metrics.count() - .name("serve_backend_error_counter") - .description("The number of exceptions that have occurred in this replica.") - .unit("") - .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) - .register(); - - restartCounter = - Metrics.count() - .name("serve_backend_replica_starts") - .description("The number of times this replica has been restarted due to failure.") - .unit("") - .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) - .register(); - - processingLatencyTracker = - Metrics.histogram() - .name("serve_backend_processing_latency_ms") - .description("The latency for queries to be processed.") - .unit("") - .boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS) - .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) - .register(); - - numProcessingItems = - Metrics.gauge() - .name("serve_replica_processing_queries") - .description("The current number of queries being processed.") - .unit("") - .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) - .register(); - - metricsRegistered = true; - - restartCounter.inc(1.0); + RayServeMetrics.execute( + () -> + requestCounter = + Metrics.count() + .name(RayServeMetrics.SERVE_BACKEND_REQUEST_COUNTER.getName()) + .description(RayServeMetrics.SERVE_BACKEND_REQUEST_COUNTER.getDescription()) + .unit("") + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_BACKEND, + backendTag, + RayServeMetrics.TAG_REPLICA, + replicaTag)) + .register()); + + RayServeMetrics.execute( + () -> + errorCounter = + Metrics.count() + .name(RayServeMetrics.SERVE_BACKEND_ERROR_COUNTER.getName()) + .description(RayServeMetrics.SERVE_BACKEND_ERROR_COUNTER.getDescription()) + .unit("") + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_BACKEND, + backendTag, + RayServeMetrics.TAG_REPLICA, + replicaTag)) + .register()); + + RayServeMetrics.execute( + () -> + restartCounter = + Metrics.count() + .name(RayServeMetrics.SERVE_BACKEND_REPLICA_STARTS.getName()) + .description(RayServeMetrics.SERVE_BACKEND_REPLICA_STARTS.getDescription()) + .unit("") + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_BACKEND, + backendTag, + RayServeMetrics.TAG_REPLICA, + replicaTag)) + .register()); + + RayServeMetrics.execute( + () -> + processingLatencyTracker = + Metrics.histogram() + .name(RayServeMetrics.SERVE_BACKEND_PROCESSING_LATENCY_MS.getName()) + .description( + RayServeMetrics.SERVE_BACKEND_PROCESSING_LATENCY_MS.getDescription()) + .unit("") + .boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS) + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_BACKEND, + backendTag, + RayServeMetrics.TAG_REPLICA, + replicaTag)) + .register()); + + RayServeMetrics.execute( + () -> + numProcessingItems = + Metrics.gauge() + .name(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getName()) + .description(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getDescription()) + .unit("") + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_BACKEND, + backendTag, + RayServeMetrics.TAG_REPLICA, + replicaTag)) + .register()); + + RayServeMetrics.execute(() -> restartCounter.inc(1.0)); } public Object handleRequest(Query request) { @@ -130,7 +163,7 @@ public Object handleRequest(Query request) { "Replica {} received request {}", replicaTag, request.getMetadata().getRequestId()); numOngoingRequests.incrementAndGet(); - reportMetrics(() -> numProcessingItems.update(numOngoingRequests.get())); + RayServeMetrics.execute(() -> numProcessingItems.update(numOngoingRequests.get())); Object result = invokeSingle(request); numOngoingRequests.decrementAndGet(); @@ -157,10 +190,10 @@ private Object invokeSingle(Query requestItem) { Object[] args = parseRequestItem(requestItem); methodToCall = getRunnerMethod(requestItem.getMetadata().getCallMethod(), args); Object result = methodToCall.invoke(callable, args); - reportMetrics(() -> requestCounter.inc(1.0)); + RayServeMetrics.execute(() -> requestCounter.inc(1.0)); return result; } catch (Throwable e) { - reportMetrics(() -> errorCounter.inc(1.0)); + RayServeMetrics.execute(() -> errorCounter.inc(1.0)); throw new RayServeException( LogUtil.format( "Replica {} failed to invoke method {}", @@ -168,7 +201,8 @@ private Object invokeSingle(Query requestItem) { methodToCall == null ? "unknown" : methodToCall.getName()), e); } finally { - reportMetrics(() -> processingLatencyTracker.update(System.currentTimeMillis() - start)); + RayServeMetrics.execute( + () -> processingLatencyTracker.update(System.currentTimeMillis() - start)); } } @@ -209,10 +243,12 @@ private Method getRunnerMethod(String methodName, Object[] args) { * Perform graceful shutdown. Trigger a graceful shutdown protocol that will wait for all the * queued tasks to be completed and return to the controller. */ - public void drainPendingQueries() { + public synchronized boolean prepareForShutdown() { while (true) { + // Sleep first because we want to make sure all the routers receive the notification to remove + // this replica first. try { - Thread.sleep((long) (config.getExperimentalGracefulShutdownWaitLoopS() * 1000)); + Thread.sleep((long) (config.getGracefulShutdownWaitLoopS() * 1000)); } catch (InterruptedException e) { LOGGER.error( "Replica {} was interrupted in sheep when draining pending queries", replicaTag); @@ -220,13 +256,27 @@ public void drainPendingQueries() { if (numOngoingRequests.get() == 0) { break; } else { - LOGGER.debug( + LOGGER.info( "Waiting for an additional {}s to shut down because there are {} ongoing requests.", - config.getExperimentalGracefulShutdownWaitLoopS(), + config.getGracefulShutdownWaitLoopS(), numOngoingRequests.get()); } } - Ray.exitActor(); + + // Explicitly call the del method to trigger clean up. We set isDeleted = true after + // succssifully calling it so the destructor is called only once. + try { + if (!isDeleted) { + ReflectUtil.getMethod(callable.getClass(), "del").invoke(callable); + } + } catch (NoSuchMethodException e) { + LOGGER.warn("Deployment {} has no del method.", backendTag); + } catch (Throwable e) { + LOGGER.error("Exception during graceful shutdown of replica."); + } finally { + isDeleted = true; + } + return true; } /** @@ -234,28 +284,34 @@ public void drainPendingQueries() { * * @param userConfig new user's configuration */ - private void reconfigure(Object userConfig) { - if (userConfig == null) { - return; + public BackendVersion reconfigure(Object userConfig) { + BackendVersion.Builder builder = BackendVersion.newBuilder(); + builder.setCodeVersion(version.getCodeVersion()); + if (userConfig != null) { + builder.setUserConfig(ByteString.copyFrom((byte[]) userConfig)); } + version = builder.build(); + try { Method reconfigureMethod = ReflectUtil.getMethod( callable.getClass(), Constants.BACKEND_RECONFIGURE_METHOD, - userConfig); // TODO cache reconfigureMethod + userConfig != null + ? MessagePackSerializer.decode((byte[]) userConfig, Object[].class) + : new Object[0]); // TODO cache reconfigure method reconfigureMethod.invoke(callable, userConfig); } catch (NoSuchMethodException e) { - throw new RayServeException( - LogUtil.format( - "user_config specified but backend {} missing {} method", - backendTag, - Constants.BACKEND_RECONFIGURE_METHOD)); + LOGGER.warn( + "user_config specified but backend {} missing {} method", + backendTag, + Constants.BACKEND_RECONFIGURE_METHOD); } catch (Throwable e) { throw new RayServeException( LogUtil.format("Backend {} failed to reconfigure user_config {}", backendTag, userConfig), e); } + return version; } /** @@ -265,12 +321,9 @@ private void reconfigure(Object userConfig) { */ private void updateBackendConfigs(Object newConfig) { config = (BackendConfig) newConfig; - reconfigure(((BackendConfig) newConfig).getUserConfig()); } - private void reportMetrics(Runnable runnable) { - if (metricsRegistered) { - runnable.run(); - } + public BackendVersion getVersion() { + return version; } } diff --git a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java index 9ccc6c6f7a448..53e0854044c71 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java @@ -7,6 +7,7 @@ import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendVersion; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.util.ReflectUtil; import io.ray.serve.util.ServeProtoUtil; @@ -27,6 +28,7 @@ public RayServeWrappedReplica( String backendDef, byte[] initArgsbytes, byte[] backendConfigBytes, + byte[] backendVersionBytes, String controllerName) throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException { @@ -52,7 +54,26 @@ public RayServeWrappedReplica( Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, callable); // Construct worker replica. - backend = new RayServeReplica(callable, backendConfig, optional.get()); + backend = + new RayServeReplica( + callable, + backendConfig, + ServeProtoUtil.parseBackendVersion(backendVersionBytes), + optional.get()); + } + + public RayServeWrappedReplica( + String backendTag, String replicaTag, DeploymentInfo deploymentInfo, String controllerName) + throws ClassNotFoundException, NoSuchMethodException, InstantiationException, + IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException { + this( + backendTag, + replicaTag, + deploymentInfo.getReplicaConfig().getBackendDef(), + deploymentInfo.getReplicaConfig().getInitArgs(), + deploymentInfo.getBackendConfig(), + deploymentInfo.getBackendVersion(), + controllerName); } private Object[] parseInitArgs(byte[] initArgsbytes, BackendConfig backendConfig) @@ -101,8 +122,21 @@ public void ready() { return; } - /** Wait until there is no request in processing. It is used for stopping replica gracefully. */ - public void drainPendingQueries() { - backend.drainPendingQueries(); + /** + * Wait until there is no request in processing. It is used for stopping replica gracefully. + * + * @return true if it is ready for shutdown. + */ + public boolean prepareForShutdown() { + return backend.prepareForShutdown(); + } + + public byte[] reconfigure(Object userConfig) { + BackendVersion backendVersion = backend.reconfigure(userConfig); + return backendVersion.toByteArray(); + } + + public byte[] getVersion() { + return backend.getVersion().toByteArray(); } } diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java b/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java index ff19348098027..a24ceea124963 100644 --- a/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java +++ b/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java @@ -12,13 +12,13 @@ public class ReplicaConfig implements Serializable { private String backendDef; - private Object[] initArgs; + private byte[] initArgs; private Map rayActorOptions; private Map resource; - public ReplicaConfig(String backendDef, Object[] initArgs, Map rayActorOptions) { + public ReplicaConfig(String backendDef, byte[] initArgs, Map rayActorOptions) { this.backendDef = backendDef; this.initArgs = initArgs; this.rayActorOptions = rayActorOptions; @@ -89,11 +89,11 @@ public void setBackendDef(String backendDef) { this.backendDef = backendDef; } - public Object[] getInitArgs() { + public byte[] getInitArgs() { return initArgs; } - public void setInitArgs(Object[] initArgs) { + public void setInitArgs(byte[] initArgs) { this.initArgs = initArgs; } diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaContext.java b/java/serve/src/main/java/io/ray/serve/ReplicaContext.java index 10c62cf7eb411..7bd768f7cdd53 100644 --- a/java/serve/src/main/java/io/ray/serve/ReplicaContext.java +++ b/java/serve/src/main/java/io/ray/serve/ReplicaContext.java @@ -3,7 +3,7 @@ /** Stores data for Serve API calls from within the user's backend code. */ public class ReplicaContext { - private String backendTag; + private String backendTag; // TODO deployment private String replicaTag; diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaSet.java b/java/serve/src/main/java/io/ray/serve/ReplicaSet.java new file mode 100644 index 0000000000000..1c7e757bba449 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/ReplicaSet.java @@ -0,0 +1,138 @@ +package io.ray.serve; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.runtime.metric.Gauge; +import io.ray.runtime.metric.Metrics; +import io.ray.runtime.metric.TagKey; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.util.CollectionUtil; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.commons.lang3.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Data structure representing a set of replica actor handles. */ +public class ReplicaSet { + + private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class); + + private volatile int maxConcurrentQueries = 8; + + private final Map, Set>> inFlightQueries; + + private AtomicInteger numQueuedQueries = new AtomicInteger(); + + private Gauge numQueuedQueriesGauge; + + public ReplicaSet(String backendTag) { + this.inFlightQueries = new ConcurrentHashMap<>(); + RayServeMetrics.execute( + () -> + this.numQueuedQueriesGauge = + Metrics.gauge() + .name(RayServeMetrics.SERVE_DEPLOYMENT_QUEUED_QUERIES.getName()) + .description(RayServeMetrics.SERVE_DEPLOYMENT_QUEUED_QUERIES.getDescription()) + .unit("") + .tags(ImmutableMap.of(RayServeMetrics.TAG_DEPLOYMENT, backendTag)) + .register()); + } + + public void setMaxConcurrentQueries(Object backendConfig) { + int newValue = ((BackendConfig) backendConfig).getMaxConcurrentQueries(); + if (newValue != this.maxConcurrentQueries) { + this.maxConcurrentQueries = newValue; + LOGGER.info("ReplicaSet: changing max_concurrent_queries to {}", newValue); + } + } + + public int getMaxConcurrentQueries() { + return maxConcurrentQueries; + } + + @SuppressWarnings("unchecked") + public synchronized void updateWorkerReplicas(Object actorSet) { + List actorNames = ((ActorSet) actorSet).getNamesList(); + Set> workerReplicas = new HashSet<>(); + if (!CollectionUtil.isEmpty(actorNames)) { + actorNames.forEach( + name -> + workerReplicas.add((ActorHandle) Ray.getActor(name).get())); + } + + Set> added = + new HashSet<>(Sets.difference(workerReplicas, inFlightQueries.keySet())); + Set> removed = + new HashSet<>(Sets.difference(inFlightQueries.keySet(), workerReplicas)); + + added.forEach(actorHandle -> inFlightQueries.put(actorHandle, Sets.newConcurrentHashSet())); + removed.forEach(actorHandle -> inFlightQueries.remove(actorHandle)); + + if (added.size() > 0 || removed.size() > 0) { + LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size()); + } + } + + /** + * Given a query, submit it to a replica and return the object ref. This method will keep track of + * the in flight queries for each replicas and only send a query to available replicas (determined + * by the backend max_concurrent_quries value.) + * + * @param query the incoming query. + * @return ray.ObjectRef + */ + public ObjectRef assignReplica(Query query) { + String endpoint = query.getMetadata().getEndpoint(); + numQueuedQueries.incrementAndGet(); + RayServeMetrics.execute( + () -> + numQueuedQueriesGauge.update( + numQueuedQueries.get(), + TagKey.tagsFromMap(ImmutableMap.of(RayServeMetrics.TAG_ENDPOINT, endpoint)))); + ObjectRef assignedRef = + tryAssignReplica(query); // TODO controll concurrency using maxConcurrentQueries + numQueuedQueries.decrementAndGet(); + RayServeMetrics.execute( + () -> + numQueuedQueriesGauge.update( + numQueuedQueries.get(), + TagKey.tagsFromMap(ImmutableMap.of(RayServeMetrics.TAG_ENDPOINT, endpoint)))); + return assignedRef; + } + + /** + * Try to assign query to a replica, return the object ref if succeeded or return None if it can't + * assign this query to any replicas. + * + * @param query query the incoming query. + * @return ray.ObjectRef + */ + private ObjectRef tryAssignReplica(Query query) { + + List> handles = new ArrayList<>(inFlightQueries.keySet()); + if (CollectionUtil.isEmpty(handles)) { + throw new RayServeException("ReplicaSet found no replica."); + } + int randomIndex = RandomUtils.nextInt(0, handles.size()); + ActorHandle replica = + handles.get(randomIndex); // TODO controll concurrency using maxConcurrentQueries + LOGGER.debug("Assigned query {} to replica {}.", query.getMetadata().getRequestId(), replica); + return replica + .task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs()) + .remote(); + } + + public Map, Set>> getInFlightQueries() { + return inFlightQueries; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/Router.java b/java/serve/src/main/java/io/ray/serve/Router.java new file mode 100644 index 0000000000000..5ef339d77767c --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/Router.java @@ -0,0 +1,64 @@ +package io.ray.serve; + +import com.google.common.collect.ImmutableMap; +import io.ray.api.BaseActorHandle; +import io.ray.api.ObjectRef; +import io.ray.runtime.metric.Count; +import io.ray.runtime.metric.Metrics; +import io.ray.serve.generated.RequestMetadata; +import io.ray.serve.poll.KeyListener; +import io.ray.serve.poll.KeyType; +import io.ray.serve.poll.LongPollClient; +import io.ray.serve.poll.LongPollNamespace; +import java.util.HashMap; +import java.util.Map; + +/** Router process incoming queries: choose backend, and assign replica. */ +public class Router { + + private ReplicaSet replicaSet; + + private Count numRouterRequests; + + private LongPollClient longPollClient; + + public Router(BaseActorHandle controllerHandle, String backendTag) { + this.replicaSet = new ReplicaSet(backendTag); + + RayServeMetrics.execute( + () -> + this.numRouterRequests = + Metrics.count() + .name(RayServeMetrics.SERVE_NUM_ROUTER_REQUESTS.getName()) + .description(RayServeMetrics.SERVE_NUM_ROUTER_REQUESTS.getDescription()) + .unit("") + .tags(ImmutableMap.of(RayServeMetrics.TAG_DEPLOYMENT, backendTag)) + .register()); + + Map keyListeners = new HashMap<>(); + keyListeners.put( + new KeyType(LongPollNamespace.BACKEND_CONFIGS, backendTag), + backendConfig -> replicaSet.setMaxConcurrentQueries(backendConfig)); // cross language + keyListeners.put( + new KeyType(LongPollNamespace.REPLICA_HANDLES, backendTag), + workerReplicas -> replicaSet.updateWorkerReplicas(workerReplicas)); // cross language + this.longPollClient = new LongPollClient(controllerHandle, keyListeners); + this.longPollClient.start(); + } + + /** + * Assign a query and returns an object ref represent the result. + * + * @param requestMetadata the metadata of incoming queries. + * @param requestArgs the request body of incoming queries. + * @return ray.ObjectRef + */ + public ObjectRef assignRequest(RequestMetadata requestMetadata, Object[] requestArgs) { + RayServeMetrics.execute(() -> numRouterRequests.inc(1.0)); + return replicaSet.assignReplica(new Query(requestMetadata, requestArgs)); + } + + public ReplicaSet getReplicaSet() { + return replicaSet; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/ServeController.java b/java/serve/src/main/java/io/ray/serve/ServeController.java new file mode 100644 index 0000000000000..1589f4c73b4c2 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/ServeController.java @@ -0,0 +1,6 @@ +package io.ray.serve; + +public interface ServeController { + + byte[] getAllEndpoints(); +} diff --git a/java/serve/src/main/java/io/ray/serve/ServeProxy.java b/java/serve/src/main/java/io/ray/serve/ServeProxy.java new file mode 100644 index 0000000000000..532a2413f9ba5 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/ServeProxy.java @@ -0,0 +1,14 @@ +package io.ray.serve; + +import java.util.Map; + +public interface ServeProxy { + + void init(Map config, ProxyRouter proxyRouter); + + default String getName() { + return getClass().getName(); + } + + default void registerServiceDiscovery() {} +} diff --git a/java/serve/src/main/java/io/ray/serve/api/Client.java b/java/serve/src/main/java/io/ray/serve/api/Client.java new file mode 100644 index 0000000000000..e5c63b5c8e184 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/api/Client.java @@ -0,0 +1,72 @@ +package io.ray.serve.api; + +import io.ray.api.ActorHandle; +import io.ray.api.BaseActorHandle; +import io.ray.api.PyActorHandle; +import io.ray.api.function.PyActorMethod; +import io.ray.serve.RayServeException; +import io.ray.serve.RayServeHandle; +import io.ray.serve.ServeController; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.util.LogUtil; +import io.ray.serve.util.ServeProtoUtil; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class Client { + + private static final Logger LOGGER = LoggerFactory.getLogger(Client.class); + + private BaseActorHandle controller; + + private Map handleCache = new ConcurrentHashMap<>(); + + public Client(BaseActorHandle controller, String controllerName, boolean detached) { + this.controller = controller; + } + + /** + * Retrieve RayServeHandle for service endpoint to invoke it from Python. + * + * @param endpointName A registered service endpoint. + * @param missingOk If true, then Serve won't check the endpoint is registered. False by default. + * @return + */ + @SuppressWarnings("unchecked") + public RayServeHandle getHandle(String endpointName, boolean missingOk) { + + String cacheKey = endpointName + "_" + missingOk; + if (handleCache.containsKey(cacheKey)) { + return handleCache.get(cacheKey); + } + + Map endpoints = null; + if (controller instanceof PyActorHandle) { + endpoints = + ServeProtoUtil.parseEndpointSet( + (byte[]) + ((PyActorHandle) controller) + .task(PyActorMethod.of("get_all_endpoints")) + .remote() + .get()); + } else { + LOGGER.warn("Client only support Python controller now."); + endpoints = + ServeProtoUtil.parseEndpointSet( + ((ActorHandle) controller) + .task(ServeController::getAllEndpoints) + .remote() + .get()); + } + + if (!missingOk && (endpoints == null || !endpoints.containsKey(endpointName))) { + throw new RayServeException(LogUtil.format("Endpoint {} does not exist.", endpointName)); + } + + RayServeHandle handle = new RayServeHandle(controller, endpointName, null, null); + handleCache.put(cacheKey, handle); + return handle; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/api/Serve.java b/java/serve/src/main/java/io/ray/serve/api/Serve.java index 8133e5bd7f23e..3b2c0ed7a2833 100644 --- a/java/serve/src/main/java/io/ray/serve/api/Serve.java +++ b/java/serve/src/main/java/io/ray/serve/api/Serve.java @@ -1,12 +1,20 @@ package io.ray.serve.api; +import com.google.common.base.Preconditions; +import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; +import io.ray.serve.Constants; import io.ray.serve.RayServeException; import io.ray.serve.ReplicaContext; +import io.ray.serve.util.LogUtil; +import java.util.Optional; /** Ray Serve global API. TODO: will be riched in the Java SDK/API PR. */ public class Serve { - public static ReplicaContext INTERNAL_REPLICA_CONTEXT; + private static ReplicaContext INTERNAL_REPLICA_CONTEXT; + + private static Client GLOBAL_CLIENT; /** * Set replica information to global context. @@ -18,11 +26,14 @@ public class Serve { */ public static void setInternalReplicaContext( String backendTag, String replicaTag, String controllerName, Object servableObject) { - // TODO singleton. INTERNAL_REPLICA_CONTEXT = new ReplicaContext(backendTag, replicaTag, controllerName, servableObject); } + public static void setInternalReplicaContext(ReplicaContext replicaContext) { + INTERNAL_REPLICA_CONTEXT = replicaContext; + } + /** * Get the global replica context. * @@ -35,4 +46,43 @@ public static ReplicaContext getReplicaContext() { } return INTERNAL_REPLICA_CONTEXT; } + + public static Client getGlobalClient() { + if (GLOBAL_CLIENT != null) { + return GLOBAL_CLIENT; + } + synchronized (Client.class) { + if (GLOBAL_CLIENT != null) { + return GLOBAL_CLIENT; + } + return connect(); + } + } + + public static void setGlobalClient(Client client) { + GLOBAL_CLIENT = client; + } + + public static Client connect() { + + if (!Ray.isInitialized()) { + Ray.init(); + } + + String controllerName = + INTERNAL_REPLICA_CONTEXT != null + ? INTERNAL_REPLICA_CONTEXT.getInternalControllerName() + : Constants.SERVE_CONTROLLER_NAME; + + Optional optional = Ray.getActor(controllerName); + Preconditions.checkState( + optional.isPresent(), + LogUtil.format( + "There is no instance running on this Ray cluster. " + + "Please call `serve.start(detached=True) to start one.")); + + Client client = new Client(optional.get(), controllerName, true); + setGlobalClient(client); + return client; + } } diff --git a/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java b/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java index 91e9ceca04723..514193e28c37d 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java +++ b/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java @@ -4,5 +4,5 @@ @FunctionalInterface public interface KeyListener { - void notifyChanged(Object object); + void notifyChanged(Object updatedObject); } diff --git a/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java b/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java index 4017be3af9db9..308391254e109 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java +++ b/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java @@ -1,6 +1,7 @@ package io.ray.serve.poll; import com.google.common.base.Preconditions; +import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.BaseActorHandle; import io.ray.api.ObjectRef; import io.ray.api.PyActorHandle; @@ -8,8 +9,16 @@ import io.ray.runtime.exception.RayActorException; import io.ray.runtime.exception.RayTaskException; import io.ray.serve.Constants; +import io.ray.serve.RayServeException; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.UpdatedObject; +import io.ray.serve.util.LogUtil; +import io.ray.serve.util.ServeProtoUtil; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; +import org.apache.commons.lang3.builder.ReflectionToStringBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,6 +42,26 @@ public class LongPollClient { /** An async thread to post the callback into. */ private Thread pollThread; + private static final Map> DESERIALIZERS = + new HashMap<>(); + + static { + DESERIALIZERS.put( + LongPollNamespace.BACKEND_CONFIGS, body -> ServeProtoUtil.parseBackendConfig(body)); + DESERIALIZERS.put( + LongPollNamespace.REPLICA_HANDLES, body -> ServeProtoUtil.parseEndpointSet(body)); + DESERIALIZERS.put( + LongPollNamespace.REPLICA_HANDLES, + body -> { + try { + return ActorSet.parseFrom(body); + } catch (InvalidProtocolBufferException e) { + throw new RayServeException( + LogUtil.format("Failed to parse ActorSet from protobuf bytes."), e); + } + }); + } + public LongPollClient(BaseActorHandle hostActor, Map keyListeners) { Preconditions.checkArgument(keyListeners != null && keyListeners.size() != 0); @@ -51,7 +80,7 @@ public LongPollClient(BaseActorHandle hostActor, Map keyLi try { pollNext(); } catch (RayActorException e) { - LOGGER.debug("LongPollClient failed to connect to host. Shutting down."); + LOGGER.error("LongPollClient failed to connect to host. Shutting down."); break; } catch (RayTaskException e) { LOGGER.error("LongPollHost errored", e); @@ -71,24 +100,44 @@ public void start() { pollThread.start(); } - /** Poll the update. */ - @SuppressWarnings("unchecked") - public void pollNext() { + /** + * Poll the update. + * + * @throws InvalidProtocolBufferException if the protobuf deserialization fails. + */ + public void pollNext() throws InvalidProtocolBufferException { currentRef = ((PyActorHandle) hostActor) .task(PyActorMethod.of(Constants.CONTROLLER_LISTEN_FOR_CHANGE_METHOD), snapshotIds) .remote(); - processUpdate((Map) currentRef.get()); + processUpdate(ServeProtoUtil.parseUpdatedObjects((byte[]) currentRef.get())); } public void processUpdate(Map updates) { - - LOGGER.debug("LongPollClient received updates for keys: {}", updates.keySet()); - + if (updates == null || updates.isEmpty()) { + LOGGER.info("LongPollClient received nothing."); + return; + } + LOGGER.info("LongPollClient received updates for keys: {}", updates.keySet()); for (Map.Entry entry : updates.entrySet()) { - objectSnapshots.put(entry.getKey(), entry.getValue().getObjectSnapshot()); + KeyType keyType = entry.getKey(); + UpdatedObject updatedObject = entry.getValue(); + + Object objectSnapshot = + DESERIALIZERS + .get(keyType.getLongPollNamespace()) + .apply(updatedObject.getObjectSnapshot().toByteArray()); + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "The updated object for key {} is {}", + keyType, + ReflectionToStringBuilder.toString(objectSnapshot)); + } + + keyListeners.get(entry.getKey()).notifyChanged(objectSnapshot); + objectSnapshots.put(entry.getKey(), objectSnapshot); snapshotIds.put(entry.getKey(), entry.getValue().getSnapshotId()); - keyListeners.get(entry.getKey()).notifyChanged(entry.getValue().getObjectSnapshot()); } } diff --git a/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java b/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java index 466af829167e8..71b3a2e8baa1e 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java +++ b/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java @@ -4,9 +4,7 @@ public enum LongPollNamespace { REPLICA_HANDLES, - TRAFFIC_POLICIES, - BACKEND_CONFIGS, - ROUTE_TABLE + ROUTE_TABLE; } diff --git a/java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java b/java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java deleted file mode 100644 index 3f3ddc63c1ae2..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java +++ /dev/null @@ -1,33 +0,0 @@ -package io.ray.serve.poll; - -import java.io.Serializable; - -/** The updated object that long poll client received. */ -public class UpdatedObject implements Serializable { - - private static final long serialVersionUID = 6245682414826079438L; - - private Object objectSnapshot; - - /** - * The identifier for the object's version. There is not sequential relation among different - * object's snapshot_ids. - */ - private int snapshotId; - - public Object getObjectSnapshot() { - return objectSnapshot; - } - - public void setObjectSnapshot(Object objectSnapshot) { - this.objectSnapshot = objectSnapshot; - } - - public int getSnapshotId() { - return snapshotId; - } - - public void setSnapshotId(int snapshotId) { - this.snapshotId = snapshotId; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java b/java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java new file mode 100644 index 0000000000000..cd66932f48276 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java @@ -0,0 +1,10 @@ +package io.ray.serve.util; + +import java.util.Collection; + +public class CollectionUtil { + + public static boolean isEmpty(Collection collection) { + return collection == null || collection.isEmpty(); + } +} diff --git a/java/serve/src/main/java/io/ray/serve/util/CommonUtil.java b/java/serve/src/main/java/io/ray/serve/util/CommonUtil.java new file mode 100644 index 0000000000000..a32ee212196d8 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/util/CommonUtil.java @@ -0,0 +1,13 @@ +package io.ray.serve.util; + +import org.apache.commons.lang3.StringUtils; + +public class CommonUtil { + + public static String formatActorName(String controllerName, String actorName) { + if (StringUtils.isBlank(controllerName)) { + return actorName; + } + return controllerName + ":" + actorName; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java b/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java index 5de1142433008..ae449dd714733 100644 --- a/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java +++ b/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java @@ -2,6 +2,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Executable; +import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; @@ -178,4 +179,17 @@ public static List getMethodStrings(Class targetClass) { } return methodStrings; } + + @SuppressWarnings("unchecked") + public static List getInstancesByClassNames(String classNames, Class cls) + throws ClassNotFoundException, InstantiationException, IllegalAccessException, + IllegalArgumentException, InvocationTargetException, NoSuchMethodException, + SecurityException { + String[] classNameArray = StringUtils.split(classNames, ";"); + List isntances = new ArrayList<>(); + for (String className : classNameArray) { + isntances.add((T) Class.forName(className).getConstructor().newInstance()); + } + return isntances; + } } diff --git a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java index b1d02a046063e..1a1c0c082d3f8 100644 --- a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java +++ b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java @@ -2,26 +2,42 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Lists; +import com.google.gson.Gson; import com.google.protobuf.InvalidProtocolBufferException; import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.Constants; import io.ray.serve.RayServeException; import io.ray.serve.generated.BackendConfig; import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.BackendVersion; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.generated.EndpointSet; +import io.ray.serve.generated.LongPollResult; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.generated.RequestWrapper; +import io.ray.serve.generated.UpdatedObject; +import io.ray.serve.poll.KeyType; +import java.util.HashMap; +import java.util.Map; import org.apache.commons.lang3.StringUtils; public class ServeProtoUtil { - public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) - throws InvalidProtocolBufferException { + private static final Gson GSON = new Gson(); + + public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) { // Get a builder from BackendConfig(bytes) or create a new one. BackendConfig.Builder builder = null; if (backendConfigBytes == null) { builder = BackendConfig.newBuilder(); } else { - BackendConfig backendConfig = BackendConfig.parseFrom(backendConfigBytes); + BackendConfig backendConfig = null; + try { + backendConfig = BackendConfig.parseFrom(backendConfigBytes); + } catch (InvalidProtocolBufferException e) { + throw new RayServeException("Failed to parse BackendConfig from protobuf bytes.", e); + } if (backendConfig == null) { builder = BackendConfig.newBuilder(); } else { @@ -40,12 +56,12 @@ public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) builder.setMaxConcurrentQueries(100); } - if (builder.getExperimentalGracefulShutdownWaitLoopS() == 0) { - builder.setExperimentalGracefulShutdownWaitLoopS(2); + if (builder.getGracefulShutdownWaitLoopS() == 0) { + builder.setGracefulShutdownWaitLoopS(2); } - if (builder.getExperimentalGracefulShutdownTimeoutS() == 0) { - builder.setExperimentalGracefulShutdownTimeoutS(20); + if (builder.getGracefulShutdownTimeoutS() == 0) { + builder.setGracefulShutdownTimeoutS(20); } if (builder.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) { @@ -84,7 +100,7 @@ public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes) // Set default values. if (StringUtils.isBlank(builder.getCallMethod())) { - builder.setCallMethod("call"); + builder.setCallMethod(Constants.DEFAULT_CALL_METHOD); } return builder.build(); @@ -108,4 +124,47 @@ public static RequestWrapper parseRequestWrapper(byte[] httpRequestWrapperBytes) return builder.build(); } + + public static Map parseUpdatedObjects(byte[] longPollResultBytes) + throws InvalidProtocolBufferException { + if (longPollResultBytes == null) { + return null; + } + LongPollResult longPollResult = LongPollResult.parseFrom(longPollResultBytes); + Map updatedObjects = longPollResult.getUpdatedObjectsMap(); + if (updatedObjects == null || updatedObjects.isEmpty()) { + return null; + } + Map udpates = new HashMap<>(updatedObjects.size()); + updatedObjects.forEach( + (key, value) -> udpates.put(ServeProtoUtil.GSON.fromJson(key, KeyType.class), value)); + return udpates; + } + + public static Map parseEndpointSet(byte[] endpointSetBytes) { + if (endpointSetBytes == null) { + return null; + } + EndpointSet endpointSet = null; + try { + endpointSet = EndpointSet.parseFrom(endpointSetBytes); + } catch (InvalidProtocolBufferException e) { + throw new RayServeException("Failed to parse EndpointSet from protobuf bytes.", e); + } + if (endpointSet == null) { + return null; + } + return endpointSet.getEndpointsMap(); + } + + public static BackendVersion parseBackendVersion(byte[] backendVersionBytes) { + if (backendVersionBytes == null) { + return null; + } + try { + return BackendVersion.parseFrom(backendVersionBytes); + } catch (InvalidProtocolBufferException e) { + throw new RayServeException("Failed to parse BackendVersion from protobuf bytes.", e); + } + } } diff --git a/java/serve/src/main/java/io/ray/serve/util/SocketUtil.java b/java/serve/src/main/java/io/ray/serve/util/SocketUtil.java new file mode 100644 index 0000000000000..ab93a6e152210 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/util/SocketUtil.java @@ -0,0 +1,49 @@ +package io.ray.serve.util; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + +public class SocketUtil { + + public static final int PORT_RANGE_MAX = 65535; + + public static int findAvailableTcpPort(int minPort) { + int portRange = PORT_RANGE_MAX - minPort; + int candidatePort = minPort; + int searchCounter = 0; + while (!isPortAvailable(candidatePort)) { + candidatePort++; + if (++searchCounter > portRange) { + throw new IllegalStateException( + String.format( + "Could not find an available tcp port in the range [%d, %d] after %d attempts.", + minPort, PORT_RANGE_MAX, searchCounter)); + } + } + return candidatePort; + } + + public static boolean isPortAvailable(int port) { + ServerSocket socket; + try { + socket = new ServerSocket(); + } catch (IOException e) { + throw new IllegalStateException("Unable to create ServerSocket.", e); + } + + try { + InetSocketAddress sa = new InetSocketAddress(port); + socket.bind(sa); + return true; + } catch (IOException ex) { + return false; + } finally { + try { + socket.close(); + } catch (IOException ex) { + // ignore this exception for now + } + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/DummyServeController.java b/java/serve/src/test/java/io/ray/serve/DummyServeController.java new file mode 100644 index 0000000000000..6ee319a477898 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/DummyServeController.java @@ -0,0 +1,21 @@ +package io.ray.serve; + +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.generated.EndpointSet; +import java.util.Map; + +public class DummyServeController implements ServeController { + + private Map endpoints; + + @Override + public byte[] getAllEndpoints() { + EndpointSet.Builder builder = EndpointSet.newBuilder(); + builder.putAllEndpoints(endpoints); + return builder.build().toByteArray(); + } + + public void setEndpoints(Map endpoints) { + this.endpoints = endpoints; + } +} diff --git a/java/serve/src/test/java/io/ray/serve/HttpProxyTest.java b/java/serve/src/test/java/io/ray/serve/HttpProxyTest.java new file mode 100644 index 0000000000000..5166603662c82 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/HttpProxyTest.java @@ -0,0 +1,74 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.serve.api.Serve; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.util.CommonUtil; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.util.HashMap; +import java.util.Map; +import org.apache.commons.lang3.RandomStringUtils; +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.client5.http.classic.methods.HttpPost; +import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; +import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class HttpProxyTest { + + @Test + public void test() throws IOException { + + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String controllerName = + CommonUtil.formatActorName( + Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); + String endpointName = "HTTPProxyTest"; + String route = "/route"; + + // Controller + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + + Map endpointInfos = new HashMap<>(); + endpointInfos.put( + endpointName, + EndpointInfo.newBuilder().setEndpointName(endpointName).setRoute(route).build()); + controllerHandle.task(DummyServeController::setEndpoints, endpointInfos).remote(); + + Serve.setInternalReplicaContext(null, null, controllerName, null); + + // ProxyRouter updates routes. + ProxyRouter proxyRouter = new ProxyRouter(); + proxyRouter.updateRoutes(endpointInfos); + + // HTTP proxy. + HttpProxy httpProxy = new HttpProxy(); + httpProxy.init(null, proxyRouter); + + // Send request. + HttpClient httpClient = HttpClientBuilder.create().build(); + HttpPost httpPost = new HttpPost("http://localhost:" + httpProxy.getPort() + route); + try (CloseableHttpResponse httpResponse = + (CloseableHttpResponse) httpClient.execute(httpPost)) { + + // No Backend replica, so error is expected. + int status = httpResponse.getCode(); + Assert.assertEquals(status, HttpURLConnection.HTTP_INTERNAL_ERROR); + } + + } finally { + if (!inited) { + Ray.shutdown(); + } + Serve.setInternalReplicaContext(null); + Serve.setGlobalClient(null); + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java b/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java new file mode 100644 index 0000000000000..6b1daa11b1141 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java @@ -0,0 +1,110 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.api.Serve; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendVersion; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.util.CommonUtil; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.util.HashMap; +import java.util.Map; +import org.apache.commons.lang3.RandomStringUtils; +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.client5.http.classic.methods.HttpPost; +import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; +import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ProxyActorTest { + + @Test + public void test() throws IOException { + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String prefix = "ProxyActorTest"; + String controllerName = + CommonUtil.formatActorName( + Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); + String backendTag = prefix; + String replicaTag = prefix; + String endpointName = prefix; + String route = "/route"; + String version = "v1"; + + // Controller + ActorHandle controller = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + Map endpointInfos = new HashMap<>(); + endpointInfos.put( + endpointName, + EndpointInfo.newBuilder().setEndpointName(endpointName).setRoute(route).build()); + controller.task(DummyServeController::setEndpoints, endpointInfos).remote(); + + // Replica + DeploymentInfo deploymentInfo = new DeploymentInfo(); + deploymentInfo.setBackendConfig(BackendConfig.newBuilder().build().toByteArray()); + deploymentInfo.setBackendVersion( + BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); + deploymentInfo.setReplicaConfig( + new ReplicaConfig(DummyBackendReplica.class.getName(), null, new HashMap<>())); + + ActorHandle replica = + Ray.actor( + RayServeWrappedReplica::new, + backendTag, + replicaTag, + deploymentInfo, + controllerName) + .setName(replicaTag) + .remote(); + replica.task(RayServeWrappedReplica::ready).remote(); + + // ProxyActor + ProxyActor proxyActor = new ProxyActor(controllerName, null); + proxyActor.getProxyRouter().updateRoutes(endpointInfos); + proxyActor + .getProxyRouter() + .getHandles() + .get(endpointName) + .getRouter() + .getReplicaSet() + .updateWorkerReplicas(ActorSet.newBuilder().addNames(replicaTag).build()); + + // Send request. + HttpClient httpClient = HttpClientBuilder.create().build(); + HttpPost httpPost = + new HttpPost( + "http://localhost:" + + ((HttpProxy) proxyActor.getProxies().get(HttpProxy.PROXY_NAME)).getPort() + + route); + try (CloseableHttpResponse httpResponse = + (CloseableHttpResponse) httpClient.execute(httpPost)) { + + int status = httpResponse.getCode(); + Assert.assertEquals(status, HttpURLConnection.HTTP_OK); + Object result = + MessagePackSerializer.decode( + EntityUtils.toByteArray(httpResponse.getEntity()), Object.class); + + Assert.assertNotNull(result); + Assert.assertEquals("1", result.toString()); + } + + } finally { + if (!inited) { + Ray.shutdown(); + } + Serve.setInternalReplicaContext(null); + Serve.setGlobalClient(null); + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java b/java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java new file mode 100644 index 0000000000000..03535a0575a79 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java @@ -0,0 +1,68 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.serve.api.Serve; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.util.CommonUtil; +import java.util.HashMap; +import java.util.Map; +import org.apache.commons.lang3.RandomStringUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ProxyRouterTest { + + @Test + public void test() { + + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String prefix = "ProxyRouterTest"; + String controllerName = + CommonUtil.formatActorName( + Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); + String endpointName1 = prefix + "_1"; + String endpointName2 = prefix + "_2"; + String route1 = "/route1"; + + // Controller + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + Map endpointInfos = new HashMap<>(); + endpointInfos.put( + endpointName1, + EndpointInfo.newBuilder().setEndpointName(endpointName1).setRoute(route1).build()); + endpointInfos.put( + endpointName2, EndpointInfo.newBuilder().setEndpointName(endpointName2).build()); + controllerHandle.task(DummyServeController::setEndpoints, endpointInfos).remote(); + + Serve.setInternalReplicaContext(null, null, controllerName, null); + + // ProxyRouter updates routes. + ProxyRouter proxyRouter = new ProxyRouter(); + proxyRouter.updateRoutes(endpointInfos); + + // Check result. + Map routeInfo = proxyRouter.getRouteInfo(); + Assert.assertNotNull(routeInfo); + Assert.assertNotNull(routeInfo.get(route1)); + Assert.assertEquals(routeInfo.get(route1).getRoute(), route1); + Assert.assertEquals(routeInfo.get(route1).getEndpointName(), endpointName1); + Assert.assertNotNull(routeInfo.get(endpointName2)); + Assert.assertEquals(routeInfo.get(endpointName2).getEndpointName(), endpointName2); + Map handles = proxyRouter.getHandles(); + Assert.assertNotNull(handles); + Assert.assertNotNull(handles.get(endpointName1)); + Assert.assertNotNull(handles.get(endpointName2)); + } finally { + if (!inited) { + Ray.shutdown(); + } + Serve.setInternalReplicaContext(null); + Serve.setGlobalClient(null); + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java b/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java new file mode 100644 index 0000000000000..9e4ac68b612fd --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java @@ -0,0 +1,76 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.BackendVersion; +import java.util.HashMap; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RayServeHandleTest { + + @Test + public void test() { + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String backendTag = "RayServeHandleTest"; + String controllerName = backendTag + "_controller"; + String replicaTag = backendTag + "_replica"; + String actorName = replicaTag; + String version = "v1"; + + // Controller + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + + // Replica + BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); + backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); + byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + + Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; + byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); + + DeploymentInfo deploymentInfo = new DeploymentInfo(); + deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setBackendVersion( + BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); + deploymentInfo.setReplicaConfig( + new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); + + ActorHandle replicaHandle = + Ray.actor( + RayServeWrappedReplica::new, + backendTag, + replicaTag, + deploymentInfo, + controllerName) + .setName(actorName) + .remote(); + replicaHandle.task(RayServeWrappedReplica::ready).remote(); + + // RayServeHandle + RayServeHandle rayServeHandle = + new RayServeHandle(controllerHandle, backendTag, null, null) + .setMethodName("getBackendTag"); + ActorSet.Builder builder = ActorSet.newBuilder(); + builder.addNames(actorName); + rayServeHandle.getRouter().getReplicaSet().updateWorkerReplicas(builder.build()); + + // remote + ObjectRef resultRef = rayServeHandle.remote(null); + Assert.assertEquals((String) resultRef.get(), backendTag); + } finally { + if (!inited) { + Ray.shutdown(); + } + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java index 7cc7746ff165c..065b74ac1fc0e 100644 --- a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java +++ b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java @@ -6,9 +6,12 @@ import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.generated.BackendConfig; import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.BackendVersion; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.generated.RequestWrapper; import java.io.IOException; +import java.util.HashMap; +import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -17,7 +20,6 @@ public class RayServeReplicaTest { @SuppressWarnings("unused") @Test public void test() throws IOException { - boolean inited = Ray.isInitialized(); Ray.init(); @@ -25,38 +27,40 @@ public void test() throws IOException { String controllerName = "RayServeReplicaTest"; String backendTag = "b_tag"; String replicaTag = "r_tag"; + String version = "v1"; - ActorHandle controllerHandle = - Ray.actor(ReplicaContext::new, backendTag, replicaTag, controllerName, new Object()) - .setName(controllerName) - .remote(); + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); - byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); - Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; - byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); + DeploymentInfo deploymentInfo = new DeploymentInfo(); + deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setBackendVersion( + BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); + deploymentInfo.setReplicaConfig( + new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); + ActorHandle backendHandle = Ray.actor( RayServeWrappedReplica::new, backendTag, replicaTag, - "io.ray.serve.ReplicaContext", - initArgsBytes, - backendConfigBytes, + deploymentInfo, controllerName) .remote(); + // ready backendHandle.task(RayServeWrappedReplica::ready).remote(); + // handle request RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); - requestMetadata.setRequestId("RayServeReplicaTest"); + requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); requestMetadata.setCallMethod("getBackendTag"); - RequestWrapper.Builder requestWrapper = RequestWrapper.newBuilder(); ObjectRef resultRef = @@ -66,8 +70,22 @@ public void test() throws IOException { requestMetadata.build().toByteArray(), requestWrapper.build().toByteArray()) .remote(); - Assert.assertEquals((String) resultRef.get(), backendTag); + + // reconfigure + ObjectRef versionRef = + backendHandle.task(RayServeWrappedReplica::reconfigure, (Object) null).remote(); + Assert.assertEquals(BackendVersion.parseFrom(versionRef.get()).getCodeVersion(), version); + + // get version + versionRef = backendHandle.task(RayServeWrappedReplica::getVersion).remote(); + Assert.assertEquals(BackendVersion.parseFrom(versionRef.get()).getCodeVersion(), version); + + // prepare for shutdown + ObjectRef shutdownRef = + backendHandle.task(RayServeWrappedReplica::prepareForShutdown).remote(); + Assert.assertTrue(shutdownRef.get()); + } finally { if (!inited) { Ray.shutdown(); diff --git a/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java new file mode 100644 index 0000000000000..513d27e4bb6b1 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java @@ -0,0 +1,108 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.BackendVersion; +import io.ray.serve.generated.RequestMetadata; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import org.apache.commons.lang3.RandomStringUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ReplicaSetTest { + + private String backendTag = "ReplicaSetTest"; + + @Test + public void setMaxConcurrentQueriesTest() { + ReplicaSet replicaSet = new ReplicaSet(backendTag); + BackendConfig.Builder builder = BackendConfig.newBuilder(); + builder.setMaxConcurrentQueries(200); + + replicaSet.setMaxConcurrentQueries(builder.build()); + Assert.assertEquals(replicaSet.getMaxConcurrentQueries(), 200); + } + + @Test + public void updateWorkerReplicasTest() { + ReplicaSet replicaSet = new ReplicaSet(backendTag); + ActorSet.Builder builder = ActorSet.newBuilder(); + + replicaSet.updateWorkerReplicas(builder.build()); + Map, Set>> inFlightQueries = + replicaSet.getInFlightQueries(); + Assert.assertTrue(inFlightQueries.isEmpty()); + } + + @SuppressWarnings("unused") + @Test + public void assignReplicaTest() { + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String controllerName = backendTag + "_controller"; + String replicaTag = backendTag + "_replica"; + String actorName = replicaTag; + String version = "v1"; + + // Controller + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + + // Replica + BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); + backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); + byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + + Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; + byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); + + DeploymentInfo deploymentInfo = new DeploymentInfo(); + deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setBackendVersion( + BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); + deploymentInfo.setReplicaConfig( + new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); + + ActorHandle replicaHandle = + Ray.actor( + RayServeWrappedReplica::new, + backendTag, + replicaTag, + deploymentInfo, + controllerName) + .setName(actorName) + .remote(); + replicaHandle.task(RayServeWrappedReplica::ready).remote(); + + // ReplicaSet + ReplicaSet replicaSet = new ReplicaSet(backendTag); + ActorSet.Builder builder = ActorSet.newBuilder(); + builder.addNames(actorName); + replicaSet.updateWorkerReplicas(builder.build()); + + // assign + + RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); + requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); + requestMetadata.setCallMethod("getBackendTag"); + + Query query = new Query(requestMetadata.build(), null); + ObjectRef resultRef = replicaSet.assignReplica(query); + + Assert.assertEquals((String) resultRef.get(), backendTag); + } finally { + if (!inited) { + Ray.shutdown(); + } + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/RouterTest.java b/java/serve/src/test/java/io/ray/serve/RouterTest.java new file mode 100644 index 0000000000000..3312179912e38 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/RouterTest.java @@ -0,0 +1,80 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.BackendVersion; +import io.ray.serve.generated.RequestMetadata; +import java.util.HashMap; +import org.apache.commons.lang3.RandomStringUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RouterTest { + + @Test + public void test() { + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String backendTag = "RouterTest"; + String controllerName = backendTag + "_controller"; + String replicaTag = backendTag + "_replica"; + String actorName = replicaTag; + String version = "v1"; + + // Controller + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + + // Replica + BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); + backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); + byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + + Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; + byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); + + DeploymentInfo deploymentInfo = new DeploymentInfo(); + deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setBackendVersion( + BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); + deploymentInfo.setReplicaConfig( + new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); + + ActorHandle replicaHandle = + Ray.actor( + RayServeWrappedReplica::new, + backendTag, + replicaTag, + deploymentInfo, + controllerName) + .setName(actorName) + .remote(); + replicaHandle.task(RayServeWrappedReplica::ready).remote(); + + // Router + Router router = new Router(controllerHandle, backendTag); + ActorSet.Builder builder = ActorSet.newBuilder(); + builder.addNames(actorName); + router.getReplicaSet().updateWorkerReplicas(builder.build()); + + // assign + RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); + requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); + requestMetadata.setCallMethod("getBackendTag"); + + ObjectRef resultRef = router.assignRequest(requestMetadata.build(), null); + Assert.assertEquals((String) resultRef.get(), backendTag); + } finally { + if (!inited) { + Ray.shutdown(); + } + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/api/ClientTest.java b/java/serve/src/test/java/io/ray/serve/api/ClientTest.java new file mode 100644 index 0000000000000..c3489bc1a1a19 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/api/ClientTest.java @@ -0,0 +1,47 @@ +package io.ray.serve.api; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.serve.DummyServeController; +import io.ray.serve.RayServeHandle; +import io.ray.serve.generated.EndpointInfo; +import java.util.HashMap; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ClientTest { + + @Test + public void getHandleTest() { + + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String prefix = "ClientTest"; + String controllerName = prefix + "_controller"; + String endpointName = prefix + "_endpoint"; + + // Controller. + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + + // Mock endpoints. + Map endpoints = new HashMap<>(); + endpoints.put(endpointName, EndpointInfo.newBuilder().setEndpointName(endpointName).build()); + controllerHandle.task(DummyServeController::setEndpoints, endpoints).remote(); + + // Client. + Client client = new Client(controllerHandle, controllerName, true); + + // Get handle. + RayServeHandle rayServeHandle = client.getHandle(endpointName, false); + Assert.assertNotNull(rayServeHandle); + } finally { + if (!inited) { + Ray.shutdown(); + } + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/api/ServeTest.java b/java/serve/src/test/java/io/ray/serve/api/ServeTest.java index b63a709a167de..cf470e8ce2248 100644 --- a/java/serve/src/test/java/io/ray/serve/api/ServeTest.java +++ b/java/serve/src/test/java/io/ray/serve/api/ServeTest.java @@ -1,7 +1,12 @@ package io.ray.serve.api; -import io.ray.serve.RayServeException; +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.serve.Constants; +import io.ray.serve.DummyServeController; import io.ray.serve.ReplicaContext; +import io.ray.serve.util.CommonUtil; +import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -10,31 +15,53 @@ public class ServeTest { @Test public void replicaContextTest() { - ReplicaContext preContext = Serve.INTERNAL_REPLICA_CONTEXT; - ReplicaContext replicaContext; - - // Test null replica context. - Serve.INTERNAL_REPLICA_CONTEXT = null; try { - replicaContext = Serve.getReplicaContext(); - Assert.assertTrue(false, "expect RayServeException"); - } catch (RayServeException e) { + // Test context setting and getting. + String backendTag = "backendTag"; + String replicaTag = "replicaTag"; + String controllerName = "controllerName"; + Object servableObject = new Object(); + Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, servableObject); + ReplicaContext replicaContext = Serve.getReplicaContext(); + Assert.assertNotNull(replicaContext, "no replica context"); + Assert.assertEquals(replicaContext.getBackendTag(), backendTag); + Assert.assertEquals(replicaContext.getReplicaTag(), replicaTag); + Assert.assertEquals(replicaContext.getInternalControllerName(), controllerName); + } finally { + // Recover context. + Serve.setInternalReplicaContext(null); } + } - // Test context setting and getting. - String backendTag = "backendTag"; - String replicaTag = "replicaTag"; - String controllerName = "controllerName"; - Object servableObject = new Object(); - Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, servableObject); - - replicaContext = Serve.getReplicaContext(); - Assert.assertNotNull(replicaContext, "no replica context"); - Assert.assertEquals(replicaContext.getBackendTag(), backendTag); - Assert.assertEquals(replicaContext.getReplicaTag(), replicaTag); - Assert.assertEquals(replicaContext.getInternalControllerName(), controllerName); + @SuppressWarnings("unused") + @Test + public void getGlobalClientTest() { + boolean inited = Ray.isInitialized(); + Ray.init(); + try { + Client client = null; + try { + client = Serve.getGlobalClient(); + Assert.assertTrue(false, "Expect IllegalStateException here!"); + } catch (IllegalStateException e) { + } + Assert.assertNull(client); - Serve.INTERNAL_REPLICA_CONTEXT = preContext; + String controllerName = + CommonUtil.formatActorName( + Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); + ActorHandle actorHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + Serve.setInternalReplicaContext(null, null, controllerName, null); + client = Serve.getGlobalClient(); + Assert.assertNotNull(client); + } finally { + if (!inited) { + Ray.shutdown(); + } + Serve.setInternalReplicaContext(null); + Serve.setGlobalClient(null); + } } } diff --git a/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java b/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java index 628f5ff4a89c4..710ad97128ede 100644 --- a/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java +++ b/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java @@ -1,12 +1,15 @@ package io.ray.serve.poll; +import com.google.gson.Gson; import org.testng.Assert; import org.testng.annotations.Test; public class KeyTypeTest { + private static final Gson GSON = new Gson(); + @Test - public void test() { + public void hashTest() { KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1"); KeyType k2 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1"); KeyType k3 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, null); @@ -28,4 +31,14 @@ public void test() { Assert.assertNotEquals(k1.hashCode(), k4.hashCode()); Assert.assertFalse(k1.equals(k4)); } + + @Test + public void jsonTest() { + KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1"); + String json = GSON.toJson(k1); + + KeyType k2 = GSON.fromJson(json, KeyType.class); + Assert.assertEquals(k1, k2); + Assert.assertEquals(k1.hashCode(), k2.hashCode()); + } } diff --git a/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java b/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java index 3d172d87bedc7..7ee254806fad3 100644 --- a/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java +++ b/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java @@ -1,5 +1,8 @@ package io.ray.serve.poll; +import com.google.protobuf.ByteString; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.UpdatedObject; import java.util.HashMap; import java.util.Map; import org.testng.Assert; @@ -10,25 +13,35 @@ public class LongPollClientTest { @Test public void test() throws Throwable { + String[] a = new String[] {"test"}; + + // Construct a listener map. KeyType keyType = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "backendTag"); - int[] a = new int[] {0}; Map keyListeners = new HashMap<>(); - keyListeners.put(keyType, (object) -> a[0] = (Integer) object); + keyListeners.put( + keyType, (object) -> a[0] = String.valueOf(((BackendConfig) object).getNumReplicas())); + + // Initialize LongPollClient. LongPollClient longPollClient = new LongPollClient(null, keyListeners); + // Construct updated object. + BackendConfig.Builder backendConfig = BackendConfig.newBuilder(); + backendConfig.setNumReplicas(20); int snapshotId = 10; - int objectSnapshot = 20; - UpdatedObject updatedObject = new UpdatedObject(); + UpdatedObject.Builder updatedObject = UpdatedObject.newBuilder(); updatedObject.setSnapshotId(snapshotId); - updatedObject.setObjectSnapshot(objectSnapshot); + updatedObject.setObjectSnapshot(ByteString.copyFrom(backendConfig.build().toByteArray())); + // Process update. Map updates = new HashMap<>(); - updates.put(keyType, updatedObject); + updates.put(keyType, updatedObject.build()); longPollClient.processUpdate(updates); + // Validation. Assert.assertEquals(longPollClient.getSnapshotIds().get(keyType).intValue(), snapshotId); Assert.assertEquals( - ((Integer) longPollClient.getObjectSnapshots().get(keyType)).intValue(), objectSnapshot); - Assert.assertEquals(a[0], objectSnapshot); + ((BackendConfig) longPollClient.getObjectSnapshots().get(keyType)).getNumReplicas(), + backendConfig.getNumReplicas()); + Assert.assertEquals(a[0], String.valueOf(backendConfig.getNumReplicas())); } } diff --git a/python/build-wheel-windows.sh b/python/build-wheel-windows.sh index cb36f901bd61c..c7c282acaa421 100755 --- a/python/build-wheel-windows.sh +++ b/python/build-wheel-windows.sh @@ -81,6 +81,13 @@ build_wheel_windows() { unset PYTHON2_BIN_PATH PYTHON3_BIN_PATH # make sure these aren't set by some chance install_ray cd "${WORKSPACE_DIR}"/python + # Set the commit SHA in __init__.py. + if [ -n "$TRAVIS_COMMIT" ]; then + sed -i.bak "s/{{RAY_COMMIT_SHA}}/$TRAVIS_COMMIT/g" ray/__init__.py && rm ray/__init__.py.bak + else + echo "TRAVIS_COMMIT variable not set - required to populated ray.__commit__." + exit 1 + fi # build ray wheel python setup.py --quiet bdist_wheel # build ray-cpp wheel diff --git a/python/ray/_private/client_mode_hook.py b/python/ray/_private/client_mode_hook.py index e6abf5f5a98f0..ef3df68206303 100644 --- a/python/ray/_private/client_mode_hook.py +++ b/python/ray/_private/client_mode_hook.py @@ -1,6 +1,6 @@ import os from contextlib import contextmanager -from functools import wraps +from functools import partial, wraps import threading # Attr set on func defs to mark they have been converted to client mode. @@ -15,6 +15,8 @@ is_client_mode_enabled_by_default = is_client_mode_enabled os.environ.update({"RAY_CLIENT_MODE": "0"}) +is_init_called = False + # Local setting of whether to ignore client hook conversion. This defaults # to TRUE and is disabled when the underlying 'real' Ray function is needed. _client_hook_status_on_thread = threading.local() @@ -75,13 +77,27 @@ def enable_client_mode(): _explicitly_disable_client_mode() -def client_mode_hook(func): - """Decorator for ray module methods to delegate to ray client""" +def client_mode_hook(func=None, *, auto_init: bool): + """Decorator for whether to use the 'regular' ray version of a function, + or the Ray Client version of that function. + + Args: + func (callable): This function. This is set when this function is used + as a decorator. + auto_init (bool): Whether `ray.init()` should be transparently called when + the wrapped function is called. This should be `True` for functions + that are *NOT* part of the initialization path (e.g. `init` or + `is_initialized`) or for functions that do not require Ray to be + initialized (e.g., KV operations, `shutdown`). + """ + if func is None: + return partial(client_mode_hook, auto_init=auto_init) + from ray.util.client import ray @wraps(func) def wrapper(*args, **kwargs): - if client_mode_should_convert(): + if client_mode_should_convert(auto_init=auto_init): # Legacy code # we only convert init function if RAY_CLIENT_MODE=1 if func.__name__ != "init" or is_client_mode_enabled_by_default: @@ -91,13 +107,23 @@ def wrapper(*args, **kwargs): return wrapper -def client_mode_should_convert(): - # This is for testing with RAY_CLIENT_MODE. - # When RAY_CLIENT_MODE=1, it means that for all the tests - # will run with client mode. - # is_client_mode_enabled will be set to be off when client is off +def client_mode_should_convert(*, auto_init: bool): + """Determines if functions should be converted to client mode & if + Ray should be auto-initialized. + + NOTE: `auto_init` must happen before we branch into regular ray or client + code because the initialization may result in either mode. + """ + if auto_init: + import ray + if os.environ.get("RAY_ENABLE_AUTO_CONNECT", + "") != "0" and not ray.is_initialized(): + ray.init() + + # `is_client_mode_enabled_by_default` is used for testing with + # `RAY_CLIENT_MODE=1`. This flag means all tests run with client mode. return (is_client_mode_enabled or is_client_mode_enabled_by_default) and \ - _get_client_hook_status_on_thread() + _get_client_hook_status_on_thread() def client_mode_wrap(func): @@ -115,7 +141,9 @@ def client_mode_wrap(func): @wraps(func) def wrapper(*args, **kwargs): - if client_mode_should_convert(): + # Directly pass this through since `client_mode_wrap` is for + # Placement Group APIs + if client_mode_should_convert(auto_init=True): f = ray.remote(num_cpus=0)(func) ref = f.remote(*args, **kwargs) return ray.get(ref) diff --git a/python/ray/_private/parameter.py b/python/ray/_private/parameter.py index 4303808609a48..ac727fc2dec01 100644 --- a/python/ray/_private/parameter.py +++ b/python/ray/_private/parameter.py @@ -72,8 +72,8 @@ class RayParams: be created. worker_path (str): The path of the source code that will be run by the worker. - setup_worker_path (str): The path of the Python file that will run - worker_setup_hook to set up the environment for the worker process. + setup_worker_path (str): The path of the Python file that will set up + the environment for the worker process. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. include_dashboard: Boolean flag indicating whether to start the web @@ -116,6 +116,7 @@ class RayParams: ray_debugger_external (bool): If true, make the Ray debugger for a worker available externally to the node it is running on. This will bind on 0.0.0.0 instead of localhost. + env_vars (dict): Override environment variables for the raylet. """ def __init__(self, @@ -168,7 +169,7 @@ def __init__(self, metrics_export_port=None, tracing_startup_hook=None, no_monitor=False, - lru_evict=False): + env_vars=None): self.object_ref_seed = object_ref_seed self.external_addresses = external_addresses self.redis_address = redis_address @@ -215,18 +216,11 @@ def __init__(self, self.start_initial_python_workers_for_first_job = ( start_initial_python_workers_for_first_job) self.ray_debugger_external = ray_debugger_external + self.env_vars = env_vars self._system_config = _system_config or {} self._enable_object_reconstruction = enable_object_reconstruction self._check_usage() - # Set the internal config options for LRU eviction. - if lru_evict: - raise DeprecationWarning( - "The lru_evict flag is deprecated as Ray natively " - "supports object spilling. Please read " - "https://docs.ray.io/en/master/memory-management.html#object-spilling " # noqa - "for more details.") - # Set the internal config options for object reconstruction. if enable_object_reconstruction: # Turn off object pinning. diff --git a/python/ray/_private/runtime_env/__init__.py b/python/ray/_private/runtime_env/__init__.py index 20401cb96f021..e69de29bb2d1d 100644 --- a/python/ray/_private/runtime_env/__init__.py +++ b/python/ray/_private/runtime_env/__init__.py @@ -1,3 +0,0 @@ -from ray._private.runtime_env.context import RuntimeEnvContext # noqa: F401 -from ray._private.runtime_env.validation import ( # noqa: F401 - override_task_or_actor_runtime_env, RuntimeEnvDict) # noqa: F401 diff --git a/python/ray/_private/runtime_env/conda.py b/python/ray/_private/runtime_env/conda.py index d9c810b89b75b..92bc3d8cb1139 100644 --- a/python/ray/_private/runtime_env/conda.py +++ b/python/ray/_private/runtime_env/conda.py @@ -12,9 +12,9 @@ from pathlib import Path import ray -from ray._private.runtime_env import RuntimeEnvContext from ray._private.runtime_env.conda_utils import (get_conda_activate_commands, get_or_create_conda_env) +from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.utils import (get_wheel_filename, get_master_wheel_url, get_release_wheel_url, try_to_create_directory) @@ -81,7 +81,7 @@ def get_conda_dict(runtime_env, runtime_env_dir) -> Optional[Dict[Any, Any]]: else: return None if runtime_env.get("pip"): - requirements_txt = runtime_env["pip"] + requirements_txt = "\n".join(runtime_env["pip"]) + "\n" pip_hash = hashlib.sha1(requirements_txt.encode("utf-8")).hexdigest() pip_hash_str = f"pip-generated-{pip_hash}" diff --git a/python/ray/_private/runtime_env/conda_utils.py b/python/ray/_private/runtime_env/conda_utils.py index 5d61c9e8c5f45..2339da036b60c 100644 --- a/python/ray/_private/runtime_env/conda_utils.py +++ b/python/ray/_private/runtime_env/conda_utils.py @@ -126,6 +126,21 @@ def get_or_create_conda_env(conda_env_path: str, return env_name +def get_conda_env_list() -> list: + """ + Get conda env list. + """ + conda_path = get_conda_bin_executable("conda") + try: + exec_cmd([conda_path, "--help"], throw_on_error=False) + except EnvironmentError: + raise ValueError(f"Could not find Conda executable at {conda_path}.") + _, stdout, _ = exec_cmd([conda_path, "env", "list", "--json"]) + envs = json.loads(stdout)["envs"] + print(f"Conda env len {len(envs)}") + return envs + + class ShellCommandException(Exception): pass diff --git a/python/ray/_private/runtime_env/context.py b/python/ray/_private/runtime_env/context.py index af3409f310ca5..c5db64437ce2d 100644 --- a/python/ray/_private/runtime_env/context.py +++ b/python/ray/_private/runtime_env/context.py @@ -4,9 +4,13 @@ import sys from typing import Dict, List, Optional +from ray.util.annotations import DeveloperAPI +from ray.core.generated.common_pb2 import Language + logger = logging.getLogger(__name__) +@DeveloperAPI class RuntimeEnvContext: """A context used to describe the created runtime env.""" @@ -31,10 +35,13 @@ def serialize(self) -> str: def deserialize(json_string): return RuntimeEnvContext(**json.loads(json_string)) - def exec_worker(self, passthrough_args: List[str]): + def exec_worker(self, passthrough_args: List[str], language: Language): os.environ.update(self.env_vars) - exec_command = " ".join([f"exec {self.py_executable}"] + - passthrough_args) + if language == Language.PYTHON: + executable = f"exec {self.py_executable}" + else: + executable = "exec" + exec_command = " ".join([executable] + passthrough_args) command_str = " && ".join(self.command_prefix + [exec_command]) logger.info(f"Exec'ing worker with command: {command_str}") os.execvp("bash", ["bash", "-c", command_str]) diff --git a/python/ray/_private/runtime_env/plugin.py b/python/ray/_private/runtime_env/plugin.py new file mode 100644 index 0000000000000..5e411c141fc08 --- /dev/null +++ b/python/ray/_private/runtime_env/plugin.py @@ -0,0 +1,70 @@ +from abc import ABC, abstractstaticmethod + +from ray.util.annotations import DeveloperAPI +from ray._private.runtime_env.context import RuntimeEnvContext + + +@DeveloperAPI +class RuntimeEnvPlugin(ABC): + @abstractstaticmethod + def validate(runtime_env_dict: dict) -> str: + """Validate user entry and returns a URI uniquely describing resource. + + This method will be called at ``f.options(runtime_env=...)`` or + ``ray.init(runtime_env=...)`` time and it should check the runtime env + dictionary for any errors. For example, it can raise "TypeError: + expected string for "conda" field". + + Args: + runtime_env_dict(dict): the entire dictionary passed in by user. + + Returns: + uri(str): a URI uniquely describing this resource (e.g., a hash of + the conda spec). + """ + raise NotImplementedError() + + def create(uri: str, runtime_env_dict: dict, + ctx: RuntimeEnvContext) -> float: + """Create and install the runtime environment. + + Gets called in the runtime env agent at install time. The URI can be + used as a caching mechanism. + + Args: + uri(str): a URI uniquely describing this resource. + runtime_env_dict(dict): the entire dictionary passed in by user. + ctx(RuntimeEnvContext): auxiliary information supplied by Ray. + + Returns: + the disk space taken up by this plugin installation for this + environment. e.g. for working_dir, this downloads the files to the + local node. + """ + return 0 + + def modify_context(uri: str, runtime_env_dict: dict, + ctx: RuntimeEnvContext) -> None: + """Modify context to change worker startup behavior. + + For example, you can use this to preprend "cd " command to worker + startup, or add new environment variables. + + Args: + uri(str): a URI uniquely describing this resource. + runtime_env_dict(dict): the entire dictionary passed in by user. + ctx(RuntimeEnvContext): auxiliary information supplied by Ray. + """ + return + + def delete(uri: str, ctx: RuntimeEnvContext) -> float: + """Delete the the runtime environment given uri. + + Args: + uri(str): a URI uniquely describing this resource. + ctx(RuntimeEnvContext): auxiliary information supplied by Ray. + + Returns: + the amount of space reclaimed by the deletion. + """ + return 0 diff --git a/python/ray/_private/runtime_env/validation.py b/python/ray/_private/runtime_env/validation.py index e113e4151424d..0e41bb6b30bd0 100644 --- a/python/ray/_private/runtime_env/validation.py +++ b/python/ray/_private/runtime_env/validation.py @@ -1,12 +1,15 @@ +import copy import json import logging import os from pathlib import Path import sys -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Set, Union import yaml import ray +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray._private.utils import import_attr # We need to setup this variable before # using this module @@ -20,19 +23,198 @@ GCS_STORAGE_MAX_SIZE = 100 * 1024 * 1024 # 100MiB -class RuntimeEnvDict: - """Parses and validates the runtime env dictionary from the user. +def parse_and_validate_working_dir(working_dir: str, + is_task_or_actor: bool = False) -> str: + """Parses and validates a user-provided 'working_dir' option. - Attributes: + The working_dir may not be specified per-task or per-actor. + + Otherwise, it should be a valid path to a local directory. + """ + assert working_dir is not None + + if is_task_or_actor: + raise NotImplementedError( + "Overriding working_dir for tasks and actors isn't supported. " + "Please use ray.init(runtime_env={'working_dir': ...}) " + "to configure the environment per-job instead.") + elif not isinstance(working_dir, str): + raise TypeError("`working_dir` must be a string, got " + f"{type(working_dir)}.") + elif not Path(working_dir).is_dir(): + raise ValueError( + f"working_dir {working_dir} is not a valid directory.") + + return working_dir + + +def parse_and_validate_conda(conda: Union[str, dict], + is_task_or_actor: bool = False + ) -> Union[str, dict]: + """Parses and validates a user-provided 'conda' option. + + Conda can be one of three cases: + 1) A dictionary describing the env. This is passed through directly. + 2) A string referring to a preinstalled conda env. + 3) A string pointing to a local conda YAML file. This is detected + by looking for a '.yaml' or '.yml' suffix. In this case, the file + will be read as YAML and passed through as a dictionary. + """ + assert conda is not None + + result = None + if sys.platform == "win32": + raise NotImplementedError("The 'conda' field in runtime_env " + "is not currently supported on " + "Windows.") + elif isinstance(conda, str): + yaml_file = Path(conda) + if yaml_file.suffix in (".yaml", ".yml"): + if not yaml_file.is_file(): + raise ValueError(f"Can't find conda YAML file {yaml_file}.") + try: + result = yaml.safe_load(yaml_file.read_text()) + except Exception as e: + raise ValueError( + f"Failed to read conda file {yaml_file}: {e}.") + else: + # Assume it's a pre-existing conda environment name. + result = conda + elif isinstance(conda, dict): + result = conda + else: + raise TypeError("runtime_env['conda'] must be of type str or " + f"dict, got {type(conda)}.") + + return result + + +def parse_and_validate_pip(pip: Union[str, List[str]], + is_task_or_actor: bool = False + ) -> Optional[List[str]]: + """Parses and validates a user-provided 'pip' option. + + Conda can be one of two cases: + 1) A List[str] describing the requirements. This is passed through. + 2) A string pointing to a local requirements file. In this case, the + file contents will be read split into a list. + """ + assert pip is not None + + result = None + if sys.platform == "win32": + raise NotImplementedError("The 'pip' field in runtime_env " + "is not currently supported on " + "Windows.") + elif isinstance(pip, str): + # We have been given a path to a requirements.txt file. + pip_file = Path(pip) + if not pip_file.is_file(): + raise ValueError(f"{pip_file} is not a valid file") + result = pip_file.read_text().strip().split("\n") + elif isinstance(pip, list) and all(isinstance(dep, str) for dep in pip): + if len(pip) == 0: + result = None + else: + result = pip + else: + raise TypeError("runtime_env['pip'] must be of type str or " + f"List[str], got {type(pip)}") + + return result + + +def parse_and_validate_uris(uris: List[str], + is_task_or_actor: bool = False) -> List[str]: + """Parses and validates a user-provided 'uris' option. + + These are passed through without validation (for now). + """ + assert uris is not None + return uris + + +def parse_and_validate_container(container: List[str], + is_task_or_actor: bool = False) -> List[str]: + """Parses and validates a user-provided 'container' option. + + This is passed through without validation (for now). + """ + assert container is not None + return container + + +def parse_and_validate_excludes(excludes: List[str], + is_task_or_actor: bool = False) -> List[str]: + """Parses and validates a user-provided 'excludes' option. + + This is validated to verify that it is of type List[str]. + + If an empty list is passed, we return `None` for consistency. + """ + assert excludes is not None + + if isinstance(excludes, list) and len(excludes) == 0: + return None + + if (isinstance(excludes, list) + and all(isinstance(path, str) for path in excludes)): + return excludes + else: + raise TypeError("runtime_env['excludes'] must be of type " + f"List[str], got {type(excludes)}") + + +def parse_and_validate_env_vars(env_vars: Dict[str, str], + is_task_or_actor: bool = False + ) -> Optional[Dict[str, str]]: + """Parses and validates a user-provided 'env_vars' option. + + This is validated to verify that all keys and vals are strings. + + If an empty dictionary is passed, we return `None` for consistency. + """ + assert env_vars is not None + if len(env_vars) == 0: + return None + + if not (isinstance(env_vars, dict) and all( + isinstance(k, str) and isinstance(v, str) + for (k, v) in env_vars.items())): + raise TypeError("runtime_env['env_vars'] must be of type " + "Dict[str, str]") + + return env_vars + + +# Dictionary mapping runtime_env options with the function to parse and +# validate them. +OPTION_TO_VALIDATION_FN = { + "working_dir": parse_and_validate_working_dir, + "excludes": parse_and_validate_excludes, + "conda": parse_and_validate_conda, + "pip": parse_and_validate_pip, + "uris": parse_and_validate_uris, + "env_vars": parse_and_validate_env_vars, + "container": parse_and_validate_container, +} + + +class ParsedRuntimeEnv(dict): + """An internal wrapper for runtime_env that is parsed and validated. + + This should be constructed from user-provided input (the API runtime_env) + and used everywhere that the runtime_env is passed around internally. + + All options in the resulting dictionary will have non-None values. + + Currently supported options: working_dir (Path): Specifies the working directory of the worker. This can either be a local directory or zip file. Examples: "." # cwd "local_project.zip" # archive is unpacked into directory - py_modules (List[Path]): Similar to working_dir, but specifies python - modules to add to the `sys.path`. - Examples: - ["/path/to/other_module", "/other_path/local_project.zip"] + uris (List[str]): A list of URIs that define the working_dir. pip (List[str] | str): Either a list of pip packages, or a string containing the path to a pip requirements.txt file. conda (dict | str): Either the conda YAML config, the name of a @@ -64,170 +246,136 @@ class RuntimeEnvDict: {"OMP_NUM_THREADS": "32", "TF_WARNINGS": "none"} """ + known_fields: Set[str] = { + "working_dir", "conda", "pip", "uris", "containers", "excludes", + "env_vars", "_ray_release", "_ray_commit", "_inject_current_ray", + "plugins" + } + def __init__(self, - runtime_env_json: dict, - working_dir: Optional[str] = None): - # Simple dictionary with all options validated. This will always - # contain all supported keys; values will be set to None if - # unspecified. However, if all values are None this is set to {}. - self._dict = dict() - - if "working_dir" in runtime_env_json: - self._dict["working_dir"] = runtime_env_json["working_dir"] - if not isinstance(self._dict["working_dir"], str): - raise TypeError("`working_dir` must be a string. Type " - f"{type(self._dict['working_dir'])} received.") - working_dir = Path(self._dict["working_dir"]).absolute() - else: - self._dict["working_dir"] = None - working_dir = Path(working_dir).absolute() if working_dir else None - - self._dict["conda"] = None - if "conda" in runtime_env_json: - if sys.platform == "win32": - raise NotImplementedError("The 'conda' field in runtime_env " - "is not currently supported on " - "Windows.") - conda = runtime_env_json["conda"] - if isinstance(conda, str): - yaml_file = Path(conda) - if yaml_file.suffix in (".yaml", ".yml"): - if working_dir and not yaml_file.is_absolute(): - yaml_file = working_dir / yaml_file - if not yaml_file.is_file(): - raise ValueError( - f"Can't find conda YAML file {yaml_file}") - try: - self._dict["conda"] = yaml.safe_load( - yaml_file.read_text()) - except Exception as e: - raise ValueError( - f"Invalid conda file {yaml_file} with error {e}") - else: - logger.info( - f"Using preinstalled conda environment: {conda}") - self._dict["conda"] = conda - elif isinstance(conda, dict): - self._dict["conda"] = conda - elif conda is not None: - raise TypeError("runtime_env['conda'] must be of type str or " - "dict") - - self._dict["pip"] = None - if "pip" in runtime_env_json: - if sys.platform == "win32": - raise NotImplementedError("The 'pip' field in runtime_env " - "is not currently supported on " - "Windows.") - if ("conda" in runtime_env_json - and runtime_env_json["conda"] is not None): - raise ValueError( - "The 'pip' field and 'conda' field of " - "runtime_env cannot both be specified.\n" - f"specified pip field: {runtime_env_json['pip']}\n" - f"specified conda field: {runtime_env_json['conda']}\n" - "To use pip with conda, please only set the 'conda' " - "field, and specify your pip dependencies " - "within the conda YAML config dict: see " - "https://conda.io/projects/conda/en/latest/" - "user-guide/tasks/manage-environments.html" - "#create-env-file-manually") - pip = runtime_env_json["pip"] - if isinstance(pip, str): - # We have been given a path to a requirements.txt file. - pip_file = Path(pip) - if working_dir and not pip_file.is_absolute(): - pip_file = working_dir / pip_file - if not pip_file.is_file(): - raise ValueError(f"{pip_file} is not a valid file") - self._dict["pip"] = pip_file.read_text() - elif isinstance(pip, list) and all( - isinstance(dep, str) for dep in pip): - # Construct valid pip requirements.txt from list of packages. - self._dict["pip"] = "\n".join(pip) + "\n" - else: - raise TypeError("runtime_env['pip'] must be of type str or " - "List[str]") - - if "uris" in runtime_env_json: - self._dict["uris"] = runtime_env_json["uris"] - - if "container" in runtime_env_json: - self._dict["container"] = runtime_env_json["container"] - - self._dict["env_vars"] = None - if "env_vars" in runtime_env_json: - env_vars = runtime_env_json["env_vars"] - self._dict["env_vars"] = env_vars - if not (isinstance(env_vars, dict) and all( - isinstance(k, str) and isinstance(v, str) - for (k, v) in env_vars.items())): - raise TypeError("runtime_env['env_vars'] must be of type" - "Dict[str, str]") - - if "_ray_release" in runtime_env_json: - self._dict["_ray_release"] = runtime_env_json["_ray_release"] - - if "_ray_commit" in runtime_env_json: - self._dict["_ray_commit"] = runtime_env_json["_ray_commit"] + runtime_env: Dict[str, Any], + is_task_or_actor: bool = False, + _validate: bool = True): + super().__init__() + + # Blindly trust that the runtime_env has already been validated. + # This is dangerous and should only be used internally (e.g., on the + # deserialization codepath. + if not _validate: + self.update(runtime_env) + return + + if runtime_env.get("conda") and runtime_env.get("pip"): + raise ValueError( + "The 'pip' field and 'conda' field of " + "runtime_env cannot both be specified.\n" + f"specified pip field: {runtime_env['pip']}\n" + f"specified conda field: {runtime_env['conda']}\n" + "To use pip with conda, please only set the 'conda' " + "field, and specify your pip dependencies " + "within the conda YAML config dict: see " + "https://conda.io/projects/conda/en/latest/" + "user-guide/tasks/manage-environments.html" + "#create-env-file-manually") + + for option, validate_fn in OPTION_TO_VALIDATION_FN.items(): + option_val = runtime_env.get(option) + if option_val is not None: + validated_option_val = validate_fn( + option_val, is_task_or_actor=is_task_or_actor) + if validated_option_val is not None: + self[option] = validated_option_val + + if "_ray_release" in runtime_env: + self["_ray_release"] = runtime_env["_ray_release"] + + if "_ray_commit" in runtime_env: + self["_ray_commit"] = runtime_env["_ray_commit"] else: - if self._dict.get("pip") or self._dict.get("conda"): - self._dict["_ray_commit"] = ray.__commit__ + if self.get("pip") or self.get("conda"): + self["_ray_commit"] = ray.__commit__ # Used for testing wheels that have not yet been merged into master. # If this is set to True, then we do not inject Ray into the conda # or pip dependencies. - if os.environ.get("RAY_RUNTIME_ENV_LOCAL_DEV_MODE"): - runtime_env_json["_inject_current_ray"] = True - if "_inject_current_ray" in runtime_env_json: - self._dict["_inject_current_ray"] = runtime_env_json[ - "_inject_current_ray"] + if "_inject_current_ray" in runtime_env: + self["_inject_current_ray"] = runtime_env["_inject_current_ray"] + elif "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ: + self["_inject_current_ray"] = True + + if "plugins" in runtime_env: + self["plugins"] = dict() + for class_path, plugin_field in runtime_env["plugins"].items(): + plugin_class: RuntimeEnvPlugin = import_attr(class_path) + if not issubclass(plugin_class, RuntimeEnvPlugin): + # TODO(simon): move the inferface to public once ready. + raise TypeError( + f"{class_path} must be inherit from " + "ray._private.runtime_env.plugin.RuntimeEnvPlugin.") + # TODO(simon): implement uri support. + _ = plugin_class.validate(runtime_env) + # Validation passed, add the entry to parsed runtime env. + self["plugins"][class_path] = plugin_field - # TODO(ekl) we should have better schema validation here. - # TODO(ekl) support py_modules - # TODO(architkulkarni) support docker + unknown_fields = ( + set(runtime_env.keys()) - ParsedRuntimeEnv.known_fields) + if len(unknown_fields): + logger.warning( + "The following unknown entries in the runtime_env dictionary " + f"will be ignored: {unknown_fields}. If you intended to use " + "them as plugins, they must be nested in the `plugins` field.") # TODO(architkulkarni) This is to make it easy for the worker caching # code in C++ to check if the env is empty without deserializing and # parsing it. We should use a less confusing approach here. - if all(val is None for val in self._dict.values()): + if all(val is None for val in self.values()): self._dict = {} - def get_parsed_dict(self) -> dict: - return self._dict + @classmethod + def deserialize(cls, serialized: str) -> "ParsedRuntimeEnv": + return cls(json.loads(serialized), _validate=False) def serialize(self) -> str: - # Use sort_keys=True because we will use the output as a key to cache - # workers by, so we need the serialization to be independent of the - # dict order. - return json.dumps(self._dict, sort_keys=True) - - def set_uris(self, uris): - self._dict["uris"] = uris + # Sort the keys we can compare the serialized string for equality. + return json.dumps(self, sort_keys=True) def override_task_or_actor_runtime_env( - runtime_env: Optional[Dict[str, Any]], - parent_runtime_env: Dict[str, Any]) -> Dict[str, Any]: - if runtime_env: - if runtime_env.get("working_dir"): - raise NotImplementedError( - "Overriding working_dir for actors is not supported. " - "Please use ray.init(runtime_env={'working_dir': ...}) " - "to configure per-job environment instead.") - # NOTE(edoakes): this is sort of hacky, but we pass in the parent - # working_dir here so the relative path to a requirements.txt file - # works. The right solution would be to merge the runtime_env with the - # parent runtime env before validation. - runtime_env_dict = RuntimeEnvDict( - runtime_env, working_dir=parent_runtime_env.get( - "working_dir")).get_parsed_dict() - else: - runtime_env_dict = {} + child_runtime_env: ParsedRuntimeEnv, + parent_runtime_env: ParsedRuntimeEnv) -> ParsedRuntimeEnv: + """Merge the given child runtime env with the parent runtime env. + + If running in a driver, the current runtime env comes from the + JobConfig. Otherwise, we are running in a worker for an actor or + task, and the current runtime env comes from the current TaskSpec. + + By default, the child runtime env inherits non-specified options from the + parent. There are two exceptions to this: + - working_dir is not inherited (only URIs). + - The env_vars dictionaries are merged, so environment variables + not specified by the child are still inherited from the parent. + + Returns: + The resulting merged ParsedRuntimeEnv. + """ + assert child_runtime_env is not None + assert parent_runtime_env is not None + + # Override environment variables. + result_env_vars = copy.deepcopy(parent_runtime_env.get("env_vars") or {}) + child_env_vars = child_runtime_env.get("env_vars") or {} + result_env_vars.update(child_env_vars) + + # Inherit all other non-specified options from the parent. + result = copy.deepcopy(parent_runtime_env) + result.update(child_runtime_env) + if len(result_env_vars) > 0: + result["env_vars"] = result_env_vars + if "working_dir" in result: + del result["working_dir"] # working_dir should not be in child env. - # If per-actor URIs aren't specified, override them with those in the - # job config. - if "uris" not in runtime_env_dict and "uris" in parent_runtime_env: - runtime_env_dict["uris"] = parent_runtime_env.get("uris") + # NOTE(architkulkarni): This allows worker caching code in C++ to + # check if a runtime env is empty without deserializing it. + assert all(val is not None for val in result.values()) - return runtime_env_dict + return result diff --git a/python/ray/_private/runtime_env/working_dir.py b/python/ray/_private/runtime_env/working_dir.py index 964cf4aafcf5d..e5034caf74a27 100644 --- a/python/ray/_private/runtime_env/working_dir.py +++ b/python/ray/_private/runtime_env/working_dir.py @@ -15,7 +15,7 @@ _internal_kv_initialized) from ray.job_config import JobConfig from ray._private.thirdparty.pathspec import PathSpec -from ray._private.runtime_env import RuntimeEnvContext +from ray._private.runtime_env.context import RuntimeEnvContext default_logger = logging.getLogger(__name__) diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 8d135129b78fa..4bf6c9bc9e055 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -21,6 +21,7 @@ import ray import ray.ray_constants as ray_constants import redis +from ray.core.generated.common_pb2 import Language # Import psutil and colorama after ray so the packaged version is used. import colorama @@ -398,6 +399,11 @@ def node_ip_address_from_perspective(address): def get_node_ip_address(address="8.8.8.8:53"): if ray.worker._global_node is not None: return ray.worker._global_node.node_ip_address + if sys.platform == "darwin": + # Due to the mac osx firewall, + # we use loopback ip as the ip address + # to prevent security popups. + return "127.0.0.1" return node_ip_address_from_perspective(address) @@ -866,7 +872,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, fate_share=fate_share, - port_denylist=port_denylist) + port_denylist=port_denylist, + listen_to_localhost_only=(node_ip_address == "127.0.0.1")) processes.append(p) redis_address = address(node_ip_address, port) primary_redis_client = redis.StrictRedis( @@ -922,7 +929,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, fate_share=fate_share, - port_denylist=port_denylist) + port_denylist=port_denylist, + listen_to_localhost_only=(node_ip_address == "127.0.0.1")) processes.append(p) shard_address = address(node_ip_address, redis_shard_port) @@ -944,7 +952,8 @@ def _start_redis_instance(executable, password=None, redis_max_memory=None, fate_share=None, - port_denylist=None): + port_denylist=None, + listen_to_localhost_only=False): """Start a single Redis server. Notes: @@ -970,6 +979,9 @@ def _start_redis_instance(executable, will start LRU eviction of entries. port_denylist (set): A set of denylist ports that shouldn't be used when allocating a new port. + listen_to_localhost_only (bool): Redis server only listens to + localhost (127.0.0.1) if it's true, + otherwise it listens to all network interfaces. Returns: A tuple of the port used by Redis and ProcessInfo for the process that @@ -990,6 +1002,8 @@ def _start_redis_instance(executable, raise ValueError("Spaces not permitted in redis password.") command += ["--requirepass", password] command += (["--port", str(port), "--loglevel", "warning"]) + if listen_to_localhost_only: + command += ["--bind", "127.0.0.1"] process_info = start_ray_process( command, ray_constants.PROCESS_TYPE_REDIS_SERVER, @@ -1347,7 +1361,8 @@ def start_raylet(redis_address, start_initial_python_workers_for_first_job=False, max_bytes=0, backup_count=0, - ray_debugger_external=False): + ray_debugger_external=False, + env_updates=None): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -1360,8 +1375,8 @@ def start_raylet(redis_address, to. worker_path (str): The path of the Python file that new worker processes will execute. - setup_worker_path (str): The path of the Python file that will run - worker_setup_hook to set up the environment for the worker process. + setup_worker_path (str): The path of the Python file that will set up + the environment for the worker process. temp_dir (str): The path of the temporary directory Ray will use. session_dir (str): The path of this session. resource_dir(str): The path of resource of this session . @@ -1393,6 +1408,8 @@ def start_raylet(redis_address, RotatingFileHandler's backupCount. ray_debugger_external (bool): True if the Ray debugger should be made available externally to this node. + env_updates (dict): Environment variable overrides. + Returns: ProcessInfo for the process that was started. """ @@ -1437,6 +1454,7 @@ def start_raylet(redis_address, redis_password, session_dir, node_ip_address, + setup_worker_path, ) else: java_worker_command = [] @@ -1567,7 +1585,8 @@ def check_should_start_agent(): use_perftools_profiler=("RAYLET_PERFTOOLS_PATH" in os.environ), stdout_file=stdout_file, stderr_file=stderr_file, - fate_share=fate_share) + fate_share=fate_share, + env_updates=env_updates) return process_info @@ -1591,6 +1610,7 @@ def build_java_worker_command( redis_password, session_dir, node_ip_address, + setup_worker_path, ): """This method assembles the command used to start a Java worker. @@ -1602,6 +1622,8 @@ def build_java_worker_command( redis_password (str): The password of connect to redis. session_dir (str): The path of this session. node_ip_address (str): The ip address for this node. + setup_worker_path (str): The path of the Python file that will set up + the environment for the worker process. Returns: The command string for starting Java worker. """ @@ -1626,7 +1648,9 @@ def build_java_worker_command( pairs.append(("ray.home", RAY_HOME)) pairs.append(("ray.logging.dir", os.path.join(session_dir, "logs"))) pairs.append(("ray.session-dir", session_dir)) - command = ["java"] + ["-D{}={}".format(*pair) for pair in pairs] + command = [sys.executable] + [setup_worker_path] + ["java"] + [ + "-D{}={}".format(*pair) for pair in pairs + ] # Add ray jars path to java classpath ray_jars = os.path.join(get_ray_jars_dir(), "*") @@ -1908,9 +1932,14 @@ def start_ray_client_server( ray_constants.SETUP_WORKER_FILENAME) command = [ - sys.executable, setup_worker_path, "-m", "ray.util.client.server", - f"--redis-address={redis_address}", f"--port={ray_client_server_port}", - f"--mode={server_type}" + sys.executable, + setup_worker_path, + "-m", + "ray.util.client.server", + f"--redis-address={redis_address}", + f"--port={ray_client_server_port}", + f"--mode={server_type}", + f"--language={Language.Name(Language.PYTHON)}", ] if redis_password: command.append(f"--redis-password={redis_password}") diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 50bb3d13c008b..8da4ac9f03e69 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -7,24 +7,25 @@ import pathlib import subprocess import sys -import tempfile import time import timeit import math import traceback -import datetime from typing import Optional, Any, List, Dict from contextlib import redirect_stdout, redirect_stderr import yaml import socket import pytest +import tempfile import ray import ray._private.services import ray._private.utils import ray._private.gcs_utils as gcs_utils +from ray._private.tls_utils import generate_self_signed_tls_certs from ray.util.queue import Queue, _QueueActor, Empty from ray.scripts.scripts import main as ray_main + try: from prometheus_client.parser import text_string_to_metric_families except (ImportError, ModuleNotFoundError): @@ -690,57 +691,11 @@ async def get_batch(self, return batch -def generate_self_signed_tls_certs(): - """Create self-signed key/cert pair for testing. - - This method requires the library ``cryptography`` be installed. - """ - try: - from cryptography import x509 - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import hashes, serialization - from cryptography.hazmat.primitives.asymmetric import rsa - from cryptography.x509.oid import NameOID - except ImportError: - raise ImportError( - "Using `Security.temporary` requires `cryptography`, please " - "install it using either pip or conda") - key = rsa.generate_private_key( - public_exponent=65537, key_size=2048, backend=default_backend()) - key_contents = key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ).decode() - - ray_interal = x509.Name( - [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) - # This is the same logic used by the GCS server to acquire a - # private/interal IP address to listen on. If we just use localhost + - # 127.0.0.1 then we won't be able to connect to the GCS and will get - # an error like "No match found for server name: 192.168.X.Y" - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.connect(("8.8.8.8", 80)) - private_ip_address = s.getsockname()[0] - s.close() - altnames = x509.SubjectAlternativeName([ - x509.DNSName(socket.gethostbyname( - socket.gethostname())), # Probably 127.0.0.1 - x509.DNSName("127.0.0.1"), - x509.DNSName(private_ip_address), # 192.168.*.* - x509.DNSName("localhost"), - ]) - now = datetime.datetime.utcnow() - cert = (x509.CertificateBuilder() - .subject_name(ray_interal).issuer_name(ray_interal).add_extension( - altnames, critical=False).public_key(key.public_key()) - .serial_number(x509.random_serial_number()).not_valid_before(now) - .not_valid_after(now + datetime.timedelta(days=365)).sign( - key, hashes.SHA256(), default_backend())) - - cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() - - return cert_contents, key_contents +def is_placement_group_removed(pg): + table = ray.util.placement_group_table(pg) + if "state" not in table: + return False + return table["state"] == "REMOVED" def setup_tls(): @@ -772,10 +727,3 @@ def teardown_tls(key_filepath, cert_filepath, temp_dir): del os.environ["RAY_TLS_SERVER_CERT"] del os.environ["RAY_TLS_SERVER_KEY"] del os.environ["RAY_TLS_CA_CERT"] - - -def is_placement_group_removed(pg): - table = ray.util.placement_group_table(pg) - if "state" not in table: - return False - return table["state"] == "REMOVED" diff --git a/python/ray/_private/tls_utils.py b/python/ray/_private/tls_utils.py new file mode 100644 index 0000000000000..8344d86c30c4b --- /dev/null +++ b/python/ray/_private/tls_utils.py @@ -0,0 +1,85 @@ +import datetime +import os +import socket + +import grpc + + +def generate_self_signed_tls_certs(): + """Create self-signed key/cert pair for testing. + + This method requires the library ``cryptography`` be installed. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + except ImportError: + raise ImportError( + "Using `Security.temporary` requires `cryptography`, please " + "install it using either pip or conda") + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend()) + key_contents = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + ray_interal = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) + # This is the same logic used by the GCS server to acquire a + # private/interal IP address to listen on. If we just use localhost + + # 127.0.0.1 then we won't be able to connect to the GCS and will get + # an error like "No match found for server name: 192.168.X.Y" + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + private_ip_address = s.getsockname()[0] + s.close() + altnames = x509.SubjectAlternativeName([ + x509.DNSName(socket.gethostbyname( + socket.gethostname())), # Probably 127.0.0.1 + x509.DNSName("127.0.0.1"), + x509.DNSName(private_ip_address), # 192.168.*.* + x509.DNSName("localhost"), + ]) + now = datetime.datetime.utcnow() + cert = (x509.CertificateBuilder().subject_name(ray_interal).issuer_name( + ray_interal).add_extension(altnames, critical=False).public_key( + key.public_key()).serial_number( + x509.random_serial_number()).not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=365)).sign( + key, hashes.SHA256(), default_backend())) + + cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() + + return cert_contents, key_contents + + +def add_port_to_grpc_server(server, address): + if os.environ.get("RAY_USE_TLS", "0") == "1": + server_cert_chain, private_key, ca_cert = load_certs_from_env() + credentials = grpc.ssl_server_credentials( + [(private_key, server_cert_chain)], + root_certificates=ca_cert, + require_client_auth=ca_cert is not None) + return server.add_secure_port(address, credentials) + else: + return server.add_insecure_port(address) + + +def load_certs_from_env(): + if os.environ.get("RAY_USE_TLS", "0") == "1": + with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: + server_cert_chain = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: + private_key = f.read() + if "RAY_TLS_CA_CERT" in os.environ: + with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: + ca_cert = f.read() + else: + ca_cert = None + + return server_cert_chain, private_key, ca_cert diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 37430d928dd92..50fe38ed65f74 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -27,6 +27,7 @@ import ray import ray._private.gcs_utils as gcs_utils import ray.ray_constants as ray_constants +from ray._private.tls_utils import load_certs_from_env # Import psutil after ray so the packaged version is used. import psutil @@ -1111,21 +1112,6 @@ def validate_namespace(namespace: str): "Pass None to not specify a namespace.") -def load_certs_from_env(): - if os.environ.get("RAY_USE_TLS", "0") == "1": - with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: - server_cert_chain = f.read() - with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: - private_key = f.read() - if "RAY_TLS_CA_CERT" in os.environ: - with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: - ca_cert = f.read() - else: - ca_cert = None - - return server_cert_chain, private_key, ca_cert - - def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] = None, asynchronous: bool = False): @@ -1142,15 +1128,3 @@ def init_grpc_channel(address: str, channel = grpc_module.insecure_channel(address, options=options) return channel - - -def add_port_to_grpc_server(server, address): - if os.environ.get("RAY_USE_TLS", "0") == "1": - server_cert_chain, private_key, ca_cert = load_certs_from_env() - credentials = grpc.ssl_server_credentials( - [(private_key, server_cert_chain)], - root_certificates=ca_cert, - require_client_auth=ca_cert is not None) - return server.add_secure_port(address, credentials) - else: - return server.add_insecure_port(address) diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 5c79c3b796459..4326d6cf943a9 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -114,7 +114,7 @@ cdef class CoreWorker: object async_event_loop object plasma_event_handler object job_config - object current_runtime_env_dict + object current_runtime_env c_bool is_local_mode cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata, diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index fdb9a7f51fef0..bbc064ec92938 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -7,18 +7,18 @@ from cpython.exc cimport PyErr_CheckSignals import asyncio -import copy import gc import inspect -import threading -import traceback -import time import logging +import msgpack import os import pickle +import setproctitle import sys +import threading +import time +import traceback import _thread -import setproctitle from libc.stdint cimport ( int32_t, @@ -100,13 +100,6 @@ from ray.includes.ray_config cimport RayConfig from ray.includes.global_state_accessor cimport CGlobalStateAccessor import ray -import ray._private.gcs_utils as gcs_utils -from ray import external_storage -from ray._private.async_compat import ( - sync_to_async, get_new_event_loop) -import ray._private.memory_monitor as memory_monitor -import ray.ray_constants as ray_constants -import ray._private.profiling as profiling from ray.exceptions import ( RayActorError, RayError, @@ -117,11 +110,15 @@ from ray.exceptions import ( TaskCancelledError, AsyncioActorExit, ) +from ray import external_storage +import ray.ray_constants as ray_constants +from ray._private.async_compat import sync_to_async, get_new_event_loop +from ray._private.client_mode_hook import disable_client_hook +import ray._private.gcs_utils as gcs_utils +from ray._private.runtime_env.validation import ParsedRuntimeEnv +import ray._private.memory_monitor as memory_monitor +import ray._private.profiling as profiling from ray._private.utils import decode -from ray._private.client_mode_hook import ( - disable_client_hook, -) -import msgpack cimport cpython @@ -1353,8 +1350,8 @@ cdef class CoreWorker: int64_t placement_group_bundle_index, c_bool placement_group_capture_child_tasks, c_string debugger_breakpoint, - runtime_env_dict, - override_environment_variables + c_string serialized_runtime_env, + runtime_env_uris, ): cdef: unordered_map[c_string, double] c_resources @@ -1362,15 +1359,10 @@ cdef class CoreWorker: c_vector[unique_ptr[CTaskArg]] args_vector CPlacementGroupID c_placement_group_id = \ placement_group_id.native() - c_string c_serialized_runtime_env - unordered_map[c_string, c_string] \ - c_override_environment_variables = \ - override_environment_variables + c_vector[c_string] c_runtime_env_uris = runtime_env_uris c_vector[CObjectReference] return_refs with self.profile_event(b"submit_task"): - c_serialized_runtime_env = \ - self.prepare_runtime_env(runtime_env_dict) prepare_resources(resources, &c_resources) ray_function = CRayFunction( language.lang, function_descriptor.descriptor) @@ -1383,8 +1375,8 @@ cdef class CoreWorker: ray_function, args_vector, CTaskOptions( name, num_returns, c_resources, b"", - c_serialized_runtime_env, - c_override_environment_variables), + serialized_runtime_env, + c_runtime_env_uris), max_retries, retry_exceptions, c_pair[CPlacementGroupID, int64_t]( c_placement_group_id, placement_group_bundle_index), @@ -1410,8 +1402,8 @@ cdef class CoreWorker: int64_t placement_group_bundle_index, c_bool placement_group_capture_child_tasks, c_string extension_data, - runtime_env_dict, - override_environment_variables + c_string serialized_runtime_env, + runtime_env_uris, ): cdef: CRayFunction ray_function @@ -1422,14 +1414,9 @@ cdef class CoreWorker: CActorID c_actor_id CPlacementGroupID c_placement_group_id = \ placement_group_id.native() - c_string c_serialized_runtime_env - unordered_map[c_string, c_string] \ - c_override_environment_variables = \ - override_environment_variables + c_vector[c_string] c_runtime_env_uris = runtime_env_uris with self.profile_event(b"submit_task"): - c_serialized_runtime_env = \ - self.prepare_runtime_env(runtime_env_dict) prepare_resources(resources, &c_resources) prepare_resources(placement_resources, &c_placement_resources) ray_function = CRayFunction( @@ -1449,8 +1436,8 @@ cdef class CoreWorker: c_placement_group_id, placement_group_bundle_index), placement_group_capture_child_tasks, - c_serialized_runtime_env, - c_override_environment_variables), + serialized_runtime_env, + c_runtime_env_uris), extension_data, &c_actor_id)) @@ -1725,12 +1712,11 @@ cdef class CoreWorker: return CCoreWorkerProcess.GetCoreWorker().GetOwnerAddress( c_object_id).SerializeAsString() - def serialize_and_promote_object_ref(self, ObjectRef object_ref): + def serialize_object_ref(self, ObjectRef object_ref): cdef: CObjectID c_object_id = object_ref.native() CAddress c_owner_address = CAddress() c_string serialized_object_status - CCoreWorkerProcess.GetCoreWorker().PromoteObjectToPlasma(c_object_id) CCoreWorkerProcess.GetCoreWorker().GetOwnershipInfo( c_object_id, &c_owner_address, &serialized_object_status) return (object_ref, @@ -1861,19 +1847,20 @@ cdef class CoreWorker: return (CCoreWorkerProcess.GetCoreWorker().GetWorkerContext() .CurrentActorIsAsync()) - def get_current_runtime_env_dict(self): + def get_current_runtime_env(self) -> ParsedRuntimeEnv: # This should never change, so we can safely cache it to avoid ser/de - if self.current_runtime_env_dict is None: + if self.current_runtime_env is None: if self.is_driver: - self.current_runtime_env_dict = \ - json.loads(self.get_job_config().serialized_runtime_env) + job_config = self.get_job_config() + serialized_env = job_config.runtime_env.serialized_runtime_env else: - self.current_runtime_env_dict = json.loads( - CCoreWorkerProcess.GetCoreWorker() - .GetWorkerContext() - .GetCurrentSerializedRuntimeEnv() - ) - return self.current_runtime_env_dict + serialized_env = CCoreWorkerProcess.GetCoreWorker() \ + .GetWorkerContext().GetCurrentSerializedRuntimeEnv() + + self.current_runtime_env = ParsedRuntimeEnv.deserialize( + serialized_env) + + return self.current_runtime_env def is_exiting(self): return CCoreWorkerProcess.GetCoreWorker().IsExiting() @@ -1901,6 +1888,26 @@ cdef class CoreWorker: return ref_counts + def get_actor_call_stats(self): + cdef: + unordered_map[c_string, c_vector[uint64_t]] c_tasks_count + + c_tasks_count = ( + CCoreWorkerProcess.GetCoreWorker().GetActorCallStats()) + it = c_tasks_count.begin() + + tasks_count = dict() + while it != c_tasks_count.end(): + func_name = dereference(it).first + counters = dereference(it).second + tasks_count[func_name] = { + "pending": counters[0], + "running": counters[1], + "finished": counters[2], + } + postincrement(it) + return tasks_count + def set_get_async_callback(self, ObjectRef object_ref, callback): cpython.Py_INCREF(callback) CCoreWorkerProcess.GetCoreWorker().GetAsync( @@ -1925,45 +1932,6 @@ cdef class CoreWorker: self.job_config.ParseFromString(c_job_config.SerializeAsString()) return self.job_config - def prepare_runtime_env(self, runtime_env_dict: dict) -> str: - """Merge the given new runtime env with the current runtime env. - - If running in a driver, the current runtime env comes from the - JobConfig. Otherwise, we are running in a worker for an actor or - task, and the current runtime env comes from the current TaskSpec. - - The child's runtime env dict is merged with the parents via a simple - dict update, except for runtime_env["env_vars"], which is merged - with runtime_env["env_vars"] of the parent rather than overwriting it. - This is so that env vars set in the parent propagate to child actors - and tasks even if a new env var is set in the child. - - Args: - runtime_env_dict (dict): A runtime env for a child actor or task. - Returns: - The resulting merged JSON-serialized runtime env. - """ - - result_dict = copy.deepcopy(self.get_current_runtime_env_dict()) - - result_env_vars = copy.deepcopy(result_dict.get("env_vars") or {}) - child_env_vars = runtime_env_dict.get("env_vars") or {} - result_env_vars.update(child_env_vars) - - result_dict.update(runtime_env_dict) - result_dict["env_vars"] = result_env_vars - - # NOTE(architkulkarni): This allows worker caching code in C++ to - # check if a runtime env is empty without deserializing it. - if result_dict["env_vars"] == {}: - result_dict["env_vars"] = None - if all(val is None for val in result_dict.values()): - result_dict = {} - - # TODO(architkulkarni): We should just use RuntimeEnvDict here - # so all the serialization and validation is done in one place - return json.dumps(result_dict, sort_keys=True) - def get_task_submission_stats(self): cdef: int64_t num_tasks_submitted diff --git a/python/ray/actor.py b/python/ray/actor.py index faec5fccc7dd7..f228389da72e0 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -5,7 +5,8 @@ import ray.ray_constants as ray_constants import ray._raylet import ray._private.signature as signature -import ray._private.runtime_env as runtime_support +from ray._private.runtime_env.validation import ( + override_task_or_actor_runtime_env, ParsedRuntimeEnv) import ray.worker from ray.util.annotations import PublicAPI from ray.util.placement_group import ( @@ -31,7 +32,7 @@ @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def method(*args, **kwargs): """Annotate an actor method. @@ -388,11 +389,17 @@ class DerivedActorClass(cls, modified_class): PythonFunctionDescriptor.from_class( modified_class.__ray_actor_class__) + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + new_runtime_env = ParsedRuntimeEnv( + runtime_env or {}, is_task_or_actor=True) + self.__ray_metadata__ = ActorClassMetadata( Language.PYTHON, modified_class, actor_creation_function_descriptor, class_id, max_restarts, max_task_retries, num_cpus, num_gpus, memory, object_store_memory, - resources, accelerator_type, runtime_env) + resources, accelerator_type, new_runtime_env) return self @@ -403,10 +410,15 @@ def _ray_from_function_descriptor( resources, accelerator_type, runtime_env): self = ActorClass.__new__(ActorClass) + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + new_runtime_env = ParsedRuntimeEnv( + runtime_env or {}, is_task_or_actor=True) self.__ray_metadata__ = ActorClassMetadata( language, None, actor_creation_function_descriptor, None, max_restarts, max_task_retries, num_cpus, num_gpus, memory, - object_store_memory, resources, accelerator_type, runtime_env) + object_store_memory, resources, accelerator_type, new_runtime_env) return self @@ -442,8 +454,7 @@ def options(self, placement_group="default", placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, - runtime_env=None, - override_environment_variables=None): + runtime_env=None): """Configures and overrides the actor instantiation parameters. The arguments are the same as those that can be passed @@ -464,6 +475,12 @@ def method(self): actor_cls = self + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + new_runtime_env = ParsedRuntimeEnv( + runtime_env or {}, is_task_or_actor=True) + class ActorOptionWrapper: def remote(self, *args, **kwargs): return actor_cls._remote( @@ -485,9 +502,7 @@ def remote(self, *args, **kwargs): placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), - runtime_env=runtime_env, - override_environment_variables=( - override_environment_variables)) + runtime_env=new_runtime_env) return ActorOptionWrapper() @@ -510,8 +525,7 @@ def _remote(self, placement_group="default", placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, - runtime_env=None, - override_environment_variables=None): + runtime_env=None): """Create an actor. This method allows more flexibility than the remote method because @@ -557,9 +571,6 @@ def _remote(self, this actor or task and its children (see :ref:`runtime-environments` for details). This API is in beta and may change before becoming stable. - override_environment_variables: Environment variables to override - and/or introduce for this actor. This is a dictionary mapping - variable names to their values. Returns: A handle to the newly created actor. @@ -584,7 +595,7 @@ def _remote(self, if max_concurrency < 1: raise ValueError("max_concurrency must be >= 1") - if client_mode_should_convert(): + if client_mode_should_convert(auto_init=True): return client_mode_convert_actor( self, args, @@ -605,9 +616,7 @@ def _remote(self, placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), - runtime_env=runtime_env, - override_environment_variables=( - override_environment_variables)) + runtime_env=runtime_env) worker = ray.worker.global_worker worker.check_connected() @@ -723,18 +732,16 @@ def _remote(self, creation_args = signature.flatten_args(function_signature, args, kwargs) - if runtime_env is None: + if runtime_env and not isinstance(runtime_env, ParsedRuntimeEnv): + runtime_env = ParsedRuntimeEnv(runtime_env) + elif isinstance(runtime_env, ParsedRuntimeEnv): + pass + else: runtime_env = meta.runtime_env - job_runtime_env = worker.core_worker.get_current_runtime_env_dict() - runtime_env_dict = runtime_support.override_task_or_actor_runtime_env( - runtime_env, job_runtime_env) - - if override_environment_variables: - logger.warning("override_environment_variables is deprecated and " - "will be removed in Ray 1.6. Please use " - ".options(runtime_env={'env_vars': {...}}).remote()" - "instead.") + parent_runtime_env = worker.core_worker.get_current_runtime_env() + parsed_runtime_env = override_task_or_actor_runtime_env( + runtime_env, parent_runtime_env) actor_id = worker.core_worker.create_actor( meta.language, @@ -754,9 +761,8 @@ def _remote(self, placement_group_capture_child_tasks, # Store actor_method_cpu in actor handle's extension data. extension_data=str(actor_method_cpu), - runtime_env_dict=runtime_env_dict, - override_environment_variables=override_environment_variables - or dict()) + serialized_runtime_env=parsed_runtime_env.serialize(), + runtime_env_uris=parsed_runtime_env.get("uris") or []) actor_handle = ActorHandle( meta.language, diff --git a/python/ray/autoscaler/_private/autoscaler.py b/python/ray/autoscaler/_private/autoscaler.py index 3b26b845e8070..b153cff37d259 100644 --- a/python/ray/autoscaler/_private/autoscaler.py +++ b/python/ray/autoscaler/_private/autoscaler.py @@ -165,6 +165,12 @@ def read_fn(): self.disable_node_updaters = self.config["provider"].get( "disable_node_updaters", False) + # Disable launch config checking if true. + # This is set in the fake_multinode situations where there isn't any + # meaningful node "type" to enforce. + self.disable_launch_config_check = self.config["provider"].get( + "disable_launch_config_check", False) + # Node launchers self.launch_queue = queue.Queue() self.pending_launches = ConcurrentCounter() @@ -485,7 +491,8 @@ def _report_pending_infeasible(self, unfulfilled: List[ResourceDict]): pending = [] infeasible = [] for bundle in unfulfilled: - placement_group = any("_group_" in k for k in bundle) + placement_group = any( + "_group_" in k or k == "bundle" for k in bundle) if placement_group: continue if self.resource_demand_scheduler.is_feasible(bundle): @@ -627,7 +634,6 @@ def _keep_worker_of_node_type(self, node_id: NodeID, Return KeepOrTerminate.decide_later otherwise. - Args: node_type_counts(Dict[NodeType, int]): The non_terminated node types counted so far. @@ -757,6 +763,8 @@ def reset(self, errors_fatal=False): "Error parsing config.") def launch_config_ok(self, node_id): + if self.disable_launch_config_check: + return True node_tags = self.provider.node_tags(node_id) tag_launch_conf = node_tags.get(TAG_RAY_LAUNCH_CONFIG) node_type = node_tags.get(TAG_RAY_USER_NODE_TYPE) diff --git a/python/ray/autoscaler/_private/docker.py b/python/ray/autoscaler/_private/docker.py index 8d94759549217..92dd16ad5001f 100644 --- a/python/ray/autoscaler/_private/docker.py +++ b/python/ray/autoscaler/_private/docker.py @@ -18,7 +18,7 @@ def _check_docker_file_mounts(file_mounts: Dict[str, str]) -> None: if Path(local).is_file(): cli_logger.warning( f"File Mount: ({remote}:{local}) refers to a file.\n To ensure" - "this mount updates properly, please use a directory.") + " this mount updates properly, please use a directory.") def validate_docker_config(config: Dict[str, Any]) -> None: diff --git a/python/ray/autoscaler/_private/fake_multi_node/__init__.py b/python/ray/autoscaler/_private/fake_multi_node/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/ray/autoscaler/_private/fake_multi_node/example.yaml b/python/ray/autoscaler/_private/fake_multi_node/example.yaml new file mode 100644 index 0000000000000..497f89647c25f --- /dev/null +++ b/python/ray/autoscaler/_private/fake_multi_node/example.yaml @@ -0,0 +1,55 @@ +# Example command to start a cluster with this config: +# +# RAY_FAKE_CLUSTER=1 ray start --autoscaling-config=example.yaml --head --block +# +# Alternatively, you can programmatically create a fake autoscaling cluster +# using ray.cluster_utils.AutoscalingCluster. +cluster_name: fake_multinode +max_workers: 8 +provider: + type: fake_multinode + use_node_id_as_ip: True + disable_node_updaters: True + disable_launch_config_check: True +available_node_types: + ray.head.default: + # You must set this manually to your "head" node resources!! The head + # node is launched via `ray start` and hence the autoscaler cannot + # configure its resources. The resources specified for its node type + # must line up with what Ray detects/is configured with on start. + resources: + CPU: 8 # <-- set this to num CPUs used/detected in `ray start` + GPU: 0 # <-- set this to num GPUs used/detected in `ray start` + node_config: {} + max_workers: 0 + ray.worker.cpu: + resources: + CPU: 1 + object_store_memory: 1000000000 + node_config: {} + min_workers: 0 + max_workers: 4 + ray.worker.gpu: + resources: + CPU: 4 + GPU: 1 + object_store_memory: 1000000000 + node_config: {} + min_workers: 0 + max_workers: 2 +head_node_type: ray.head.default +auth: {} +upscaling_speed: 1.0 +idle_timeout_minutes: 0.1 +docker: {} +initialization_commands: [] +setup_commands: [] +head_setup_commands: [] +worker_setup_commands: [] +head_start_ray_commands: [] +worker_start_ray_commands: [] +file_mounts: {} +cluster_synced_files: [] +file_mounts_sync_continuously: false +rsync_exclude: [] +rsync_filter: [] diff --git a/python/ray/autoscaler/_private/fake_multi_node/node_provider.py b/python/ray/autoscaler/_private/fake_multi_node/node_provider.py new file mode 100644 index 0000000000000..71650d845b1e5 --- /dev/null +++ b/python/ray/autoscaler/_private/fake_multi_node/node_provider.py @@ -0,0 +1,114 @@ +import logging +import os +import json + +import ray +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import (TAG_RAY_NODE_KIND, NODE_KIND_HEAD, + NODE_KIND_WORKER, TAG_RAY_USER_NODE_TYPE, + TAG_RAY_NODE_NAME, TAG_RAY_NODE_STATUS, + STATUS_UP_TO_DATE) + +logger = logging.getLogger(__name__) + +# We generate the node ids deterministically in the fake node provider, so that +# we can associate launched nodes with their resource reports. IDs increment +# starting with fffff*00000 for the head node, fffff*00001, etc. for workers. +FAKE_HEAD_NODE_ID = "fffffffffffffffffffffffffffffffffffffffffffffffffff00000" +FAKE_HEAD_NODE_TYPE = "ray.head.default" + + +class FakeMultiNodeProvider(NodeProvider): + """A node provider that implements multi-node on a single machine. + + This is used for laptop mode testing of autoscaling functionality.""" + + def __init__(self, provider_config, cluster_name): + NodeProvider.__init__(self, provider_config, cluster_name) + if "RAY_FAKE_CLUSTER" not in os.environ: + raise RuntimeError( + "FakeMultiNodeProvider requires ray to be started with " + "RAY_FAKE_CLUSTER=1 ray start ...") + self._nodes = { + FAKE_HEAD_NODE_ID: { + "tags": { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: FAKE_HEAD_NODE_TYPE, + TAG_RAY_NODE_NAME: FAKE_HEAD_NODE_ID, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + } + }, + } + self._next_node_id = 0 + + def _next_hex_node_id(self): + self._next_node_id += 1 + base = "fffffffffffffffffffffffffffffffffffffffffffffffffff" + return base + str(self._next_node_id).zfill(5) + + def non_terminated_nodes(self, tag_filters): + nodes = [] + for node_id in self._nodes: + tags = self.node_tags(node_id) + ok = True + for k, v in tag_filters.items(): + if tags.get(k) != v: + ok = False + if ok: + nodes.append(node_id) + return nodes + + def is_running(self, node_id): + return node_id in self._nodes + + def is_terminated(self, node_id): + return node_id not in self._nodes + + def node_tags(self, node_id): + return self._nodes[node_id]["tags"] + + def external_ip(self, node_id): + return node_id + + def internal_ip(self, node_id): + return node_id + + def set_node_tags(self, node_id, tags): + raise AssertionError("Readonly node provider cannot be updated") + + def create_node_with_resources(self, node_config, tags, count, resources): + node_type = tags[TAG_RAY_USER_NODE_TYPE] + next_id = self._next_hex_node_id() + ray_params = ray._private.parameter.RayParams( + min_worker_port=0, + max_worker_port=0, + dashboard_port=None, + num_cpus=resources.pop("CPU", 0), + num_gpus=resources.pop("GPU", 0), + object_store_memory=resources.pop("object_store_memory", None), + resources=resources, + redis_address="{}:6379".format( + ray._private.services.get_node_ip_address()), + env_vars={ + "RAY_OVERRIDE_NODE_ID_FOR_TESTING": next_id, + "RAY_OVERRIDE_RESOURCES": json.dumps(resources), + }) + node = ray.node.Node( + ray_params, head=False, shutdown_at_exit=False, spawn_reaper=False) + self._nodes[next_id] = { + "tags": { + TAG_RAY_NODE_KIND: NODE_KIND_WORKER, + TAG_RAY_USER_NODE_TYPE: node_type, + TAG_RAY_NODE_NAME: next_id, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + }, + "node": node + } + + def terminate_node(self, node_id): + node = self._nodes.pop(node_id)["node"] + node.kill_all_processes(check_alive=False, allow_graceful=True) + + @staticmethod + def bootstrap_config(cluster_config): + return cluster_config diff --git a/python/ray/autoscaler/_private/gcp/node.py b/python/ray/autoscaler/_private/gcp/node.py index 93a9933ddc186..69a456ac56c0e 100644 --- a/python/ray/autoscaler/_private/gcp/node.py +++ b/python/ray/autoscaler/_private/gcp/node.py @@ -437,8 +437,26 @@ def create_instance(self, "name": name }) + # Allow Google Compute Engine instance templates. + # + # Config example: + # + # ... + # node_config: + # sourceInstanceTemplate: global/instanceTemplates/worker-16 + # machineType: e2-standard-16 + # ... + # + # node_config parameters override matching template parameters, if any. + # + # https://cloud.google.com/compute/docs/instance-templates + # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert + source_instance_template = config.pop("sourceInstanceTemplate", None) + operation = self.resource.instances().insert( - project=self.project_id, zone=self.availability_zone, + project=self.project_id, + zone=self.availability_zone, + sourceInstanceTemplate=source_instance_template, body=config).execute() if wait_for_operation: diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index 172bd5b74b57d..b19a8d04c7032 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -27,6 +27,8 @@ from ray.autoscaler._private.load_metrics import LoadMetrics from ray.autoscaler._private.constants import \ AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE +from ray.autoscaler._private.fake_multi_node.node_provider import \ + FAKE_HEAD_NODE_ID from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS, \ DEBUG_AUTOSCALING_ERROR, format_readonly_node_type @@ -162,7 +164,10 @@ def __init__(self, head_node_ip = redis_address.split(":")[0] self.redis_address = redis_address self.redis_password = redis_password - self.load_metrics = LoadMetrics(local_ip=head_node_ip) + if os.environ.get("RAY_FAKE_CLUSTER"): + self.load_metrics = LoadMetrics(local_ip=FAKE_HEAD_NODE_ID) + else: + self.load_metrics = LoadMetrics(local_ip=head_node_ip) self.last_avail_resources = None self.event_summarizer = EventSummarizer() self.prefix_cluster_info = prefix_cluster_info @@ -223,7 +228,7 @@ def update_load_metrics(self): request = gcs_service_pb2.GetAllResourceUsageRequest() response = self.gcs_node_resources_stub.GetAllResourceUsage( - request, timeout=4) + request, timeout=60) resources_batch_data = response.resource_usage_data # Tell the readonly node provider what nodes to report. @@ -244,8 +249,7 @@ def update_load_metrics(self): resource_message.node_id.hex()) resources = {} for k, v in resource_message.resources_total.items(): - if not k.startswith("node:"): - resources[k] = v + resources[k] = v mirror_node_types[node_type] = { "resources": resources, "node_config": {}, diff --git a/python/ray/autoscaler/_private/node_launcher.py b/python/ray/autoscaler/_private/node_launcher.py index aad9cee326973..803b8df37d881 100644 --- a/python/ray/autoscaler/_private/node_launcher.py +++ b/python/ray/autoscaler/_private/node_launcher.py @@ -47,6 +47,8 @@ def _launch_node(self, config: Dict[str, Any], count: int, if node_type: launch_config.update( config["available_node_types"][node_type]["node_config"]) + resources = copy.deepcopy( + config["available_node_types"][node_type]["resources"]) launch_hash = hash_launch_conf(launch_config, config["auth"]) self.log("Launching {} nodes, type {}.".format(count, node_type)) node_config = copy.deepcopy(config.get("worker_nodes", {})) @@ -64,7 +66,8 @@ def _launch_node(self, config: Dict[str, Any], count: int, node_tags[TAG_RAY_USER_NODE_TYPE] = node_type node_config.update(launch_config) launch_start_time = time.time() - self.provider.create_node(node_config, node_tags, count) + self.provider.create_node_with_resources(node_config, node_tags, count, + resources) launch_time = time.time() - launch_start_time for _ in range(count): # Note: when launching multiple nodes we observe the time it diff --git a/python/ray/autoscaler/_private/providers.py b/python/ray/autoscaler/_private/providers.py index e60eb441e1414..343350817f512 100644 --- a/python/ray/autoscaler/_private/providers.py +++ b/python/ray/autoscaler/_private/providers.py @@ -56,6 +56,12 @@ def _import_readonly(provider_config): return ReadOnlyNodeProvider +def _import_fake_multinode(provider_config): + from ray.autoscaler._private.fake_multi_node.node_provider import \ + FakeMultiNodeProvider + return FakeMultiNodeProvider + + def _import_kubernetes(provider_config): from ray.autoscaler._private._kubernetes.node_provider import \ KubernetesNodeProvider @@ -117,6 +123,7 @@ def _import_external(provider_config): _NODE_PROVIDERS = { "local": _import_local, + "fake_multinode": _import_fake_multinode, "readonly": _import_readonly, "aws": _import_aws, "gcp": _import_gcp, @@ -129,6 +136,7 @@ def _import_external(provider_config): _PROVIDER_PRETTY_NAMES = { "readonly": "Readonly (Manual Cluster Setup)", + "fake_multinode": "Fake Multinode", "local": "Local", "aws": "AWS", "gcp": "GCP", diff --git a/python/ray/autoscaler/_private/resource_demand_scheduler.py b/python/ray/autoscaler/_private/resource_demand_scheduler.py index 517f49f63281c..f055a01769714 100644 --- a/python/ray/autoscaler/_private/resource_demand_scheduler.py +++ b/python/ray/autoscaler/_private/resource_demand_scheduler.py @@ -116,7 +116,8 @@ def is_feasible(self, bundle: ResourceDict) -> bool: for node_type, config in self.node_types.items(): max_of_type = config.get("max_workers", 0) node_resources = config["resources"] - if max_of_type > 0 and _fits(node_resources, bundle): + if (node_type == self.head_node_type or max_of_type > 0) and _fits( + node_resources, bundle): return True return False @@ -764,7 +765,11 @@ def _utilization_score(node_resources: ResourceDict, return None fittable = [] + resource_types = set() for r in resources: + for k, v in r.items(): + if v > 0: + resource_types.add(k) if _fits(remaining, r): fittable.append(r) _inplace_subtract(remaining, r) @@ -772,12 +777,15 @@ def _utilization_score(node_resources: ResourceDict, return None util_by_resources = [] + num_matching_resource_types = 0 for k, v in node_resources.items(): # Don't divide by zero. if v < 1: # Could test v == 0 on the nose, but v < 1 feels safer. # (Note that node resources are integers.) continue + if k in resource_types: + num_matching_resource_types += 1 util = (v - remaining[k]) / v util_by_resources.append(v * (util**3)) @@ -785,9 +793,11 @@ def _utilization_score(node_resources: ResourceDict, if not util_by_resources: return None - # Prioritize using all resources first, then prioritize overall balance + # Prioritize matching multiple resource types first, then prioritize + # using all resources, then prioritize overall balance # of multiple resources. - return (min(util_by_resources), np.mean(util_by_resources)) + return (num_matching_resource_types, min(util_by_resources), + np.mean(util_by_resources)) def get_bin_pack_residual(node_resources: List[ResourceDict], @@ -818,7 +828,16 @@ def get_bin_pack_residual(node_resources: List[ResourceDict], nodes = copy.deepcopy(node_resources) # List of nodes that cannot be used again due to strict spread. used = [] - for demand in resource_demands: + # We order the resource demands in the following way: + # More complex demands first. + # Break ties: heavier demands first. + # Break ties: lexicographically (to ensure stable ordering). + for demand in sorted( + resource_demands, + key=lambda demand: (len(demand.values()), + sum(demand.values()), + sorted(demand.items())), + reverse=True): found = False node = None for i in range(len(nodes)): diff --git a/python/ray/autoscaler/gcp/tpu.yaml b/python/ray/autoscaler/gcp/tpu.yaml index 34726cb2205b4..a963e62c1898d 100644 --- a/python/ray/autoscaler/gcp/tpu.yaml +++ b/python/ray/autoscaler/gcp/tpu.yaml @@ -32,9 +32,9 @@ available_node_types: # Support for TPU pods will be added in the future. acceleratorType: v2-8 runtimeVersion: v2-alpha - # Uncomment to use preemptible TPUs - # schedulingConfig: - # preemptible: true + schedulingConfig: + # Set to false to use non-preemptible TPUs + preemptible: true provider: type: gcp @@ -51,15 +51,21 @@ head_node_type: ray_head_default # Compute instances have python 3.7, but TPUs have 3.8 - need to update # Install Jax and other dependencies on the Compute head node head_setup_commands: - - conda create -y -n "ray" python=3.8.5 && sudo update-alternatives --install /opt/conda/bin/python python /opt/conda/envs/ray/bin/python 10 && sudo update-alternatives --install /opt/conda/bin/pip pip /opt/conda/envs/ray/bin/pip 10 - - export PATH="$PATH:/opt/conda/envs/ray/bin" && echo 'export PATH="$PATH:/opt/conda/envs/ray/bin"' >> ~/.bashrc - - python -m pip install --upgrade "jax[cpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + # Two first lines are a workaround for ssh timing out + - sleep 2 + - sleep 2 + - sudo chown -R $(whoami) /opt/conda/* + - conda create -y -n "ray" python=3.8.5 + - conda activate ray && echo 'conda activate ray' >> ~/.bashrc + - python -m pip install --upgrade pip + - python -m pip install --upgrade "jax[cpu]==0.2.14" - python -m pip install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku google-api-python-client cryptography tensorboardX ray[default] - python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl - git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && python -m pip install . # Install Jax and other dependencies on TPU worker_setup_commands: + - pip3 install --upgrade pip - pip3 install --upgrade "jax[tpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip3 install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku tensorboardX ray[default] - python3 -c "import jax; jax.device_count(); jax.numpy.add(1, 1)" # test if Jax has been installed correctly diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index c912cd772456d..3340910592ba3 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -124,6 +124,18 @@ def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str], """ raise NotImplementedError + def create_node_with_resources( + self, node_config: Dict[str, Any], tags: Dict[str, str], + count: int, + resources: Dict[str, float]) -> Optional[Dict[str, Any]]: + """Create nodes with a given resource config. + + This is the method actually called by the autoscaler. Prefer to + implement this when possible directly, otherwise it delegates to the + create_node() implementation. + """ + return self.create_node(node_config, tags, count) + def set_node_tags(self, node_id: str, tags: Dict[str, str]) -> None: """Sets the tag values (string dict) for the specified node.""" raise NotImplementedError diff --git a/python/ray/autoscaler/ray-schema.json b/python/ray/autoscaler/ray-schema.json index 64e67c7c8f42b..3e9c791720f04 100644 --- a/python/ray/autoscaler/ray-schema.json +++ b/python/ray/autoscaler/ray-schema.json @@ -56,7 +56,7 @@ }, "idle_timeout_minutes": { "description": "If a node is idle for this many minutes, it will be removed.", - "type": "integer", + "type": "number", "minimum": 0 }, "provider": { diff --git a/python/ray/cluster_utils.py b/python/ray/cluster_utils.py index 965492b6eafcf..80ef1f0ecd40a 100644 --- a/python/ray/cluster_utils.py +++ b/python/ray/cluster_utils.py @@ -1,4 +1,9 @@ import logging +import json +import yaml +import os +import subprocess +import tempfile import time import ray @@ -8,6 +13,70 @@ logger = logging.getLogger(__name__) +class AutoscalingCluster: + """Create a local autoscaling cluster for testing. + + See test_autoscaler_fake_multinode.py for an end-to-end example. + """ + + def __init__(self, head_resources: dict, worker_node_types: dict): + """Create the cluster. + + Args: + head_resources: resources of the head node, including CPU. + worker_node_types: autoscaler node types config for worker nodes. + """ + base_config = yaml.safe_load( + open( + os.path.join( + os.path.dirname(ray.__file__), + "autoscaler/_private/fake_multi_node/example.yaml"))) + base_config["available_node_types"] = worker_node_types + base_config["available_node_types"]["ray.head.default"] = { + "resources": head_resources, + "node_config": {}, + "max_workers": 0, + } + self._head_resources = head_resources + self._config = base_config + self._process = None + + def start(self): + """Start the cluster. + + After this call returns, you can connect to the cluster with + ray.init("auto"). + """ + subprocess.check_call(["ray", "stop", "--force"]) + fake_config = tempfile.mktemp() + with open(fake_config, "w") as f: + f.write(json.dumps(self._config)) + cmd = [ + "ray", "start", "--autoscaling-config={}".format(fake_config), + "--head", "--block" + ] + if "CPU" in self._head_resources: + cmd.append("--num-cpus={}".format(self._head_resources.pop("CPU"))) + if "GPU" in self._head_resources: + cmd.append("--num-gpus={}".format(self._head_resources.pop("GPU"))) + if self._head_resources: + cmd.append("--resources='{}'".format( + json.dumps(self._head_resources))) + env = os.environ.copy() + env.update({ + "AUTOSCALER_UPDATE_INTERVAL_S": "1", + "RAY_FAKE_CLUSTER": "1" + }) + self._process = subprocess.Popen(cmd, env=env) + time.sleep(5) # TODO(ekl) wait for it properly + + def shutdown(self): + """Terminate the cluster.""" + if self._process: + self._process.kill() + subprocess.check_call(["ray", "stop", "--force"]) + + class Cluster: def __init__(self, initialize_head=False, diff --git a/python/ray/cross_language.py b/python/ray/cross_language.py index eae3939ffa11e..f3ad1de0a3903 100644 --- a/python/ray/cross_language.py +++ b/python/ray/cross_language.py @@ -79,7 +79,8 @@ def java_function(class_name, function_name): None, # max_calls, None, # max_retries, None, # retry_exceptions, - None) # runtime_env + None, # runtime_env + None) # placement_group @PublicAPI(stability="beta") diff --git a/python/ray/data/__init__.py b/python/ray/data/__init__.py index 521add717220c..c6e411fadf86b 100644 --- a/python/ray/data/__init__.py +++ b/python/ray/data/__init__.py @@ -1,7 +1,8 @@ from ray.data.read_api import from_items, range, range_arrow, \ range_tensor, read_parquet, read_json, read_csv, read_binary_files, \ - from_dask, from_modin, from_mars, from_pandas, from_numpy, from_arrow, \ - from_spark, read_datasource, read_numpy, read_text + from_dask, from_modin, from_mars, from_pandas, from_pandas_refs, \ + from_numpy, from_arrow, from_arrow_refs, from_spark, read_datasource, \ + read_numpy, read_text from ray.data.datasource import Datasource, ReadTask from ray.data.dataset import Dataset from ray.data.impl.progress_bar import set_progress_bars @@ -18,10 +19,12 @@ "from_dask", "from_items", "from_arrow", + "from_arrow_refs", "from_mars", "from_modin", "from_numpy", "from_pandas", + "from_pandas_refs", "from_spark", "range", "range_arrow", diff --git a/python/ray/data/block.py b/python/ray/data/block.py index 35b99780c5e0d..e7edab74863ad 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -16,8 +16,8 @@ # Represents a batch of records to be stored in the Ray object store. # # Block data can be accessed in a uniform way via ``BlockAccessors`` such as -# ``SimpleBlockAccessor``, ``ArrowBlockAccessor``, and ``TensorBlockAccessor``. -Block = Union[List[T], np.ndarray, "pyarrow.Table", bytes] +# ``SimpleBlockAccessor`` and ``ArrowBlockAccessor``. +Block = Union[List[T], "pyarrow.Table", bytes] @DeveloperAPI @@ -52,8 +52,8 @@ class BlockAccessor(Generic[T]): as a top-level Ray object, without a wrapping class (issue #17186). There are three types of block accessors: ``SimpleBlockAccessor``, which - operates over a plain Python list, ``ArrowBlockAccessor``, for - ``pyarrow.Table`` type blocks, and ``TensorBlockAccessor``, for tensors. + operates over a plain Python list, and ``ArrowBlockAccessor`` for + ``pyarrow.Table`` type blocks. """ def num_rows(self) -> int: @@ -85,12 +85,16 @@ def to_pandas(self) -> "pandas.DataFrame": """Convert this block into a Pandas dataframe.""" raise NotImplementedError - def to_numpy(self) -> np.ndarray: - """Convert this block into a NumPy ndarray.""" + def to_numpy(self, column: str = None) -> np.ndarray: + """Convert this block (or column of block) into a NumPy ndarray. + + Args: + column: Name of column to convert, or None. + """ raise NotImplementedError - def to_arrow(self) -> Union["pyarrow.Table", "pyarrow.Tensor"]: - """Convert this block into an Arrow table or tensor.""" + def to_arrow(self) -> "pyarrow.Table": + """Convert this block into an Arrow table.""" raise NotImplementedError def size_bytes(self) -> int: @@ -136,10 +140,6 @@ def for_block(block: Block) -> "BlockAccessor[T]": from ray.data.impl.simple_block import \ SimpleBlockAccessor return SimpleBlockAccessor(block) - elif isinstance(block, np.ndarray): - from ray.data.impl.tensor_block import \ - TensorBlockAccessor - return TensorBlockAccessor(block) else: raise TypeError("Not a block type: {}".format(block)) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 11d0a13c9cbae..0b8a7fad6ca50 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -51,8 +51,7 @@ class Dataset(Generic[T]): Datasets are implemented as a list of ``ObjectRef[Block]``. The block also determines the unit of parallelism. The default block type is the - ``pyarrow.Table``. Tensor objects are held in ``np.ndarray`` blocks, - and other Arrow-incompatible objects are held in ``list`` blocks. + ``pyarrow.Table``. Arrow-incompatible objects are held in ``list`` blocks. Since Datasets are just lists of Ray object refs, they can be passed between Ray tasks and actors just like any other object. Datasets support @@ -169,7 +168,7 @@ def map_batches(self, tasks, or "actors" to use an autoscaling Ray actor pool. batch_format: Specify "native" to use the native block format, "pandas" to select ``pandas.DataFrame`` as the batch format, - or "pyarrow" to select ``pyarrow.Table/Tensor``. + or "pyarrow" to select ``pyarrow.Table``. ray_remote_args: Additional resource requirements to request from ray (e.g., num_gpus=1 to request GPUs for the map tasks). """ @@ -205,19 +204,15 @@ def transform(block: Block) -> Block: "or 'pyarrow', got: {}".format(batch_format)) applied = fn(view) - if (isinstance(applied, list) or isinstance(applied, pa.Table) - or isinstance(applied, np.ndarray)): + if isinstance(applied, list) or isinstance(applied, pa.Table): applied = applied elif isinstance(applied, pd.core.frame.DataFrame): applied = pa.Table.from_pandas(applied) - elif isinstance(applied, pa.Tensor): - applied = applied.to_numpy() else: raise ValueError("The map batches UDF returned a type " f"{type(applied)}, which is not allowed. " "The return type must be either list, " - "pandas.DataFrame, np.ndarray, " - "pyarrow.Tensor, or pyarrow.Table") + "pandas.DataFrame, or pyarrow.Table") builder.add_block(applied) return builder.build() @@ -352,8 +347,13 @@ def random_shuffle( Returns: The shuffled dataset. """ + curr_num_blocks = self.num_blocks() + # Handle empty dataset. + if curr_num_blocks == 0: + return self + if num_blocks is None: - num_blocks = self.num_blocks() + num_blocks = curr_num_blocks new_blocks = simple_shuffle( self._move_blocks() if _move else self._blocks, num_blocks, @@ -402,24 +402,150 @@ def split(self, if n <= 0: raise ValueError(f"The number of splits {n} is not positive.") - if n > self.num_blocks() and equal: - raise NotImplementedError( - f"The number of splits {n} > the number of dataset blocks " - f"{self.num_blocks()}, yet an equal split was requested.") - if locality_hints and len(locality_hints) != n: raise ValueError( f"The length of locality_hints {len(locality_hints)} " "doesn't equal the number of splits {n}.") - # TODO(ekl) we could do better than truncation here. This could be a - # problem if block sizes are very skewed. - def equalize(splits: List[Dataset[T]]) -> List[Dataset[T]]: + def _partition_splits(splits: List[Dataset[T]], part_size: int, + counts_cache: Dict[str, int]): + """Partition splits into two sets: splits that are smaller than the + target size and splits that are larger than the target size. + """ + splits = sorted(splits, key=lambda s: counts_cache[s._get_uuid()]) + idx = next(i for i, split in enumerate(splits) + if counts_cache[split._get_uuid()] >= part_size) + return splits[:idx], splits[idx:] + + def _equalize_larger_splits(splits: List[Dataset[T]], target_size: int, + counts_cache: Dict[str, int], + num_splits_required: int): + """Split each split into one or more subsplits that are each the + target size, with at most one leftover split that's smaller + than the target size. + + This assume that the given splits are sorted in ascending order. + """ + new_splits = [] + leftovers = [] + for split in splits: + size = counts_cache[split._get_uuid()] + if size == target_size: + new_splits.append(split) + continue + split_indices = list(range(target_size, size, target_size)) + split_splits = split.split_at_indices(split_indices) + last_split_size = split_splits[-1].count() + if last_split_size < target_size: + # Last split is smaller than the target size, save it for + # our unioning of small splits. + leftover = split_splits.pop() + leftovers.append(leftover) + counts_cache[leftover._get_uuid()] = leftover.count() + if len(new_splits) + len(split_splits) >= num_splits_required: + # Short-circuit if the new splits will make us reach the + # desired number of splits. + new_splits.extend( + split_splits[:num_splits_required - len(new_splits)]) + break + new_splits.extend(split_splits) + return new_splits, leftovers + + def _equalize_smaller_splits( + splits: List[Dataset[T]], target_size: int, + counts_cache: Dict[str, int], num_splits_required: int): + """Union small splits up to the target split size. + + This assume that the given splits are sorted in ascending order. + """ + new_splits = [] + union_buffer = [] + union_buffer_size = 0 + low = 0 + high = len(splits) - 1 + while low <= high: + # Union small splits up to the target split size. + low_split = splits[low] + low_count = counts_cache[low_split._get_uuid()] + high_split = splits[high] + high_count = counts_cache[high_split._get_uuid()] + if union_buffer_size + high_count <= target_size: + # Try to add the larger split to the union buffer first. + union_buffer.append(high_split) + union_buffer_size += high_count + high -= 1 + elif union_buffer_size + low_count <= target_size: + union_buffer.append(low_split) + union_buffer_size += low_count + low += 1 + else: + # Neither the larger nor smaller split fit in the union + # buffer, so we split the smaller split into a subsplit + # that will fit into the union buffer and a leftover + # subsplit that we add back into the candidate split list. + diff = target_size - union_buffer_size + diff_split, new_low_split = low_split.split_at_indices( + [diff]) + union_buffer.append(diff_split) + union_buffer_size += diff + # We overwrite the old low split and don't advance the low + # pointer since (1) the old low split can be discarded, + # (2) the leftover subsplit is guaranteed to be smaller + # than the old low split, and (3) the low split should be + # the smallest split in the candidate split list, which is + # this subsplit. + splits[low] = new_low_split + counts_cache[new_low_split._get_uuid()] = low_count - diff + if union_buffer_size == target_size: + # Once the union buffer is full, we union together the + # splits. + assert len(union_buffer) > 1, union_buffer + first_ds = union_buffer[0] + new_split = first_ds.union(*union_buffer[1:]) + new_splits.append(new_split) + # Clear the union buffer. + union_buffer = [] + union_buffer_size = 0 + if len(new_splits) == num_splits_required: + # Short-circuit if we've reached the desired number of + # splits. + break + return new_splits + + def equalize(splits: List[Dataset[T]], + num_splits: int) -> List[Dataset[T]]: if not equal: return splits - lower_bound = min([s.count() for s in splits]) - assert lower_bound > 0, splits - return [s.limit(lower_bound) for s in splits] + counts = {s._get_uuid(): s.count() for s in splits} + total_rows = sum(counts.values()) + # Number of rows for each split. + target_size = total_rows // num_splits + + # Partition splits. + smaller_splits, larger_splits = _partition_splits( + splits, target_size, counts) + if len(smaller_splits) == 0 and num_splits < len(splits): + # All splits are already equal. + return splits + + # Split larger splits. + new_splits, leftovers = _equalize_larger_splits( + larger_splits, target_size, counts, num_splits) + # Short-circuit if we've already reached the desired number of + # splits. + if len(new_splits) == num_splits: + return new_splits + # Add leftovers to small splits and re-sort. + smaller_splits += leftovers + smaller_splits = sorted( + smaller_splits, key=lambda s: counts[s._get_uuid()]) + + # Union smaller splits. + new_splits_small = _equalize_smaller_splits( + smaller_splits, target_size, counts, + num_splits - len(new_splits)) + new_splits.extend(new_splits_small) + return new_splits block_refs = list(self._blocks) metadata_mapping = { @@ -433,7 +559,8 @@ def equalize(splits: List[Dataset[T]]) -> List[Dataset[T]]: BlockList( list(blocks), [metadata_mapping[b] for b in blocks])) for blocks in np.array_split(block_refs, n) - ]) + if not equal or len(blocks) > 0 + ], n) # If the locality_hints is set, we use a two-round greedy algorithm # to co-locate the blocks with the actors based on block @@ -532,7 +659,7 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]: [metadata_mapping[b] for b in allocation_per_actor[actor]])) for actor in locality_hints - ]) + ], n) def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]: """Split the dataset at the given indices (like np.split). @@ -580,6 +707,9 @@ def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]: def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]": """Combine this dataset with others of the same type. + The order of the blocks in the datasets is preserved, as is the + relative ordering between the datasets passed in the argument list. + Args: other: List of datasets to combine with this one. The datasets must have the same schema as this dataset, otherwise the @@ -589,35 +719,21 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]": A new dataset holding the union of their data. """ - blocks: List[ObjectRef[Block]] = [] + calls: List[Callable[[], ObjectRef[Block]]] = [] metadata: List[BlockMetadata] = [] - pending_blocks: List[Callable[[], ObjectRef[Block]]] = [] - pending_metadata: List[BlockMetadata] = [] + blocks: List[ObjectRef[Block]] = [] datasets = [self] + list(other) for ds in datasets: bl = ds._blocks if isinstance(bl, LazyBlockList): - for block, meta in zip(bl._blocks, bl._metadata): - blocks.append(block) - metadata.append(meta) - lim = len(bl._blocks) - for call, meta in zip(bl._calls[lim:], bl._metadata[lim:]): - pending_blocks.append(call) - pending_metadata.append(meta) + calls.extend(bl._calls) else: - assert isinstance(bl, BlockList), bl - blocks.extend(list(bl._blocks)) - metadata.extend(bl.get_metadata()) + calls.extend([None] * len(bl)) + metadata.extend(bl._metadata) + blocks.extend(bl._blocks) - result = LazyBlockList([], []) - result._calls = ([None] * len(blocks)) + pending_blocks - result._blocks = blocks - result._metadata = metadata + pending_metadata - - assert len(result._calls) == len(result._metadata), result - assert len(result._blocks) <= len(result._calls), result - return Dataset(result) + return Dataset(LazyBlockList(calls, metadata, blocks)) def sort(self, key: Union[None, str, List[str], Callable[[T], Any]] = None, @@ -653,6 +769,9 @@ def sort(self, Returns: A new, sorted dataset. """ + # Handle empty dataset. + if self.num_blocks() == 0: + return self return Dataset(sort_impl(self._blocks, key, descending)) def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]": @@ -678,8 +797,8 @@ def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]": comes from the first dataset and v comes from the second. """ - blocks1 = self.get_blocks() - blocks2 = other.get_blocks() + blocks1 = self.get_internal_block_refs() + blocks2 = other.get_internal_block_refs() if len(blocks1) != len(blocks2): # TODO(ekl) consider supporting if num_rows are equal. @@ -761,6 +880,9 @@ def count(self) -> int: Returns: The number of records in the dataset. """ + # Handle empty dataset. + if self.num_blocks() == 0: + return 0 # For parquet, we can return the count directly from metadata. meta_count = self._meta_count() @@ -849,6 +971,8 @@ def write_parquet(self, path: str, *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, **arrow_parquet_args) -> None: """Write the dataset to parquet. @@ -867,6 +991,10 @@ def write_parquet(self, path: The path to the destination root directory, where Parquet files will be written to. filesystem: The filesystem implementation to write to. + try_create_dir: Try to create all directories in destination path + if True. Does nothing if all directories already exist. + arrow_open_stream_args: kwargs passed to + pyarrow.fs.FileSystem.open_output_stream arrow_parquet_args: Options to pass to pyarrow.parquet.write_table(), which is used to write out each block to a file. @@ -876,12 +1004,16 @@ def write_parquet(self, path=path, dataset_uuid=self._uuid, filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, **arrow_parquet_args) def write_json(self, path: str, *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, **pandas_json_args) -> None: """Write the dataset to json. @@ -900,6 +1032,10 @@ def write_json(self, path: The path to the destination root directory, where json files will be written to. filesystem: The filesystem implementation to write to. + try_create_dir: Try to create all directories in destination path + if True. Does nothing if all directories already exist. + arrow_open_stream_args: kwargs passed to + pyarrow.fs.FileSystem.open_output_stream pandas_json_args: These args will be passed to pandas.DataFrame.to_json(), which we use under the hood to write out each Datasets block. These @@ -910,12 +1046,16 @@ def write_json(self, path=path, dataset_uuid=self._uuid, filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, **pandas_json_args) def write_csv(self, path: str, *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, **arrow_csv_args) -> None: """Write the dataset to csv. @@ -934,6 +1074,10 @@ def write_csv(self, path: The path to the destination root directory, where csv files will be written to. filesystem: The filesystem implementation to write to. + try_create_dir: Try to create all directories in destination path + if True. Does nothing if all directories already exist. + arrow_open_stream_args: kwargs passed to + pyarrow.fs.FileSystem.open_output_stream arrow_csv_args: Other CSV write options to pass to pyarrow. """ self.write_datasource( @@ -941,17 +1085,23 @@ def write_csv(self, path=path, dataset_uuid=self._uuid, filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, **arrow_csv_args) def write_numpy( self, path: str, *, - filesystem: Optional["pyarrow.fs.FileSystem"] = None) -> None: - """Write the dataset to npy files. + column: str = "value", + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None) -> None: + """Write a tensor column of the dataset to npy files. - This is only supported for datasets of Tensor records. - To control the number of files, use ``.repartition()``. + This is only supported for datasets convertible to Arrow records that + contain a TensorArray column. To control the number of files, use + ``.repartition()``. The format of the output files will be {self._uuid}_{block_idx}.npy, where ``uuid`` is an unique id for the dataset. @@ -964,13 +1114,22 @@ def write_numpy( Args: path: The path to the destination root directory, where npy files will be written to. + column: The name of the table column that contains the tensor to + be written. This defaults to "value". filesystem: The filesystem implementation to write to. + try_create_dir: Try to create all directories in destination path + if True. Does nothing if all directories already exist. + arrow_open_stream_args: kwargs passed to + pyarrow.fs.FileSystem.open_output_stream """ self.write_datasource( NumpyDatasource(), path=path, dataset_uuid=self._uuid, - filesystem=filesystem) + column=column, + filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args) def write_datasource(self, datasource: Datasource[T], **write_args) -> None: @@ -1042,7 +1201,7 @@ def iter_batches(self, batch_format: The format in which to return each batch. Specify "native" to use the current block format, "pandas" to select ``pandas.DataFrame`` or "pyarrow" to select - ``pyarrow.Table/Tensor``. Default is "native". + ``pyarrow.Table``. Default is "native". drop_last: Whether to drop the last batch if it's incomplete. Returns: @@ -1310,14 +1469,15 @@ def to_modin(self) -> "modin.DataFrame": """Convert this dataset into a Modin dataframe. This works by first converting this dataset into a distributed set of - Pandas dataframes (using ``.to_pandas()``). Please see caveats there. - Then the individual dataframes are used to create the modin DataFrame - using + Pandas dataframes (using ``.to_pandas_refs()``). Please see caveats + there. Then the individual dataframes are used to create the modin + DataFrame using ``modin.distributed.dataframe.pandas.partitions.from_partitions()``. This is only supported for datasets convertible to Arrow records. This function induces a copy of the data. For zero-copy access to the - underlying data, consider using ``.to_arrow()`` or ``.get_blocks()``. + underlying data, consider using ``.to_arrow()`` or + ``.get_internal_block_refs()``. Time complexity: O(dataset size / parallelism) @@ -1327,7 +1487,7 @@ def to_modin(self) -> "modin.DataFrame": from modin.distributed.dataframe.pandas.partitions import ( from_partitions) - pd_objs = self.to_pandas() + pd_objs = self.to_pandas_refs() return from_partitions(pd_objs, axis=0) def to_spark(self, @@ -1343,17 +1503,45 @@ def to_spark(self, core_worker = ray.worker.global_worker.core_worker locations = [ core_worker.get_owner_address(block) - for block in self.get_blocks() + for block in self.get_internal_block_refs() ] return raydp.spark.ray_dataset_to_spark_dataframe( - spark, self.schema(), self.get_blocks(), locations) + spark, self.schema(), self.get_internal_block_refs(), locations) + + def to_pandas(self, limit: int = 1000) -> "pandas.DataFrame": + """Convert this dataset into a single Pandas DataFrame. + + This is only supported for datasets convertible to Arrow records. This + limits the number of records returned to the provided limit. + + Time complexity: O(limit) + + Args: + limit: The maximum number of records to return. - def to_pandas(self) -> List[ObjectRef["pandas.DataFrame"]]: + Returns: + A Pandas DataFrame created from this dataset, containing a limited + number of records. + """ + + if self.count() > limit: + logger.warning(f"Only returning the first {limit} records from " + "to_pandas()") + limited_ds = self.limit(limit) + blocks = limited_ds.get_internal_block_refs() + output = DelegatingArrowBlockBuilder() + for block in ray.get(blocks): + output.add_block(block) + return output.build().to_pandas() + + @DeveloperAPI + def to_pandas_refs(self) -> List[ObjectRef["pandas.DataFrame"]]: """Convert this dataset into a distributed set of Pandas dataframes. This is only supported for datasets convertible to Arrow records. This function induces a copy of the data. For zero-copy access to the - underlying data, consider using ``.to_arrow()`` or ``.get_blocks()``. + underlying data, consider using ``.to_arrow()`` or + ``.get_internal_block_refs()``. Time complexity: O(dataset size / parallelism) @@ -1364,23 +1552,48 @@ def to_pandas(self) -> List[ObjectRef["pandas.DataFrame"]]: block_to_df = cached_remote_fn(_block_to_df) return [block_to_df.remote(block) for block in self._blocks] - def to_numpy(self) -> List[ObjectRef[np.ndarray]]: + def to_numpy(self, *, + column: Optional[str] = None) -> List[ObjectRef[np.ndarray]]: """Convert this dataset into a distributed set of NumPy ndarrays. This is only supported for datasets convertible to NumPy ndarrays. This function induces a copy of the data. For zero-copy access to the - underlying data, consider using ``.to_arrow()`` or ``.get_blocks()``. + underlying data, consider using ``.to_arrow()`` or + ``.get_internal_block_refs()``. Time complexity: O(dataset size / parallelism) + Args: + column: The name of the column to convert to numpy, or None to + specify the entire row. Required for Arrow tables. + Returns: A list of remote NumPy ndarrays created from this dataset. """ block_to_ndarray = cached_remote_fn(_block_to_ndarray) - return [block_to_ndarray.remote(block) for block in self._blocks] + return [ + block_to_ndarray.remote(block, column=column) + for block in self._blocks + ] + + def to_arrow(self) -> List["pyarrow.Table"]: + """Convert this dataset into a list of Arrow tables. + + This is only supported for datasets convertible to Arrow records. + This function is zero-copy if the existing data is already in Arrow + format. Otherwise, the data will be converted to Arrow format. + + Time complexity: O(1) unless conversion is required. - def to_arrow(self) -> List[ObjectRef["pyarrow.Table"]]: + Returns: + A list of Arrow tables created from this dataset. + """ + + return ray.get(self.to_arrow_refs()) + + @DeveloperAPI + def to_arrow_refs(self) -> List[ObjectRef["pyarrow.Table"]]: """Convert this dataset into a distributed set of Arrow tables. This is only supported for datasets convertible to Arrow records. @@ -1450,28 +1663,32 @@ def __init__(self, ds: "Dataset[T]"): def __iter__(self): return Iterator(self._ds) - return DatasetPipeline(Iterable(self), length=times) + return DatasetPipeline(Iterable(self), length=times or float("inf")) def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]": - """Pipeline the dataset execution by splitting its blocks into groups. + raise DeprecationWarning("Use .window(blocks_per_window=n) instead of " + ".pipeline(parallelism=n)") - Transformations prior to the call to ``pipeline()`` are evaluated in + def window(self, *, blocks_per_window: int = 10) -> "DatasetPipeline[T]": + """Convert this into a DatasetPipeline by windowing over data blocks. + + Transformations prior to the call to ``window()`` are evaluated in bulk on the entire dataset. Transformations done on the returned - pipeline are evaluated incrementally per group of blocks as data is + pipeline are evaluated incrementally per window of blocks as data is read from the output of the pipeline. - Pipelining execution allows for output to be read sooner without + Windowing execution allows for output to be read sooner without waiting for all transformations to fully execute, and can also improve efficiency if transforms use different resources (e.g., GPUs). - Without pipelining:: + Without windowing:: [preprocessing......] [inference.......] [write........] Time -----------------------------------------------------------> - With pipelining:: + With windowing:: [prep1] [prep2] [prep3] [infer1] [infer2] [infer3] @@ -1481,20 +1698,20 @@ def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]": Examples: >>> # Create an inference pipeline. >>> ds = ray.data.read_binary_files(dir) - >>> pipe = ds.pipeline(parallelism=10).map(infer) - DatasetPipeline(num_stages=2, length=40) + >>> pipe = ds.window(blocks_per_window=10).map(infer) + DatasetPipeline(num_windows=40, num_stages=2) >>> # The higher the stage parallelism, the shorter the pipeline. - >>> pipe = ds.pipeline(parallelism=20).map(infer) - DatasetPipeline(num_stages=2, length=20) + >>> pipe = ds.window(blocks_per_window=20).map(infer) + DatasetPipeline(num_windows=20, num_stages=2) >>> # Outputs can be incrementally read from the pipeline. >>> for item in pipe.iter_rows(): ... print(item) Args: - parallelism: The parallelism (number of blocks) per stage. - Increasing parallelism increases pipeline throughput, but also + blocks_per_window: The window size (parallelism) in blocks. + Increasing window size increases pipeline throughput, but also increases the latency to initial output, since it decreases the length of the pipeline. Setting this to infinity effectively disables pipelining. @@ -1518,7 +1735,7 @@ def gen(): class Iterable: def __init__(self, blocks): - self._splits = blocks.split(split_size=parallelism) + self._splits = blocks.split(split_size=blocks_per_window) def __iter__(self): return Iterator(self._splits) @@ -1527,7 +1744,7 @@ def __iter__(self): return DatasetPipeline(it, length=len(it._splits)) @DeveloperAPI - def get_blocks(self) -> List[ObjectRef[Block]]: + def get_internal_block_refs(self) -> List[ObjectRef[Block]]: """Get a list of references to the underlying blocks of this dataset. This function can be used for zero-copy access to the data. @@ -1581,13 +1798,14 @@ def _split(self, index: int, right = None return left, right + def _divide(self, block_idx: int) -> ("Dataset[T]", "Dataset[T]"): + left, right = self._blocks.divide(block_idx) + return Dataset(left), Dataset(right) + def __repr__(self) -> str: schema = self.schema() if schema is None: schema_str = "Unknown schema" - elif isinstance(schema, dict): - schema_str = "".format( - schema["shape"], schema["dtype"]) elif isinstance(schema, type): schema_str = str(schema) else: @@ -1599,8 +1817,6 @@ def __repr__(self) -> str: schema_str = ", ".join(schema_str) schema_str = "{" + schema_str + "}" count = self._meta_count() - if count is None: - count = "?" return "Dataset(num_blocks={}, num_rows={}, schema={})".format( len(self._blocks), count, schema_str) @@ -1640,9 +1856,9 @@ def _block_to_df(block: Block): return block.to_pandas() -def _block_to_ndarray(block: Block): +def _block_to_ndarray(block: Block, column: Optional[str]): block = BlockAccessor.for_block(block) - return block.to_numpy() + return block.to_numpy(column) def _block_to_arrow(block: Block): diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 158905e70e9f9..962961105f895 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1,7 +1,7 @@ import functools import time from typing import Any, Callable, List, Iterator, Iterable, Generic, Union, \ - TYPE_CHECKING + Optional, TYPE_CHECKING import ray from ray.data.dataset import Dataset, T, U, BatchType @@ -13,13 +13,15 @@ if TYPE_CHECKING: import pyarrow -# Operations that can be naively applied per dataset in the pipeline. +# Operations that can be naively applied per dataset row in the pipeline. PER_DATASET_OPS = [ - "map", "map_batches", "flat_map", "filter", "repartition", - "random_shuffle", "sort", "write_json", "write_csv", "write_parquet", - "write_datasource" + "map", "map_batches", "flat_map", "filter", "write_json", "write_csv", + "write_parquet", "write_datasource" ] +# Operations that apply to each dataset holistically in the pipeline. +HOLISTIC_PER_DATASET_OPS = ["repartition", "random_shuffle", "sort"] + # Similar to above but we should force evaluation immediately. PER_DATASET_OUTPUT_OPS = [ "write_json", "write_csv", "write_parquet", "write_datasource" @@ -40,7 +42,7 @@ class DatasetPipeline(Generic[T]): A DatasetPipeline can be created by either repeating a Dataset (``ds.repeat(times=None)``), by turning a single Dataset into a pipeline - (``ds.pipeline(parallelism=10)``), or defined explicitly using + (``ds.window(blocks_per_window=10)``), or defined explicitly using ``DatasetPipeline.from_iterable()``. DatasetPipeline supports the all the per-record transforms of Datasets @@ -57,7 +59,7 @@ def __init__(self, """Construct a DatasetPipeline (internal API). The constructor is not part of the DatasetPipeline API. Use the - ``Dataset.repeat()``, ``Dataset.pipeline()``, or + ``Dataset.repeat()``, ``Dataset.window()``, or ``DatasetPipeline.from_iterable()`` methods to construct a pipeline. """ self._base_iterable = base_iterable @@ -240,6 +242,124 @@ def __next__(self): for idx in range(n) ] + def rewindow(self, *, blocks_per_window: int) -> "DatasetPipeline[T]": + """Change the windowing (blocks per dataset) of this pipeline. + + Changes the windowing of this pipeline to the specified size. For + example, if the current pipeline has two blocks per dataset, and + `.rewindow(blocks_per_window=4)` is requested, adjacent datasets will + be merged until each dataset is 4 blocks. If + `.rewindow(blocks_per_window)` was requested the datasets will be + split into smaller windows. + + Args: + blocks_per_window: The new target blocks per window. + """ + + class WindowIterator: + def __init__(self, original_iter): + self._original_iter = original_iter + self._buffer: Optional[Dataset[T]] = None + + def __next__(self) -> Dataset[T]: + try: + # Merge windows until we meet the requested window size. + if self._buffer is None: + self._buffer = next(self._original_iter) + while self._buffer.num_blocks() < blocks_per_window: + self._buffer = self._buffer.union( + next(self._original_iter)) + # Slice off the left-most chunk and return it. + res, self._buffer = self._buffer._divide(blocks_per_window) + assert res.num_blocks() <= blocks_per_window, res + return lambda: res + except StopIteration: + # Return the left-over data as a single window. + if self._buffer and self._buffer.num_blocks() > 0: + res = self._buffer + assert res.num_blocks() <= blocks_per_window, res + self._buffer = None + return lambda: res + else: + raise + + class WindowIterable: + def __init__(self, original_iter): + self._original_iter = original_iter + + def __iter__(self): + return WindowIterator(self._original_iter) + + return DatasetPipeline( + WindowIterable(self.iter_datasets()), length=None) + + def repeat(self, times: int = None) -> "DatasetPipeline[T]": + """Repeat this pipeline a given number or times, or indefinitely. + + This operation is only allowed for pipelines of a finite length. An + error will be raised for pipelines of infinite length. + + Transformations prior to the call to ``repeat()`` are evaluated once. + Transformations done on the repeated pipeline are evaluated on each + loop of the pipeline over the base pipeline. + + Args: + times: The number of times to loop over this pipeline, or None + to repeat indefinitely. + """ + + if self._length == float("inf"): + raise ValueError("Cannot repeat a pipeline of infinite length.") + + class RepeatIterator: + def __init__(self, original_iter): + self._original_iter = original_iter + # Holds results to repeat. + self._results = [] + # Incrementing cursor over results. + self._i = 0 + # This is calculated later. + self._max_i = None + + def __next__(self) -> Dataset[T]: + # Still going through the original pipeline. + if self._original_iter: + try: + res = next(self._original_iter) + self._results.append(res) + return lambda: res + except StopIteration: + self._original_iter = None + # Calculate the cursor limit. + if times: + self._max_i = len(self._results) * (times - 1) + else: + self._max_i = float("inf") + # Going through a repeat of the pipeline. + if self._i < self._max_i: + res = self._results[self._i % len(self._results)] + self._i += 1 + return lambda: res + else: + raise StopIteration + + class RepeatIterable: + def __init__(self, original_iter): + self._original_iter = original_iter + + def __iter__(self): + return RepeatIterator(self._original_iter) + + if not times: + length = float("inf") + elif times and self._length: + length = times * self._length + else: + length = None + + return DatasetPipeline( + RepeatIterable(self.iter_datasets()), length=length) + def schema(self) -> Union[type, "pyarrow.lib.Schema"]: """Return the schema of the dataset pipeline. @@ -287,6 +407,19 @@ def sum(self) -> int: total += elem return total + def show_windows(self, limit_per_dataset: int = 10) -> None: + """Print up to the given number of records from each window/dataset. + + This is helpful as a debugging tool for understanding the structure of + dataset pipelines. + + Args: + limit_per_dataset: Rows to print per window/dataset. + """ + for i, ds in enumerate(self.iter_datasets()): + print("=== Window {} ===".format(i)) + ds.show(limit_per_dataset) + @DeveloperAPI def iter_datasets(self) -> Iterator[Dataset[T]]: """Iterate over the output datasets of this pipeline. @@ -300,9 +433,9 @@ def iter_datasets(self) -> Iterator[Dataset[T]]: return PipelineExecutor(self) @DeveloperAPI - def foreach_dataset(self, fn: Callable[[Dataset[T]], Dataset[U]] - ) -> "DatasetPipeline[U]": - """Apply a transform to each dataset in this pipeline. + def foreach_window(self, fn: Callable[[Dataset[T]], Dataset[U]] + ) -> "DatasetPipeline[U]": + """Apply a transform to each dataset/window in this pipeline. Args: fn: The function to transform each dataset with. @@ -319,6 +452,10 @@ def foreach_dataset(self, fn: Callable[[Dataset[T]], Dataset[U]] self._progress_bars, _executed=self._executed) + def foreach_dataset(self, *a, **kw) -> None: + raise DeprecationWarning( + "`foreach_dataset` has been renamed to `foreach_window`.") + @staticmethod def from_iterable(iterable: Iterable[Callable[[], Dataset[T]]], ) -> "DatasetPipeline[T]": @@ -335,7 +472,7 @@ def from_iterable(iterable: Iterable[Callable[[], Dataset[T]]], return DatasetPipeline(iterable, length=length) def __repr__(self) -> str: - return "DatasetPipeline(length={}, num_stages={})".format( + return "DatasetPipeline(num_windows={}, num_stages={})".format( self._length, 1 + len(self._stages)) def __str__(self) -> str: @@ -355,7 +492,7 @@ def make_impl(method): @functools.wraps(delegate) def impl(self, *args, **kwargs): - return self.foreach_dataset( + return self.foreach_window( lambda ds: getattr(ds, method)(*args, **kwargs)) if impl.__annotations__.get("return"): @@ -366,6 +503,33 @@ def impl(self, *args, **kwargs): setattr(DatasetPipeline, method, make_impl(method)) +for method in HOLISTIC_PER_DATASET_OPS: + + def make_impl(method): + delegate = getattr(Dataset, method) + + @functools.wraps(delegate) + def impl(self, *args, **kwargs): + return self.foreach_window( + lambda ds: getattr(ds, method)(*args, **kwargs)) + + if impl.__annotations__.get("return"): + impl.__annotations__["return"] = impl.__annotations__[ + "return"].replace("Dataset", "DatasetPipeline") + + return impl + + def deprecation_warning(method: str): + def impl(*a, **kw): + raise DeprecationWarning( + "`{}` has been renamed to `{}_each_window`.".format( + method, method)) + + return impl + + setattr(DatasetPipeline, method, deprecation_warning(method)) + setattr(DatasetPipeline, method + "_each_window", make_impl(method)) + for method in PER_DATASET_OUTPUT_OPS: def make_impl(method): diff --git a/python/ray/data/datasource/__init__.py b/python/ray/data/datasource/__init__.py index b09aa9acb0c0d..fd3b5e21c6eb0 100644 --- a/python/ray/data/datasource/__init__.py +++ b/python/ray/data/datasource/__init__.py @@ -1,5 +1,6 @@ from ray.data.datasource.datasource import (Datasource, RangeDatasource, - DummyOutputDatasource, ReadTask) + DummyOutputDatasource, ReadTask, + RandomIntRowDatasource) from ray.data.datasource.json_datasource import JSONDatasource from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.datasource.numpy_datasource import NumpyDatasource @@ -18,6 +19,7 @@ "_S3FileSystemWrapper", "Datasource", "RangeDatasource", + "RandomIntRowDatasource", "DummyOutputDatasource", "ReadTask", ] diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 46b313ab3bfd0..0945120306ebf 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -130,10 +130,11 @@ def make_block(start: int, count: int) -> Block: return pyarrow.Table.from_arrays( [np.arange(start, start + count)], names=["value"]) elif block_format == "tensor": - return np.ones( - tensor_shape, dtype=np.int64) * np.expand_dims( + tensor = TensorArray( + np.ones(tensor_shape, dtype=np.int64) * np.expand_dims( np.arange(start, start + count), - tuple(range(1, 1 + len(tensor_shape)))) + tuple(range(1, 1 + len(tensor_shape))))) + return pyarrow.Table.from_pydict({"value": tensor}) else: return list(builtins.range(start, start + count)) @@ -145,7 +146,14 @@ def make_block(start: int, count: int) -> Block: import pyarrow schema = pyarrow.Table.from_pydict({"value": [0]}).schema elif block_format == "tensor": - schema = {"dtype": "int64", "shape": (None, ) + tensor_shape} + _check_pyarrow_version() + from ray.data.extensions import TensorArray + import pyarrow + tensor = TensorArray( + np.ones(tensor_shape, dtype=np.int64) * np.expand_dims( + np.arange(0, 10), tuple( + range(1, 1 + len(tensor_shape))))) + schema = pyarrow.Table.from_pydict({"value": tensor}).schema elif block_format == "list": schema = int else: @@ -213,3 +221,50 @@ def on_write_complete(self, write_results: List[WriteResult]) -> None: def on_write_failed(self, write_results: List[ObjectRef[WriteResult]], error: Exception) -> None: self.num_failed += 1 + + +class RandomIntRowDatasource(Datasource[ArrowRow]): + """An example datasource that generates rows with random int64 columns. + + Examples: + >>> source = RandomIntRowDatasource() + >>> ray.data.read_datasource(source, n=10, num_columns=2).take() + ... ArrowRow({'c_0': 1717767200176864416, 'c_1': 999657309586757214}) + ... ArrowRow({'c_0': 4983608804013926748, 'c_1': 1160140066899844087}) + """ + + def prepare_read(self, parallelism: int, n: int, + num_columns: int) -> List[ReadTask]: + _check_pyarrow_version() + import pyarrow + + read_tasks: List[ReadTask] = [] + block_size = max(1, n // parallelism) + + def make_block(count: int, num_columns: int) -> Block: + return pyarrow.Table.from_arrays( + np.random.randint( + np.iinfo(np.int64).max, + size=(num_columns, count), + dtype=np.int64), + names=[f"c_{i}" for i in range(num_columns)]) + + schema = pyarrow.Table.from_pydict( + {f"c_{i}": [0] + for i in range(num_columns)}).schema + + i = 0 + while i < n: + count = min(block_size, n - i) + read_tasks.append( + ReadTask( + lambda count=count, num_columns=num_columns: + make_block(count, num_columns), + BlockMetadata( + num_rows=count, + size_bytes=8 * count * num_columns, + schema=schema, + input_files=None))) + i += block_size + + return read_tasks diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 9a326ebdcf62d..054f18b0436f7 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -1,6 +1,7 @@ import logging import os -from typing import Callable, Optional, List, Tuple, Union, Any, TYPE_CHECKING +from typing import Callable, Optional, List, Tuple, Union, Any, Dict, \ + TYPE_CHECKING import urllib.parse if TYPE_CHECKING: @@ -36,6 +37,7 @@ def prepare_read( paths: Union[str, List[str]], filesystem: Optional["pyarrow.fs.FileSystem"] = None, schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None, + open_stream_args: Optional[Dict[str, Any]] = None, _block_udf: Optional[Callable[[Block], Block]] = None, **reader_args) -> List[ReadTask]: """Creates and returns read tasks for a file-based datasource. @@ -52,6 +54,9 @@ def prepare_read( filesystem = _wrap_s3_serialization_workaround(filesystem) + if open_stream_args is None: + open_stream_args = {} + def read_files( read_paths: List[str], fs: Union["pyarrow.fs.FileSystem", _S3FileSystemWrapper]): @@ -60,7 +65,7 @@ def read_files( fs = fs.unwrap() builder = DelegatingArrowBlockBuilder() for read_path in read_paths: - with fs.open_input_stream(read_path) as f: + with fs.open_input_stream(read_path, **open_stream_args) as f: data = read_file(f, read_path, **reader_args) if isinstance(data, pa.Table) or isinstance( data, np.ndarray): @@ -115,16 +120,22 @@ def do_write(self, path: str, dataset_uuid: str, filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + open_stream_args: Optional[Dict[str, Any]] = None, _block_udf: Optional[Callable[[Block], Block]] = None, **write_args) -> List[ObjectRef[WriteResult]]: """Creates and returns write tasks for a file-based datasource.""" path, filesystem = _resolve_paths_and_filesystem(path, filesystem) path = path[0] - filesystem.create_dir(path, recursive=True) + if try_create_dir: + filesystem.create_dir(path, recursive=True) filesystem = _wrap_s3_serialization_workaround(filesystem) _write_block_to_file = self._write_block + if open_stream_args is None: + open_stream_args = {} + def write_block(write_path: str, block: Block): logger.debug(f"Writing {write_path} file.") fs = filesystem @@ -132,8 +143,10 @@ def write_block(write_path: str, block: Block): fs = fs.unwrap() if _block_udf is not None: block = _block_udf(block) - with fs.open_output_stream(write_path) as f: - _write_block_to_file(f, BlockAccessor.for_block(block)) + + with fs.open_output_stream(write_path, **open_stream_args) as f: + _write_block_to_file(f, BlockAccessor.for_block(block), + **write_args) write_block = cached_remote_fn(write_block) @@ -188,9 +201,8 @@ def _resolve_paths_and_filesystem( compatibility. """ import pyarrow as pa - from pyarrow.fs import (FileSystem, PyFileSystem, FSSpecHandler, - _resolve_filesystem_and_path) - import fsspec + from pyarrow.fs import FileSystem, PyFileSystem, FSSpecHandler, \ + _resolve_filesystem_and_path if isinstance(paths, str): paths = [paths] @@ -202,11 +214,20 @@ def _resolve_paths_and_filesystem( raise ValueError("Must provide at least one path.") if filesystem and not isinstance(filesystem, FileSystem): + err_msg = f"The filesystem passed must either conform to " \ + f"pyarrow.fs.FileSystem, or " \ + f"fsspec.spec.AbstractFileSystem. The provided " \ + f"filesystem was: {filesystem}" + try: + import fsspec + except ModuleNotFoundError: + # If filesystem is not a pyarrow filesystem and fsspec isn't + # installed, then filesystem is neither a pyarrow filesystem nor + # an fsspec filesystem, so we raise a TypeError. + raise TypeError(err_msg) if not isinstance(filesystem, fsspec.spec.AbstractFileSystem): - raise TypeError(f"The filesystem passed must either conform to " - f"pyarrow.fs.FileSystem, or " - f"fsspec.spec.AbstractFileSystem. The provided " - f"filesystem was: {filesystem}") + raise TypeError(err_msg) + filesystem = PyFileSystem(FSSpecHandler(filesystem)) resolved_paths = [] @@ -266,9 +287,10 @@ def _expand_paths(paths: Union[str, List[str]], return expanded_paths, file_infos -def _expand_directory(path: str, - filesystem: "pyarrow.fs.FileSystem", - exclude_prefixes: List[str] = [".", "_"]) -> List[str]: +def _expand_directory( + path: str, + filesystem: "pyarrow.fs.FileSystem", + exclude_prefixes: Optional[List[str]] = None) -> List[str]: """ Expand the provided directory path to a list of file paths. @@ -283,6 +305,9 @@ def _expand_directory(path: str, Returns: A list of file paths contained in the provided directory. """ + if exclude_prefixes is None: + exclude_prefixes = [".", "_"] + from pyarrow.fs import FileSelector selector = FileSelector(path, recursive=True) files = filesystem.get_file_info(selector) @@ -295,7 +320,7 @@ def _expand_directory(path: str, if not file_path.startswith(base_path): continue relative = file_path[len(base_path):] - if any(relative.startswith(prefix) for prefix in [".", "_"]): + if any(relative.startswith(prefix) for prefix in exclude_prefixes): continue filtered_paths.append((file_path, file_)) # We sort the paths to guarantee a stable order. diff --git a/python/ray/data/datasource/numpy_datasource.py b/python/ray/data/datasource/numpy_datasource.py index 08bc7f2c0916e..8ba02e9d40cc5 100644 --- a/python/ray/data/datasource/numpy_datasource.py +++ b/python/ray/data/datasource/numpy_datasource.py @@ -7,7 +7,7 @@ import pyarrow from ray.data.block import BlockAccessor -from ray.data.datasource.file_based_datasource import (FileBasedDatasource) +from ray.data.datasource.file_based_datasource import FileBasedDatasource class NumpyDatasource(FileBasedDatasource): @@ -21,17 +21,22 @@ class NumpyDatasource(FileBasedDatasource): """ def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args): + from ray.data.extensions import TensorArray + import pyarrow as pa # TODO(ekl) Ideally numpy can read directly from the file, but it # seems like it requires the file to be seekable. buf = BytesIO() data = f.readall() buf.write(data) buf.seek(0) - return np.load(buf) + return pa.Table.from_pydict({ + "value": TensorArray(np.load(buf, allow_pickle=True)) + }) def _write_block(self, f: "pyarrow.NativeFile", block: BlockAccessor, - **writer_args): - np.save(f, block.to_arrow()) + column: str, **writer_args): + value = block.to_numpy(column) + np.save(f, value) def _file_format(self): return "npy" diff --git a/python/ray/data/examples/demo_infer.py b/python/ray/data/examples/demo_infer.py index 18237f7898541..352d8ddf31ec6 100644 --- a/python/ray/data/examples/demo_infer.py +++ b/python/ray/data/examples/demo_infer.py @@ -18,7 +18,7 @@ def __call__(self, x): return x -ds = ds.pipeline(parallelism=10) \ +ds = ds.window(blocks_per_window=10) \ .map(preprocess) \ .map(Model, compute="actors", num_gpus=1) diff --git a/python/ray/data/extensions/tensor_extension.py b/python/ray/data/extensions/tensor_extension.py index 3c80fed64242f..9872cf7e225ef 100644 --- a/python/ray/data/extensions/tensor_extension.py +++ b/python/ray/data/extensions/tensor_extension.py @@ -140,7 +140,7 @@ class TensorDtype(pd.api.extensions.ExtensionDtype): one: int64 two: extension> - >>> read_df = ray.get(read_ds.to_pandas())[0] + >>> read_df = ray.get(read_ds.to_pandas_refs())[0] >>> read_df.dtypes one int64 two TensorDtype @@ -422,7 +422,7 @@ class TensorArray(pd.api.extensions.ExtensionArray, TensorOpsMixin): one: int64 two: extension> - >>> read_df = ray.get(read_ds.to_pandas())[0] + >>> read_df = ray.get(read_ds.to_pandas_refs())[0] >>> read_df.dtypes one int64 two TensorDtype @@ -1155,6 +1155,10 @@ def __arrow_ext_class__(self): """ return ArrowTensorArray + def __str__(self): + return "".format( + self.shape, self.storage_type.value_type) + @PublicAPI(stability="beta") class ArrowTensorArray(pa.ExtensionArray): diff --git a/python/ray/data/impl/arrow_block.py b/python/ray/data/impl/arrow_block.py index 41c5875bb6c16..a9d0634930a49 100644 --- a/python/ray/data/impl/arrow_block.py +++ b/python/ray/data/impl/arrow_block.py @@ -13,7 +13,6 @@ from ray.data.block import Block, BlockAccessor, BlockMetadata from ray.data.impl.block_builder import BlockBuilder from ray.data.impl.simple_block import SimpleBlockBuilder -from ray.data.impl.tensor_block import TensorBlockBuilder if TYPE_CHECKING: import pandas @@ -78,8 +77,6 @@ def add(self, item: Any) -> None: self._builder = ArrowBlockBuilder() except (TypeError, pyarrow.lib.ArrowInvalid): self._builder = SimpleBlockBuilder() - elif isinstance(item, np.ndarray): - self._builder = TensorBlockBuilder() else: self._builder = SimpleBlockBuilder() self._builder.add(item) @@ -188,8 +185,21 @@ def schema(self) -> "pyarrow.lib.Schema": def to_pandas(self) -> "pandas.DataFrame": return self._table.to_pandas() - def to_numpy(self) -> np.ndarray: - return np.array(self._table) + def to_numpy(self, column: str = None) -> np.ndarray: + if not column: + raise ValueError( + "`column` must be specified when calling .to_numpy() " + "on Arrow blocks.") + if column not in self._table.column_names: + raise ValueError( + "Cannot find column {}, available columns: {}".format( + column, self._table.column_names)) + array = self._table[column] + if array.num_chunks > 1: + # TODO(ekl) combine fails since we can't concat ArrowTensorType? + array = array.combine_chunks() + assert array.num_chunks == 1, array + return self._table[column].chunk(0).to_numpy() def to_arrow(self) -> "pyarrow.Table": return self._table diff --git a/python/ray/data/impl/block_list.py b/python/ray/data/impl/block_list.py index b6e88c8fe4fc0..691a710b5faa6 100644 --- a/python/ray/data/impl/block_list.py +++ b/python/ray/data/impl/block_list.py @@ -42,6 +42,13 @@ def split(self, split_size: int) -> List["BlockList"]: output.append(BlockList(b.tolist(), m.tolist())) return output + def divide(self, block_idx: int) -> ("BlockList", "BlockList"): + self._check_if_cleared() + return (BlockList(self._blocks[:block_idx], + self._metadata[:block_idx]), + BlockList(self._blocks[block_idx:], + self._metadata[block_idx:])) + def __len__(self): self._check_if_cleared() return len(self._blocks) diff --git a/python/ray/data/impl/compute.py b/python/ray/data/impl/compute.py index e52aa3bce13d1..8f0a7fb8e41f0 100644 --- a/python/ray/data/impl/compute.py +++ b/python/ray/data/impl/compute.py @@ -35,6 +35,10 @@ def _map_block(block: Block, meta: BlockMetadata, class TaskPool(ComputeStrategy): def apply(self, fn: Any, remote_args: dict, blocks: BlockList[Any]) -> BlockList[Any]: + # Handle empty datasets. + if len(blocks) == 0: + return blocks + map_bar = ProgressBar("Map Progress", total=len(blocks)) kwargs = remote_args.copy() @@ -47,8 +51,23 @@ def apply(self, fn: Any, remote_args: dict, ] new_blocks, new_metadata = zip(*refs) - map_bar.block_until_complete(list(new_blocks)) - new_metadata = ray.get(list(new_metadata)) + new_metadata = list(new_metadata) + try: + new_metadata = map_bar.fetch_until_complete(new_metadata) + except (ray.exceptions.RayTaskError, KeyboardInterrupt) as e: + # One or more mapper tasks failed, or we received a SIGINT signal + # while waiting; either way, we cancel all map tasks. + for ref in new_metadata: + ray.cancel(ref) + # Wait until all tasks have failed or been cancelled. + for ref in new_metadata: + try: + ray.get(ref) + except (ray.exceptions.RayTaskError, + ray.exceptions.TaskCancelledError): + pass + # Reraise the original task failure exception. + raise e from None return BlockList(list(new_blocks), list(new_metadata)) diff --git a/python/ray/data/impl/lazy_block_list.py b/python/ray/data/impl/lazy_block_list.py index 7ccf8e58295ae..0bfd1e0ac1093 100644 --- a/python/ray/data/impl/lazy_block_list.py +++ b/python/ray/data/impl/lazy_block_list.py @@ -9,19 +9,25 @@ class LazyBlockList(BlockList[T]): - def __init__(self, calls: Callable[[], ObjectRef[Block]], - metadata: List[BlockMetadata]): - assert len(calls) == len(metadata), (calls, metadata) + def __init__(self, + calls: Callable[[], ObjectRef[Block]], + metadata: List[BlockMetadata], + blocks: List[ObjectRef[Block]] = None): self._calls = calls - self._blocks = [calls[0]()] if calls else [] self._metadata = metadata + if blocks: + self._blocks = blocks + else: + self._blocks = [None] * len(calls) + # Immediately compute the first block at least. + if calls: + self._blocks[0] = calls[0]() + assert len(calls) == len(metadata), (calls, metadata) + assert len(calls) == len(self._blocks), (calls, self._blocks) def copy(self) -> "LazyBlockList": - new_list = LazyBlockList.__new__(LazyBlockList) - new_list._calls = self._calls - new_list._blocks = self._blocks - new_list._metadata = self._metadata - return new_list + return LazyBlockList(self._calls.copy(), self._metadata.copy(), + self._blocks.copy()) def clear(self): super().clear() @@ -32,11 +38,22 @@ def split(self, split_size: int) -> List["LazyBlockList"]: num_splits = math.ceil(len(self._calls) / split_size) calls = np.array_split(self._calls, num_splits) meta = np.array_split(self._metadata, num_splits) + blocks = np.array_split(self._blocks, num_splits) output = [] - for c, m in zip(calls, meta): - output.append(LazyBlockList(c.tolist(), m.tolist())) + for c, m, b in zip(calls, meta, blocks): + output.append(LazyBlockList(c.tolist(), m.tolist(), b.tolist())) return output + def divide(self, block_idx: int) -> ("BlockList", "BlockList"): + self._check_if_cleared() + left = LazyBlockList(self._calls[:block_idx], + self._metadata[:block_idx], + self._blocks[:block_idx]) + right = LazyBlockList(self._calls[block_idx:], + self._metadata[block_idx:], + self._blocks[block_idx:]) + return left, right + def __len__(self): self._check_if_cleared() return len(self._calls) @@ -64,9 +81,19 @@ def _get_or_compute(self, i: int) -> ObjectRef[Block]: self._check_if_cleared() assert i < len(self._calls), i # Check if we need to compute more blocks. - if i >= len(self._blocks): - start = len(self._blocks) + if not self._blocks[i]: # Exponentially increase the number of blocks computed per batch. - for c in self._calls[start:max(i + 1, start * 2)]: - self._blocks.append(c()) + for j in range(max(i + 1, i * 2)): + if j >= len(self._blocks): + break + if not self._blocks[j]: + self._blocks[j] = self._calls[j]() + assert self._blocks[i], self._blocks return self._blocks[i] + + def _num_computed(self): + i = 0 + for b in self._blocks: + if b is not None: + i += 1 + return i diff --git a/python/ray/data/impl/pipeline_executor.py b/python/ray/data/impl/pipeline_executor.py index c02b04ffdabb4..7eeacc0a8cac1 100644 --- a/python/ray/data/impl/pipeline_executor.py +++ b/python/ray/data/impl/pipeline_executor.py @@ -10,7 +10,7 @@ from ray.data.dataset_pipeline import DatasetPipeline -@ray.remote +@ray.remote(num_cpus=0, placement_group=None) def pipeline_stage(fn: Callable[[], Dataset[T]]) -> Dataset[T]: try: prev = set_progress_bars(False) @@ -27,12 +27,15 @@ def __init__(self, pipeline: "DatasetPipeline[T]"): self._iter = iter(self._pipeline._base_iterable) self._stages[0] = pipeline_stage.remote(next(self._iter)) + if self._pipeline._length and self._pipeline._length != float("inf"): + length = self._pipeline._length + else: + length = 1 + if self._pipeline._progress_bars: self._bars = [ - ProgressBar( - "Stage {}".format(i), - self._pipeline._length or 1, - position=i) for i in range(len(self._stages)) + ProgressBar("Stage {}".format(i), length, position=i) + for i in range(len(self._stages)) ] else: self._bars = None @@ -84,7 +87,7 @@ def __next__(self): return output -@ray.remote +@ray.remote(num_cpus=0, placement_group=None) class PipelineSplitExecutorCoordinator: def __init__(self, pipeline: "DatasetPipeline[T]", n: int, splitter: Callable[[Dataset], "DatasetPipeline[T]"]): diff --git a/python/ray/data/impl/progress_bar.py b/python/ray/data/impl/progress_bar.py index c9c1caa43cb5b..fc28da681f3ee 100644 --- a/python/ray/data/impl/progress_bar.py +++ b/python/ray/data/impl/progress_bar.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Any import ray from ray.types import ObjectRef @@ -50,6 +50,16 @@ def block_until_complete(self, remaining: List[ObjectRef]) -> None: done, remaining = ray.wait(remaining, fetch_local=False) self.update(len(done)) + def fetch_until_complete(self, refs: List[ObjectRef]) -> List[Any]: + ref_to_result = {} + remaining = refs + while remaining: + done, remaining = ray.wait(remaining, fetch_local=True) + for ref, result in zip(done, ray.get(done)): + ref_to_result[ref] = result + self.update(len(done)) + return [ref_to_result[ref] for ref in refs] + def set_description(self, name: str) -> None: if self._bar: self._bar.set_description(name) diff --git a/python/ray/data/impl/remote_fn.py b/python/ray/data/impl/remote_fn.py index 968380e187c50..a6b4eb06d0f46 100644 --- a/python/ray/data/impl/remote_fn.py +++ b/python/ray/data/impl/remote_fn.py @@ -13,7 +13,10 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: which means ray.remote cannot be used top-level in ray.data). """ if fn not in CACHED_FUNCTIONS: - default_ray_remote_args = {"retry_exceptions": True} + default_ray_remote_args = { + "retry_exceptions": True, + "placement_group": None, + } CACHED_FUNCTIONS[fn] = ray.remote(**{ **default_ray_remote_args, **ray_remote_args diff --git a/python/ray/data/impl/simple_block.py b/python/ray/data/impl/simple_block.py index ba20d1334b06b..f609c65bd28b8 100644 --- a/python/ray/data/impl/simple_block.py +++ b/python/ray/data/impl/simple_block.py @@ -58,7 +58,9 @@ def to_pandas(self) -> "pandas.DataFrame": import pandas return pandas.DataFrame(self._items) - def to_numpy(self) -> np.ndarray: + def to_numpy(self, column: str = None) -> np.ndarray: + if column: + raise ValueError("`column` arg not supported for list block") return np.array(self._items) def to_arrow(self) -> "pyarrow.Table": diff --git a/python/ray/data/impl/tensor_block.py b/python/ray/data/impl/tensor_block.py deleted file mode 100644 index 3ad8d8afad71b..0000000000000 --- a/python/ray/data/impl/tensor_block.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Iterator, List, TypeVar, Dict, TYPE_CHECKING - -import numpy as np - -if TYPE_CHECKING: - import pandas - import pyarrow - -from ray.data.block import Block, BlockAccessor -from ray.data.impl.block_builder import BlockBuilder - -T = TypeVar("T") - - -# TODO(ekl) switch to pyarrow.Tensor as the block type; currently there is a -# serialization issue with pyarrow tensors. -class TensorBlockBuilder(BlockBuilder[T]): - def __init__(self): - self._rows = [] - self._tensors: List[np.ndarray] = [] - self._num_rows = 0 - - def add(self, row: np.ndarray) -> None: - self._rows.append(row) - self._num_rows += 1 - - def add_block(self, block: np.ndarray) -> None: - assert isinstance(block, np.ndarray), block - self._tensors.append(block) - self._num_rows += len(block) - - def build(self) -> Block: - tensors = self._tensors.copy() - if self._rows: - tensors.append(np.stack(self._rows, axis=0)) - return np.concatenate(tensors, axis=0) - - def num_rows(self) -> int: - return self._num_rows - - -class TensorBlockAccessor(BlockAccessor): - def __init__(self, tensor: np.ndarray): - self._tensor = tensor - - def iter_rows(self) -> Iterator[np.ndarray]: - return iter(self._tensor) - - def slice(self, start: int, end: int, - copy: bool) -> "TensorBlockAccessor[T]": - view = self._tensor[start:end] - if copy: - view = view.copy() - return view - - def to_pandas(self) -> "pandas.DataFrame": - import pandas - return pandas.DataFrame(self._tensor) - - def to_numpy(self) -> np.ndarray: - return self._tensor - - def to_arrow(self) -> "pyarrow.Tensor": - import pyarrow - return pyarrow.Tensor.from_numpy(self._tensor) - - def schema(self) -> Dict: - shape = self._tensor.shape - shape = (None, ) + shape[1:] - return {"shape": shape, "dtype": self._tensor.dtype.name} - - def num_rows(self) -> int: - return len(self._tensor) - - def size_bytes(self) -> int: - return self._tensor.nbytes - - @staticmethod - def builder() -> TensorBlockBuilder[T]: - return TensorBlockBuilder() diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 887e08baa1495..fb98561489987 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -14,7 +14,7 @@ import ray from ray.types import ObjectRef -from ray.util.annotations import PublicAPI +from ray.util.annotations import PublicAPI, DeveloperAPI from ray.data.block import Block, BlockAccessor, BlockMetadata from ray.data.dataset import Dataset from ray.data.datasource import Datasource, RangeDatasource, \ @@ -283,6 +283,7 @@ def read_json(paths: Union[str, List[str]], filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, ray_remote_args: Dict[str, Any] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, **arrow_json_args) -> Dataset[ArrowRow]: """Create an Arrow dataset from json files. @@ -302,6 +303,8 @@ def read_json(paths: Union[str, List[str]], filesystem: The filesystem implementation to read from. parallelism: The amount of parallelism to use for the dataset. ray_remote_args: kwargs passed to ray.remote in the read tasks. + arrow_open_stream_args: kwargs passed to + pyarrow.fs.FileSystem.open_input_stream arrow_json_args: Other json read options to pass to pyarrow. Returns: @@ -313,6 +316,7 @@ def read_json(paths: Union[str, List[str]], paths=paths, filesystem=filesystem, ray_remote_args=ray_remote_args, + open_stream_args=arrow_open_stream_args, **arrow_json_args) @@ -322,6 +326,7 @@ def read_csv(paths: Union[str, List[str]], filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, ray_remote_args: Dict[str, Any] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, **arrow_csv_args) -> Dataset[ArrowRow]: """Create an Arrow dataset from csv files. @@ -341,6 +346,8 @@ def read_csv(paths: Union[str, List[str]], filesystem: The filesystem implementation to read from. parallelism: The amount of parallelism to use for the dataset. ray_remote_args: kwargs passed to ray.remote in the read tasks. + arrow_open_stream_args: kwargs passed to + pyarrow.fs.FileSystem.open_input_stream arrow_csv_args: Other csv read options to pass to pyarrow. Returns: @@ -352,6 +359,7 @@ def read_csv(paths: Union[str, List[str]], paths=paths, filesystem=filesystem, ray_remote_args=ray_remote_args, + open_stream_args=arrow_open_stream_args, **arrow_csv_args) @@ -362,6 +370,7 @@ def read_text( encoding: str = "utf-8", filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, ) -> Dataset[str]: """Create a dataset from lines stored in text files. @@ -377,13 +386,18 @@ def read_text( encoding: The encoding of the files (e.g., "utf-8" or "ascii"). filesystem: The filesystem implementation to read from. parallelism: The amount of parallelism to use for the dataset. + arrow_open_stream_args: kwargs passed to + pyarrow.fs.FileSystem.open_input_stream Returns: Dataset holding lines of text read from the specified paths. """ return read_binary_files( - paths, filesystem=filesystem, parallelism=parallelism).flat_map( + paths, + filesystem=filesystem, + parallelism=parallelism, + arrow_open_stream_args=arrow_open_stream_args).flat_map( lambda x: x.decode(encoding).split("\n")) @@ -392,7 +406,8 @@ def read_numpy(paths: Union[str, List[str]], *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, - **numpy_load_args) -> Dataset[np.ndarray]: + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + **numpy_load_args) -> Dataset[ArrowRow]: """Create an Arrow dataset from csv files. Examples: @@ -410,6 +425,8 @@ def read_numpy(paths: Union[str, List[str]], A list of paths can contain both files and directories. filesystem: The filesystem implementation to read from. parallelism: The amount of parallelism to use for the dataset. + arrow_open_stream_args: kwargs passed to + pyarrow.fs.FileSystem.open_input_stream numpy_load_args: Other options to pass to np.load. Returns: @@ -420,6 +437,7 @@ def read_numpy(paths: Union[str, List[str]], parallelism=parallelism, paths=paths, filesystem=filesystem, + open_stream_args=arrow_open_stream_args, **numpy_load_args) @@ -431,6 +449,7 @@ def read_binary_files( filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, ray_remote_args: Dict[str, Any] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, ) -> Dataset[Union[Tuple[str, bytes], bytes]]: """Create a dataset from binary files of arbitrary contents. @@ -449,6 +468,8 @@ def read_binary_files( filesystem: The filesystem implementation to read from. ray_remote_args: kwargs passed to ray.remote in the read tasks. parallelism: The amount of parallelism to use for the dataset. + arrow_open_stream_args: kwargs passed to + pyarrow.fs.FileSystem.open_input_stream Returns: Dataset holding Arrow records read from the specified paths. @@ -460,6 +481,7 @@ def read_binary_files( include_paths=include_paths, filesystem=filesystem, ray_remote_args=ray_remote_args, + open_stream_args=arrow_open_stream_args, schema=bytes) @@ -509,12 +531,27 @@ def from_modin(df: "modin.DataFrame") -> Dataset[ArrowRow]: from modin.distributed.dataframe.pandas.partitions import unwrap_partitions parts = unwrap_partitions(df, axis=0) - return from_pandas(parts) + return from_pandas_refs(parts) @PublicAPI(stability="beta") -def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]: - """Create a dataset from a set of Pandas dataframes. +def from_pandas(dfs: List["pandas.DataFrame"]) -> Dataset[ArrowRow]: + """Create a dataset from a list of Pandas dataframes. + + Args: + dfs: A list of Pandas dataframes. + + Returns: + Dataset holding Arrow records read from the dataframes. + """ + return from_pandas_refs([ray.put(df) for df in dfs]) + + +@DeveloperAPI +def from_pandas_refs( + dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]: + """Create a dataset from a list of Ray object references to Pandas + dataframes. Args: dfs: A list of Ray object references to pandas dataframes. @@ -529,7 +566,7 @@ def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]: return Dataset(BlockList(blocks, ray.get(list(metadata)))) -def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]: +def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[ArrowRow]: """Create a dataset from a set of NumPy ndarrays. Args: @@ -546,8 +583,23 @@ def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]: @PublicAPI(stability="beta") -def from_arrow(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]] - ) -> Dataset[ArrowRow]: +def from_arrow( + tables: List[Union["pyarrow.Table", bytes]]) -> Dataset[ArrowRow]: + """Create a dataset from a list of Arrow tables. + + Args: + tables: A list of Ray object references to Arrow tables, + or its streaming format in bytes. + + Returns: + Dataset holding Arrow records from the tables. + """ + return from_arrow_refs([ray.put(t) for t in tables]) + + +@DeveloperAPI +def from_arrow_refs(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]] + ) -> Dataset[ArrowRow]: """Create a dataset from a set of Arrow tables. Args: @@ -590,8 +642,11 @@ def _df_to_block(df: "pandas.DataFrame") -> Block[ArrowRow]: def _ndarray_to_block(ndarray: np.ndarray) -> Block[np.ndarray]: - return (ndarray, - BlockAccessor.for_block(ndarray).get_metadata(input_files=None)) + import pyarrow as pa + from ray.data.extensions import TensorArray + table = pa.Table.from_pydict({"value": TensorArray(ndarray)}) + return (table, + BlockAccessor.for_block(table).get_metadata(input_files=None)) def _get_schema(block: Block) -> Any: diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 7562e2c5a7105..91aa91e5c2eb2 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -17,9 +17,11 @@ import ray from ray.tests.conftest import * # noqa +from ray.data.dataset import Dataset from ray.data.datasource import DummyOutputDatasource from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.block import BlockAccessor +from ray.data.impl.block_list import BlockList from ray.data.datasource.file_based_datasource import _unwrap_protocol from ray.data.extensions.tensor_extension import ( TensorArray, TensorDtype, ArrowTensorType, ArrowTensorArray) @@ -29,7 +31,7 @@ def maybe_pipeline(ds, enabled): if enabled: - return ds.pipeline(parallelism=1) + return ds.window(blocks_per_window=1) else: return ds @@ -58,7 +60,10 @@ def run(): assert sorted(ds.iter_rows()) == [0, 1, 2, 3, 4] pg = ray.util.placement_group([{"CPU": 1}]) - ray.get(run.options(placement_group=pg).remote()) + ray.get( + run.options( + placement_group=pg, + placement_group_capture_child_tasks=True).remote()) @pytest.mark.parametrize("pipelined", [False, True]) @@ -142,6 +147,102 @@ def __call__(self, x): assert len(actor_reuse) == 10, actor_reuse +def test_transform_failure(shutdown_only): + ray.init(num_cpus=2) + ds = ray.data.from_items([0, 10], parallelism=2) + + def mapper(x): + time.sleep(x) + raise ValueError("oops") + return x + + with pytest.raises(ray.exceptions.RayTaskError): + ds.map(mapper) + + +@pytest.mark.parametrize( + "block_sizes,num_splits", + [ + ( # Test baseline. + [3, 6, 3], 3), + ( # Already balanced. + [3, 3, 3], 3), + ( # Row truncation. + [3, 6, 4], 3), + ( # Row truncation, smaller number of blocks. + [3, 6, 2, 3], 3), + ( # Row truncation, larger number of blocks. + [5, 6, 2, 5], 5), + ( # All smaller but one. + [1, 1, 1, 1, 6], 5), + ( # All larger but one. + [4, 4, 4, 4, 1], 5), + ( # Single block. + [2], 2), + ( # Single split. + [2, 5], 1), + ]) +def test_equal_split_balanced(ray_start_regular_shared, block_sizes, + num_splits): + _test_equal_split_balanced(block_sizes, num_splits) + + +def _test_equal_split_balanced(block_sizes, num_splits): + blocks = [] + metadata = [] + total_rows = 0 + for block_size in block_sizes: + block = list(range(total_rows, total_rows + block_size)) + blocks.append(ray.put(block)) + metadata.append(BlockAccessor.for_block(block).get_metadata(None)) + total_rows += block_size + block_list = BlockList(blocks, metadata) + ds = Dataset(block_list) + + splits = ds.split(num_splits, equal=True) + split_counts = [split.count() for split in splits] + assert len(split_counts) == num_splits + expected_block_size = total_rows // num_splits + # Check that all splits are the expected size. + assert all([count == expected_block_size for count in split_counts]) + expected_total_rows = sum(split_counts) + # Check that the expected number of rows were dropped. + assert total_rows - expected_total_rows == total_rows % num_splits + # Check that all rows are unique (content check). + split_rows = [row for split in splits for row in split.take(total_rows)] + assert len(set(split_rows)) == len(split_rows) + + +def test_equal_split_balanced_grid(ray_start_regular_shared): + + # Tests balanced equal splitting over a grid of configurations. + # Grid: num_blocks x num_splits x num_rows_block_1 x ... x num_rows_block_n + seed = int(time.time()) + print(f"Seeding RNG for test_equal_split_balanced_grid with: {seed}") + random.seed(seed) + max_num_splits = 20 + num_splits_samples = 5 + max_num_blocks = 50 + max_num_rows_per_block = 100 + num_blocks_samples = 5 + block_sizes_samples = 5 + for num_splits in np.random.randint( + 2, max_num_splits + 1, size=num_splits_samples): + for num_blocks in np.random.randint( + 1, max_num_blocks + 1, size=num_blocks_samples): + block_sizes_list = [ + np.random.randint( + 1, max_num_rows_per_block + 1, size=num_blocks) + for _ in range(block_sizes_samples) + ] + for block_sizes in block_sizes_list: + if sum(block_sizes) < num_splits: + min_ = math.ceil(num_splits / num_blocks) + block_sizes = np.random.randint( + min_, max_num_rows_per_block + 1, size=num_blocks) + _test_equal_split_balanced(block_sizes, num_splits) + + @pytest.mark.parametrize("pipelined", [False, True]) def test_basic(ray_start_regular_shared, pipelined): ds0 = ray.data.range(5) @@ -195,30 +296,15 @@ def test_batch_tensors(ray_start_regular_shared): def test_tensors(ray_start_regular_shared): # Create directly. ds = ray.data.range_tensor(5, shape=(3, 5)) - assert str(ds) == ("Dataset(num_blocks=5, num_rows=5, " - "schema=)") - - # Transform. - ds = ds.map_batches(lambda t: np.expand_dims(t, 3)) - assert str(ds) == ("Dataset(num_blocks=5, num_rows=5, " - "schema=)") + assert str(ds) == ( + "Dataset(num_blocks=5, num_rows=5, " + "schema={value: })") # Pandas conversion. res = ray.data.range_tensor(10).map_batches( lambda t: t + 2, batch_format="pandas").take(2) - assert str(res) == "[ArrowRow({'0': 2}), ArrowRow({'0': 3})]", res - - # From other formats. - ds = ray.data.range(10).map_batches(lambda x: np.array(x)) - assert str(ds) == ("Dataset(num_blocks=10, num_rows=10, " - "schema=)") - ds = ray.data.range(10).map(lambda x: np.array(x)) - assert str(ds) == ("Dataset(num_blocks=10, num_rows=10, " - "schema=)") - ds = ray.data.from_items([np.zeros(shape=(2, 2, 2)) for _ in range(4)]) - assert str(ds) == ( - "Dataset(num_blocks=4, num_rows=4, " - "schema=)"), ds + assert str(res) == \ + "[ArrowRow({'value': array([2])}), ArrowRow({'value': array([3])})]" def test_tensor_array_ops(ray_start_regular_shared): @@ -308,7 +394,7 @@ def test_tensors_in_tables_from_pandas(ray_start_regular_shared): df = pd.DataFrame({"one": list(range(outer_dim)), "two": list(arr)}) # Cast column to tensor extension dtype. df["two"] = df["two"].astype(TensorDtype()) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) values = [[s["one"], s["two"]] for s in ds.take()] expected = list(zip(list(range(outer_dim)), arr)) for v, e in zip(sorted(values), expected): @@ -322,8 +408,8 @@ def test_tensors_in_tables_pandas_roundtrip(ray_start_regular_shared): num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape) df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)}) - ds = ray.data.from_pandas([ray.put(df)]) - ds_df = ray.get(ds.to_pandas())[0] + ds = ray.data.from_pandas([df]) + ds_df = ds.to_pandas() assert ds_df.equals(df) @@ -335,7 +421,7 @@ def test_tensors_in_tables_parquet_roundtrip(ray_start_regular_shared, num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape) df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)}) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) ds = ray.data.read_parquet(str(tmp_path)) values = [[s["one"], s["two"]] for s in ds.take()] @@ -352,7 +438,7 @@ def test_tensors_in_tables_parquet_with_schema(ray_start_regular_shared, num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape) df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)}) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) schema = pa.schema([ ("one", pa.int32()), @@ -378,7 +464,7 @@ def test_tensors_in_tables_parquet_pickle_manual_serde( "one": list(range(outer_dim)), "two": [pickle.dumps(a) for a in arr] }) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) ds = ray.data.read_parquet(str(tmp_path)) @@ -421,7 +507,7 @@ def test_tensors_in_tables_parquet_bytes_manual_serde(ray_start_regular_shared, "one": list(range(outer_dim)), "two": [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) ds = ray.data.read_parquet(str(tmp_path)) @@ -460,7 +546,7 @@ def test_tensors_in_tables_parquet_bytes_manual_serde_udf( "one": list(range(outer_dim)), tensor_col_name: [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) # Manually deserialize the tensor bytes and cast to a TensorArray. @@ -499,7 +585,7 @@ def test_tensors_in_tables_parquet_bytes_manual_serde_col_schema( "one": list(range(outer_dim)), tensor_col_name: [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) def _block_udf(block: pa.Table): @@ -536,7 +622,7 @@ def test_tensors_in_tables_parquet_bytes_with_schema(ray_start_regular_shared, "one": list(range(outer_dim)), "two": [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) schema = pa.schema([ ("one", pa.int32()), @@ -574,7 +660,7 @@ def test_tensors_in_tables_to_torch(ray_start_regular_shared, pipelined): "label": [4.0, 5.0, 6.0] }) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ds = maybe_pipeline(ds, pipelined) torchd = ds.to_torch(label_column="label", batch_size=2) @@ -614,7 +700,7 @@ def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined): "label": TensorArray(arr2), }) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ds = maybe_pipeline(ds, pipelined) tfd = ds.to_tf( label_column="label", @@ -639,13 +725,11 @@ def test_numpy_roundtrip(ray_start_regular_shared, fs, data_path): ds = ray.data.range_tensor(10, parallelism=2) ds.write_numpy(data_path, filesystem=fs) ds = ray.data.read_numpy(data_path, filesystem=fs) - assert str(ds) == ("Dataset(num_blocks=2, num_rows=?, " - "schema=)") - - assert str( - ds.take()) == ("[array([0]), array([1]), array([2]), " - "array([3]), array([4]), array([5]), array([6]), " - "array([7]), array([8]), array([9])]"), ds.take() + assert str(ds) == ( + "Dataset(num_blocks=2, num_rows=None, " + "schema={value: })") + assert str(ds.take(2)) == \ + "[ArrowRow({'value': array([0])}), ArrowRow({'value': array([1])})]" def test_numpy_read(ray_start_regular_shared, tmp_path): @@ -654,13 +738,11 @@ def test_numpy_read(ray_start_regular_shared, tmp_path): np.save( os.path.join(path, "test.npy"), np.expand_dims(np.arange(0, 10), 1)) ds = ray.data.read_numpy(path) - assert str(ds) == ("Dataset(num_blocks=1, num_rows=?, " - "schema=)") - - assert str( - ds.take()) == ("[array([0]), array([1]), array([2]), " - "array([3]), array([4]), array([5]), array([6]), " - "array([7]), array([8]), array([9])]"), ds.take() + assert str(ds) == ( + "Dataset(num_blocks=1, num_rows=None, " + "schema={value: })") + assert str(ds.take(2)) == \ + "[ArrowRow({'value': array([0])}), ArrowRow({'value': array([1])})]" @pytest.mark.parametrize("fs,data_path,endpoint_url", [ @@ -682,7 +764,12 @@ def test_numpy_write(ray_start_regular_shared, fs, data_path, endpoint_url): s3 = S3FileSystem(client_kwargs={"endpoint_url": endpoint_url}) arr1 = np.load(s3.open(file_path1)) arr2 = np.load(s3.open(file_path2)) - np.testing.assert_equal(np.concatenate((arr1, arr2)), ds.take()) + assert ds.count() == 10 + assert len(arr1) == 5 + assert len(arr2) == 5 + assert arr1.sum() == 10 + assert arr2.sum() == 35 + assert str(ds.take(1)) == "[ArrowRow({'value': array([0])})]" def test_read_text(ray_start_regular_shared, tmp_path): @@ -733,6 +820,16 @@ def test_empty_dataset(ray_start_regular_shared): assert str(ds) == \ "Dataset(num_blocks=1, num_rows=0, schema=Unknown schema)" + # Test map on empty dataset. + ds = ray.data.from_items([]) + ds = ds.map(lambda x: x) + assert ds.count() == 0 + + # Test filter on empty dataset. + ds = ray.data.from_items([]) + ds = ds.filter(lambda: True) + assert ds.count() == 0 + def test_schema(ray_start_regular_shared): ds = ray.data.range(10) @@ -751,17 +848,17 @@ def test_schema(ray_start_regular_shared): def test_lazy_loading_exponential_rampup(ray_start_regular_shared): ds = ray.data.range(100, parallelism=20) - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 assert ds.take(10) == list(range(10)) - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 assert ds.take(20) == list(range(20)) - assert len(ds._blocks._blocks) == 4 + assert ds._blocks._num_computed() == 4 assert ds.take(30) == list(range(30)) - assert len(ds._blocks._blocks) == 8 + assert ds._blocks._num_computed() == 8 assert ds.take(50) == list(range(50)) - assert len(ds._blocks._blocks) == 16 + assert ds._blocks._num_computed() == 16 assert ds.take(100) == list(range(100)) - assert len(ds._blocks._blocks) == 20 + assert ds._blocks._num_computed() == 20 def test_limit(ray_start_regular_shared): @@ -834,7 +931,16 @@ def test_repartition_arrow(ray_start_regular_shared): def test_from_pandas(ray_start_regular_shared): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) + values = [(r["one"], r["two"]) for r in ds.take(6)] + rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] + assert values == rows + + +def test_from_pandas_refs(ray_start_regular_shared): + df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) + df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) + ds = ray.data.from_pandas_refs([ray.put(df1), ray.put(df2)]) values = [(r["one"], r["two"]) for r in ds.take(6)] rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] assert values == rows @@ -845,13 +951,27 @@ def test_from_numpy(ray_start_regular_shared): arr2 = np.expand_dims(np.arange(4, 8), 1) ds = ray.data.from_numpy([ray.put(arr1), ray.put(arr2)]) values = np.array(ds.take(8)) - np.testing.assert_equal(np.concatenate((arr1, arr2)), values) + for i in range(4): + assert values[i]["value"] == arr1[i] + for i in range(4, 8): + assert values[i]["value"] == arr2[i - 4] def test_from_arrow(ray_start_regular_shared): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_arrow([ + ds = ray.data.from_arrow( + [pa.Table.from_pandas(df1), + pa.Table.from_pandas(df2)]) + values = [(r["one"], r["two"]) for r in ds.take(6)] + rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] + assert values == rows + + +def test_from_arrow_refs(ray_start_regular_shared): + df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) + df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) + ds = ray.data.from_arrow_refs([ ray.put(pa.Table.from_pandas(df1)), ray.put(pa.Table.from_pandas(df2)) ]) @@ -864,20 +984,36 @@ def test_to_pandas(ray_start_regular_shared): n = 5 df = pd.DataFrame({"value": list(range(n))}) ds = ray.data.range_arrow(n) - dfds = pd.concat(ray.get(ds.to_pandas()), ignore_index=True) + dfds = ds.to_pandas() + assert df.equals(dfds) + + # Test limit. + dfds = ds.to_pandas(limit=3) + assert df[:3].equals(dfds) + + # Test limit greater than number of rows. + dfds = ds.to_pandas(limit=6) + assert df.equals(dfds) + + +def test_to_pandas_refs(ray_start_regular_shared): + n = 5 + df = pd.DataFrame({"value": list(range(n))}) + ds = ray.data.range_arrow(n) + dfds = pd.concat(ray.get(ds.to_pandas_refs()), ignore_index=True) assert df.equals(dfds) def test_to_numpy(ray_start_regular_shared): # Tensor Dataset ds = ray.data.range_tensor(10, parallelism=2) - arr = np.concatenate(ray.get(ds.to_numpy())) + arr = np.concatenate(ray.get(ds.to_numpy(column="value"))) np.testing.assert_equal(arr, np.expand_dims(np.arange(0, 10), 1)) # Table Dataset ds = ray.data.range_arrow(10) - arr = np.concatenate(ray.get(ds.to_numpy())) - np.testing.assert_equal(arr, np.expand_dims(np.arange(0, 10), 1)) + arr = np.concatenate(ray.get(ds.to_numpy(column="value"))) + np.testing.assert_equal(arr, np.arange(0, 10)) # Simple Dataset ds = ray.data.range(10) @@ -888,23 +1024,41 @@ def test_to_numpy(ray_start_regular_shared): def test_to_arrow(ray_start_regular_shared): n = 5 + # Zero-copy. + df = pd.DataFrame({"value": list(range(n))}) + ds = ray.data.range_arrow(n) + dfds = pd.concat([t.to_pandas() for t in ds.to_arrow()], ignore_index=True) + assert df.equals(dfds) + + # Conversion. + df = pd.DataFrame({0: list(range(n))}) + ds = ray.data.range(n) + dfds = pd.concat([t.to_pandas() for t in ds.to_arrow()], ignore_index=True) + assert df.equals(dfds) + + +def test_to_arrow_refs(ray_start_regular_shared): + n = 5 + # Zero-copy. df = pd.DataFrame({"value": list(range(n))}) ds = ray.data.range_arrow(n) dfds = pd.concat( - [t.to_pandas() for t in ray.get(ds.to_arrow())], ignore_index=True) + [t.to_pandas() for t in ray.get(ds.to_arrow_refs())], + ignore_index=True) assert df.equals(dfds) # Conversion. df = pd.DataFrame({0: list(range(n))}) ds = ray.data.range(n) dfds = pd.concat( - [t.to_pandas() for t in ray.get(ds.to_arrow())], ignore_index=True) + [t.to_pandas() for t in ray.get(ds.to_arrow_refs())], + ignore_index=True) assert df.equals(dfds) -def test_get_blocks(ray_start_regular_shared): - blocks = ray.data.range(10).get_blocks() +def test_get_internal_block_refs(ray_start_regular_shared): + blocks = ray.data.range(10).get_internal_block_refs() assert len(blocks) == 10 out = [] for b in ray.get(blocks): @@ -916,9 +1070,9 @@ def test_get_blocks(ray_start_regular_shared): def test_pandas_roundtrip(ray_start_regular_shared, tmp_path): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) - dfds = pd.concat(ray.get(ds.to_pandas())) - assert pd.concat([df1, df2]).equals(dfds) + ds = ray.data.from_pandas([df1, df2]) + dfds = ds.to_pandas() + assert pd.concat([df1, df2], ignore_index=True).equals(dfds) def test_fsspec_filesystem(ray_start_regular_shared, tmp_path): @@ -942,7 +1096,7 @@ def test_fsspec_filesystem(ray_start_regular_shared, tmp_path): ds = ray.data.read_parquet([path1, path2], filesystem=fs) # Test metadata-only parquet ops. - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 assert ds.count() == 6 out_path = os.path.join(tmp_path, "out") @@ -981,7 +1135,7 @@ def test_parquet_read(ray_start_regular_shared, fs, data_path): ds = ray.data.read_parquet(data_path, filesystem=fs) # Test metadata-only parquet ops. - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 assert ds.count() == 6 assert ds.size_bytes() > 0 assert ds.schema() is not None @@ -995,11 +1149,11 @@ def test_parquet_read(ray_start_regular_shared, fs, data_path): assert repr(ds) == \ "Dataset(num_blocks=2, num_rows=6, " \ "schema={one: int64, two: string})", ds - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 # Forces a data read. values = [[s["one"], s["two"]] for s in ds.take()] - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 assert sorted(values) == [[1, "a"], [2, "b"], [3, "c"], [4, "e"], [5, "f"], [6, "g"]] @@ -1030,7 +1184,7 @@ def test_parquet_read_partitioned(ray_start_regular_shared, fs, data_path): ds = ray.data.read_parquet(data_path, filesystem=fs) # Test metadata-only parquet ops. - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 assert ds.count() == 6 assert ds.size_bytes() > 0 assert ds.schema() is not None @@ -1044,11 +1198,11 @@ def test_parquet_read_partitioned(ray_start_regular_shared, fs, data_path): "Dataset(num_blocks=2, num_rows=6, " \ "schema={two: string, " \ "one: dictionary})", ds - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 # Forces a data read. values = [[s["one"], s["two"]] for s in ds.take()] - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 assert sorted(values) == [[1, "a"], [1, "b"], [1, "c"], [3, "e"], [3, "f"], [3, "g"]] @@ -1077,7 +1231,7 @@ def test_parquet_read_partitioned_with_filter(ray_start_regular_shared, str(tmp_path), parallelism=1, filter=(pa.dataset.field("two") == "a")) values = [[s["one"], s["two"]] for s in ds.take()] - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 assert sorted(values) == [[1, "a"], [1, "a"]] # 2 partitions, 1 empty partition, 2 block/read tasks, 1 empty block @@ -1086,7 +1240,7 @@ def test_parquet_read_partitioned_with_filter(ray_start_regular_shared, str(tmp_path), parallelism=2, filter=(pa.dataset.field("two") == "a")) values = [[s["one"], s["two"]] for s in ds.take()] - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 assert sorted(values) == [[1, "a"], [1, "a"]] @@ -1114,7 +1268,7 @@ def _block_udf(block: pa.Table): str(tmp_path), parallelism=1, _block_udf=_block_udf) ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()]) - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 np.testing.assert_array_equal(sorted(ones), np.array(one_data) + 1) # 2 blocks/read tasks @@ -1123,7 +1277,7 @@ def _block_udf(block: pa.Table): str(tmp_path), parallelism=2, _block_udf=_block_udf) ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()]) - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 np.testing.assert_array_equal(sorted(ones), np.array(one_data) + 1) # 2 blocks/read tasks, 1 empty block @@ -1135,7 +1289,7 @@ def _block_udf(block: pa.Table): _block_udf=_block_udf) ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()]) - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 np.testing.assert_array_equal(sorted(ones), np.array(one_data[:2]) + 1) @@ -1152,7 +1306,7 @@ def test_parquet_write(ray_start_regular_shared, fs, data_path, endpoint_url): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) path = os.path.join(data_path, "test_parquet_dir") if fs is None: os.mkdir(path) @@ -1187,7 +1341,7 @@ def test_parquet_write_create_dir(ray_start_regular_shared, fs, data_path, df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) path = os.path.join(data_path, "test_parquet_dir") ds._set_uuid("data") ds.write_parquet(path, filesystem=fs) @@ -1241,7 +1395,7 @@ def test_parquet_write_with_udf(ray_start_regular_shared, tmp_path): df1 = pd.DataFrame({"one": one_data[:3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": one_data[3:], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) def _block_udf(block: pa.Table): df = block.to_pandas() @@ -1266,7 +1420,7 @@ def _block_udf(block: pa.Table): def test_parquet_roundtrip(ray_start_regular_shared, fs, data_path): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ds._set_uuid("data") path = os.path.join(data_path, "test_parquet_dir") if fs is None: @@ -1275,8 +1429,8 @@ def test_parquet_roundtrip(ray_start_regular_shared, fs, data_path): fs.create_dir(_unwrap_protocol(path)) ds.write_parquet(path, filesystem=fs) ds2 = ray.data.read_parquet(path, parallelism=2, filesystem=fs) - ds2df = pd.concat(ray.get(ds2.to_pandas())) - assert pd.concat([df1, df2]).equals(ds2df) + ds2df = ds2.to_pandas() + assert pd.concat([df1, df2], ignore_index=True).equals(ds2df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() == meta.size_bytes @@ -1357,9 +1511,7 @@ def test_iter_batches_basic(ray_start_regular_shared): df3 = pd.DataFrame({"one": [7, 8, 9], "two": [8, 9, 10]}) df4 = pd.DataFrame({"one": [10, 11, 12], "two": [11, 12, 13]}) dfs = [df1, df2, df3, df4] - ds = ray.data.from_pandas( - [ray.put(df1), ray.put(df2), - ray.put(df3), ray.put(df4)]) + ds = ray.data.from_pandas(dfs) # Default. for batch, df in zip(ds.iter_batches(batch_format="pandas"), dfs): @@ -1469,7 +1621,7 @@ def test_iter_batches_grid(ray_start_regular_shared): })) running_size += block_size num_rows = running_size - ds = ray.data.from_pandas([ray.put(df) for df in dfs]) + ds = ray.data.from_pandas(dfs) for batch_size in np.random.randint( 1, num_rows + 1, size=batch_size_samples): for drop_last in (False, True): @@ -1485,10 +1637,7 @@ def test_iter_batches_grid(ray_start_regular_shared): # Concatenated batches should equal the DataFrame # representation of the entire dataset. assert pd.concat( - batches, ignore_index=True).equals( - pd.concat( - ray.get(ds.to_pandas()), - ignore_index=True)) + batches, ignore_index=True).equals(ds.to_pandas()) else: # Number of batches should be equal to # num_rows / batch_size, rounded down. @@ -1498,9 +1647,8 @@ def test_iter_batches_grid(ray_start_regular_shared): # remainder sliced off. assert pd.concat( batches, ignore_index=True).equals( - pd.concat( - ray.get(ds.to_pandas()), ignore_index=True) - [:batch_size * (num_rows // batch_size)]) + ds.to_pandas()[:batch_size * + (num_rows // batch_size)]) if num_rows % batch_size == 0 or drop_last: assert all( len(batch) == batch_size for batch in batches) @@ -1515,7 +1663,7 @@ def test_lazy_loading_iter_batches_exponential_rampup( ds = ray.data.range(32, parallelism=8) expected_num_blocks = [1, 2, 4, 4, 8, 8, 8, 8] for _, expected in zip(ds.iter_batches(), expected_num_blocks): - assert len(ds._blocks._blocks) == expected + assert ds._blocks._num_computed() == expected def test_map_batch(ray_start_regular_shared, tmp_path): @@ -1769,7 +1917,7 @@ def test_from_dask(ray_start_regular_shared): df = pd.DataFrame({"one": list(range(100)), "two": list(range(100))}) ddf = dd.from_pandas(df, npartitions=10) ds = ray.data.from_dask(ddf) - dfds = pd.concat(ray.get(ds.to_pandas())) + dfds = ds.to_pandas() assert df.equals(dfds) @@ -1778,7 +1926,7 @@ def test_to_dask(ray_start_regular_shared): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ddf = ds.to_dask() # Explicit Dask-on-Ray assert df.equals(ddf.compute(scheduler=ray_dask_get)) @@ -1791,7 +1939,7 @@ def test_from_modin(ray_start_regular_shared): df = pd.DataFrame({"one": list(range(100)), "two": list(range(100))}, ) modf = mopd.DataFrame(df) ds = ray.data.from_modin(modf) - dfds = pd.concat(ray.get(ds.to_pandas())) + dfds = ds.to_pandas() assert df.equals(dfds) @@ -1823,7 +1971,7 @@ def test_to_tf(ray_start_regular_shared, pipelined): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) + ds = ray.data.from_pandas([df1, df2, df3]) ds = maybe_pipeline(ds, pipelined) tfd = ds.to_tf( label_column="label", @@ -1851,7 +1999,7 @@ def test_to_tf_feature_columns(ray_start_regular_shared): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]).drop("two", axis=1) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) + ds = ray.data.from_pandas([df1, df2, df3]) tfd = ds.to_tf( label_column="label", feature_columns=["one"], @@ -1880,7 +2028,7 @@ def test_to_torch(ray_start_regular_shared, pipelined): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) + ds = ray.data.from_pandas([df1, df2, df3]) ds = maybe_pipeline(ds, pipelined) torchd = ds.to_torch(label_column="label", batch_size=3) @@ -1907,7 +2055,7 @@ def test_to_torch_feature_columns(ray_start_regular_shared): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]).drop("two", axis=1) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) + ds = ray.data.from_pandas([df1, df2, df3]) torchd = ds.to_torch( label_column="label", feature_columns=["one"], batch_size=3) iterations = [] @@ -1934,7 +2082,7 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df1.to_json( path1, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json(path1, filesystem=fs) - dsdf = ray.get(ds.to_pandas())[0] + dsdf = ds.to_pandas() assert df1.equals(dsdf) # Test metadata ops. assert ds.count() == 3 @@ -1947,8 +2095,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df2.to_json( path2, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json([path1, path2], parallelism=2, filesystem=fs) - dsdf = pd.concat(ray.get(ds.to_pandas())) - df = pd.concat([df1, df2]) + dsdf = ds.to_pandas() + df = pd.concat([df1, df2], ignore_index=True) assert df.equals(dsdf) # Test metadata ops. for block, meta in zip(ds._blocks, ds._blocks.get_metadata()): @@ -1962,7 +2110,7 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): ds = ray.data.read_json( [path1, path2, path3], parallelism=2, filesystem=fs) df = pd.concat([df1, df2, df3], ignore_index=True) - dsdf = pd.concat(ray.get(ds.to_pandas()), ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) # Directory, two files. @@ -1980,8 +2128,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df2.to_json( path2, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json(path, filesystem=fs) - df = pd.concat([df1, df2]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(path) @@ -2019,8 +2167,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): lines=True, storage_options=storage_options) ds = ray.data.read_json([path1, path2], filesystem=fs) - df = pd.concat([df1, df2, df3]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2, df3], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(path1) @@ -2044,8 +2192,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df2.to_json( path2, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json([dir_path, path2], filesystem=fs) - df = pd.concat([df1, df2]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(dir_path) @@ -2059,7 +2207,7 @@ def test_zipped_json_read(ray_start_regular_shared, tmp_path): path1 = os.path.join(tmp_path, "test1.json.gz") df1.to_json(path1, compression="gzip", orient="records", lines=True) ds = ray.data.read_json(path1) - assert df1.equals(ray.get(ds.to_pandas())[0]) + assert df1.equals(ds.to_pandas()) # Test metadata ops. assert ds.count() == 3 assert ds.input_files() == [path1] @@ -2069,8 +2217,8 @@ def test_zipped_json_read(ray_start_regular_shared, tmp_path): path2 = os.path.join(tmp_path, "test2.json.gz") df2.to_json(path2, compression="gzip", orient="records", lines=True) ds = ray.data.read_json([path1, path2], parallelism=2) - dsdf = pd.concat(ray.get(ds.to_pandas())) - assert pd.concat([df1, df2]).equals(dsdf) + dsdf = ds.to_pandas() + assert pd.concat([df1, df2], ignore_index=True).equals(dsdf) # Test metadata ops. for block, meta in zip(ds._blocks, ds._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() @@ -2085,8 +2233,8 @@ def test_zipped_json_read(ray_start_regular_shared, tmp_path): path2 = os.path.join(tmp_path, "data1.json.gz") df2.to_json(path2, compression="gzip", orient="records", lines=True) ds = ray.data.read_json([dir_path, path2]) - df = pd.concat([df1, df2]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) shutil.rmtree(dir_path) @@ -2103,7 +2251,7 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url): storage_options = dict(client_kwargs=dict(endpoint_url=endpoint_url)) # Single block. df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([ray.put(df1)]) + ds = ray.data.from_pandas([df1]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.json") @@ -2116,7 +2264,7 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) file_path2 = os.path.join(data_path, "data_000001.json") @@ -2143,12 +2291,12 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url): def test_json_roundtrip(ray_start_regular_shared, fs, data_path): # Single block. df = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.json") ds2 = ray.data.read_json([file_path], filesystem=fs) - ds2df = pd.concat(ray.get(ds2.to_pandas())) + ds2df = ds2.to_pandas() assert ds2df.equals(df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): @@ -2161,12 +2309,12 @@ def test_json_roundtrip(ray_start_regular_shared, fs, data_path): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df), ray.put(df2)]) + ds = ray.data.from_pandas([df, df2]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) ds2 = ray.data.read_json(data_path, parallelism=2, filesystem=fs) - ds2df = pd.concat(ray.get(ds2.to_pandas())) - assert pd.concat([df, df2]).equals(ds2df) + ds2df = ds2.to_pandas() + assert pd.concat([df, df2], ignore_index=True).equals(ds2df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() == meta.size_bytes @@ -2190,7 +2338,7 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path1 = os.path.join(data_path, "test1.csv") df1.to_csv(path1, index=False, storage_options=storage_options) ds = ray.data.read_csv(path1, filesystem=fs) - dsdf = ray.get(ds.to_pandas())[0] + dsdf = ds.to_pandas() assert df1.equals(dsdf) # Test metadata ops. assert ds.count() == 3 @@ -2202,8 +2350,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path2 = os.path.join(data_path, "test2.csv") df2.to_csv(path2, index=False, storage_options=storage_options) ds = ray.data.read_csv([path1, path2], parallelism=2, filesystem=fs) - dsdf = pd.concat(ray.get(ds.to_pandas())) - df = pd.concat([df1, df2]) + dsdf = ds.to_pandas() + df = pd.concat([df1, df2], ignore_index=True) assert df.equals(dsdf) # Test metadata ops. for block, meta in zip(ds._blocks, ds._blocks.get_metadata()): @@ -2215,7 +2363,7 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): df3.to_csv(path3, index=False, storage_options=storage_options) ds = ray.data.read_csv([path1, path2, path3], parallelism=2, filesystem=fs) df = pd.concat([df1, df2, df3], ignore_index=True) - dsdf = pd.concat(ray.get(ds.to_pandas()), ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) # Directory, two files. @@ -2231,8 +2379,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path2 = os.path.join(path, "data1.csv") df2.to_csv(path2, index=False, storage_options=storage_options) ds = ray.data.read_csv(path, filesystem=fs) - df = pd.concat([df1, df2]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(path) @@ -2258,8 +2406,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): file_path3 = os.path.join(path2, "data2.csv") df3.to_csv(file_path3, index=False, storage_options=storage_options) ds = ray.data.read_csv([path1, path2], filesystem=fs) - df = pd.concat([df1, df2, df3]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2, df3], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(path1) @@ -2281,8 +2429,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path2 = os.path.join(data_path, "data1.csv") df2.to_csv(path2, index=False, storage_options=storage_options) ds = ray.data.read_csv([dir_path, path2], filesystem=fs) - df = pd.concat([df1, df2]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(dir_path) @@ -2302,7 +2450,7 @@ def test_csv_write(ray_start_regular_shared, fs, data_path, endpoint_url): storage_options = dict(client_kwargs=dict(endpoint_url=endpoint_url)) # Single block. df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([ray.put(df1)]) + ds = ray.data.from_pandas([df1]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.csv") @@ -2310,7 +2458,7 @@ def test_csv_write(ray_start_regular_shared, fs, data_path, endpoint_url): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) file_path2 = os.path.join(data_path, "data_000001.csv") @@ -2329,12 +2477,12 @@ def test_csv_write(ray_start_regular_shared, fs, data_path, endpoint_url): def test_csv_roundtrip(ray_start_regular_shared, fs, data_path): # Single block. df = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.csv") ds2 = ray.data.read_csv([file_path], filesystem=fs) - ds2df = pd.concat(ray.get(ds2.to_pandas())) + ds2df = ds2.to_pandas() assert ds2df.equals(df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): @@ -2342,12 +2490,12 @@ def test_csv_roundtrip(ray_start_regular_shared, fs, data_path): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df), ray.put(df2)]) + ds = ray.data.from_pandas([df, df2]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) ds2 = ray.data.read_csv(data_path, parallelism=2, filesystem=fs) - ds2df = pd.concat(ray.get(ds2.to_pandas())) - assert pd.concat([df, df2]).equals(ds2df) + ds2df = ds2.to_pandas() + assert pd.concat([df, df2], ignore_index=True).equals(ds2df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() == meta.size_bytes @@ -2365,13 +2513,21 @@ def test_sort_simple(ray_start_regular_shared): assert ds.sort(key=lambda x: -x).take(num_items) == list( reversed(range(num_items))) + # Test empty dataset. + ds = ray.data.from_items([]) + s1 = ds.sort() + assert s1.count() == 0 + assert s1 == ds + @pytest.mark.parametrize("pipelined", [False, True]) def test_random_shuffle(shutdown_only, pipelined): def range(n, parallelism=200): ds = ray.data.range(n, parallelism=parallelism) if pipelined: - return ds.repeat(2) + pipe = ds.repeat(2) + pipe.random_shuffle = pipe.random_shuffle_each_window + return pipe else: return ds @@ -2416,6 +2572,12 @@ def range(n, parallelism=200): r2 = range(100).random_shuffle(_move=True).take(999) assert r1 != r2, (r1, r2) + # Test empty dataset. + ds = ray.data.from_items([]) + r1 = ds.random_shuffle() + assert r1.count() == 0 + assert r1 == ds + def test_random_shuffle_spread(ray_start_cluster): cluster = ray_start_cluster @@ -2437,7 +2599,7 @@ def get_node_id(): ds = ray.data.range( 100, parallelism=2).random_shuffle(_spread_resource_prefix="bar:") - blocks = ds.get_blocks() + blocks = ds.get_internal_block_refs() ray.wait(blocks, num_returns=len(blocks), fetch_local=False) location_data = ray.experimental.get_object_locations(blocks) locations = [] @@ -2478,7 +2640,7 @@ def get_node_id(): ds = ray.data.read_parquet(data_path, _spread_resource_prefix="bar:") # Force reads. - blocks = ds.get_blocks() + blocks = ds.get_internal_block_refs() assert len(blocks) == 2 ray.wait(blocks, num_returns=len(blocks), fetch_local=False) @@ -2505,7 +2667,7 @@ def test_sort_arrow(ray_start_regular, num_items, parallelism): offset += shard if offset < num_items: dfs.append(pd.DataFrame({"a": a[offset:], "b": b[offset:]})) - ds = ray.data.from_pandas([ray.put(df) for df in dfs]) + ds = ray.data.from_pandas(dfs) def assert_sorted(sorted_ds, expected_rows): assert [tuple(row.values()) @@ -2535,7 +2697,7 @@ def __init__(self): def _read_file(self, f: "pa.NativeFile", path: str, **reader_args): count = self.counter.increment.remote() if ray.get(count) == 1: - raise ValueError() + raise ValueError("oops") else: return CSVDatasource._read_file(self, f, path, **reader_args) @@ -2543,7 +2705,7 @@ def _write_block(self, f: "pa.NativeFile", block: BlockAccessor, **writer_args): count = self.counter.increment.remote() if ray.get(count) == 1: - raise ValueError() + raise ValueError("oops") else: CSVDatasource._write_block(self, f, block, **writer_args) @@ -2563,7 +2725,7 @@ def _write_block(self, f: "pa.NativeFile", block: BlockAccessor, def flaky_mapper(x): count = counter.increment.remote() if ray.get(count) == 1: - raise ValueError() + raise ValueError("oops") else: return ray.get(count) diff --git a/python/ray/data/tests/test_dataset_pipeline.py b/python/ray/data/tests/test_dataset_pipeline.py index b199374f80437..cffb378f36861 100644 --- a/python/ray/data/tests/test_dataset_pipeline.py +++ b/python/ray/data/tests/test_dataset_pipeline.py @@ -30,14 +30,14 @@ def block_on_ones(x: int) -> int: time.sleep(999999) return x - pipe = ray.data.range(2).pipeline(parallelism=1) + pipe = ray.data.range(2).window(blocks_per_window=1) pipe = pipe.map(block_on_ones) assert pipe.take(1) == [0] def test_cannot_read_twice(ray_start_regular_shared): ds = ray.data.range(10) - pipe = ds.pipeline(parallelism=1) + pipe = ds.window(blocks_per_window=1) assert pipe.count() == 10 with pytest.raises(RuntimeError): pipe.count() @@ -52,25 +52,70 @@ def test_cannot_read_twice(ray_start_regular_shared): def test_basic_pipeline(ray_start_regular_shared): ds = ray.data.range(10) - pipe = ds.pipeline(parallelism=1) - assert str(pipe) == "DatasetPipeline(length=10, num_stages=1)" + pipe = ds.window(blocks_per_window=1) + assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" assert pipe.count() == 10 - pipe = ds.pipeline(parallelism=1).map(lambda x: x).map(lambda x: x) - assert str(pipe) == "DatasetPipeline(length=10, num_stages=3)" + pipe = ds.window(blocks_per_window=1).map(lambda x: x).map(lambda x: x) + assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=3)" assert pipe.take() == list(range(10)) - pipe = ds.pipeline(parallelism=999) - assert str(pipe) == "DatasetPipeline(length=1, num_stages=1)" + pipe = ds.window(blocks_per_window=999) + assert str(pipe) == "DatasetPipeline(num_windows=1, num_stages=1)" assert pipe.count() == 10 pipe = ds.repeat(10) - assert str(pipe) == "DatasetPipeline(length=10, num_stages=1)" + assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" assert pipe.count() == 100 pipe = ds.repeat(10) assert pipe.sum() == 450 +def test_window(ray_start_regular_shared): + ds = ray.data.range(10) + pipe = ds.window(blocks_per_window=1) + assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" + pipe = pipe.rewindow(blocks_per_window=3) + assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)" + datasets = list(pipe.iter_datasets()) + assert len(datasets) == 4 + assert datasets[0].take() == [0, 1, 2] + assert datasets[1].take() == [3, 4, 5] + assert datasets[2].take() == [6, 7, 8] + assert datasets[3].take() == [9] + + ds = ray.data.range(10) + pipe = ds.window(blocks_per_window=5) + assert str(pipe) == "DatasetPipeline(num_windows=2, num_stages=1)" + pipe = pipe.rewindow(blocks_per_window=3) + assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)" + datasets = list(pipe.iter_datasets()) + assert len(datasets) == 4 + assert datasets[0].take() == [0, 1, 2] + assert datasets[1].take() == [3, 4, 5] + assert datasets[2].take() == [6, 7, 8] + assert datasets[3].take() == [9] + + +def test_repeat(ray_start_regular_shared): + ds = ray.data.range(5) + pipe = ds.window(blocks_per_window=1) + assert str(pipe) == "DatasetPipeline(num_windows=5, num_stages=1)" + pipe = pipe.repeat(2) + assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" + assert pipe.take() == (list(range(5)) + list(range(5))) + + ds = ray.data.range(5) + pipe = ds.window(blocks_per_window=1) + pipe = pipe.repeat() + assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)" + assert len(pipe.take(99)) == 99 + + pipe = ray.data.range(5).repeat() + with pytest.raises(ValueError): + pipe.repeat() + + def test_from_iterable(ray_start_regular_shared): pipe = DatasetPipeline.from_iterable( [lambda: ray.data.range(3), lambda: ray.data.range(2)]) @@ -80,7 +125,7 @@ def test_from_iterable(ray_start_regular_shared): def test_repeat_forever(ray_start_regular_shared): ds = ray.data.range(10) pipe = ds.repeat() - assert str(pipe) == "DatasetPipeline(length=None, num_stages=1)" + assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)" for i, v in enumerate(pipe.iter_rows()): assert v == i % 10, (v, i, i % 10) if i > 1000: @@ -89,38 +134,38 @@ def test_repeat_forever(ray_start_regular_shared): def test_repartition(ray_start_regular_shared): pipe = ray.data.range(10).repeat(10) - assert pipe.repartition(1).sum() == 450 + assert pipe.repartition_each_window(1).sum() == 450 pipe = ray.data.range(10).repeat(10) - assert pipe.repartition(10).sum() == 450 + assert pipe.repartition_each_window(10).sum() == 450 pipe = ray.data.range(10).repeat(10) - assert pipe.repartition(100).sum() == 450 + assert pipe.repartition_each_window(100).sum() == 450 def test_iter_batches(ray_start_regular_shared): - pipe = ray.data.range(10).pipeline(parallelism=2) + pipe = ray.data.range(10).window(blocks_per_window=2) batches = list(pipe.iter_batches()) assert len(batches) == 10 assert all(len(e) == 1 for e in batches) def test_iter_datasets(ray_start_regular_shared): - pipe = ray.data.range(10).pipeline(parallelism=2) + pipe = ray.data.range(10).window(blocks_per_window=2) ds = list(pipe.iter_datasets()) assert len(ds) == 5 - pipe = ray.data.range(10).pipeline(parallelism=5) + pipe = ray.data.range(10).window(blocks_per_window=5) ds = list(pipe.iter_datasets()) assert len(ds) == 2 -def test_foreach_dataset(ray_start_regular_shared): - pipe = ray.data.range(5).pipeline(parallelism=2) - pipe = pipe.foreach_dataset(lambda ds: ds.map(lambda x: x * 2)) +def test_foreach_window(ray_start_regular_shared): + pipe = ray.data.range(5).window(blocks_per_window=2) + pipe = pipe.foreach_window(lambda ds: ds.map(lambda x: x * 2)) assert pipe.take() == [0, 2, 4, 6, 8] def test_schema(ray_start_regular_shared): - pipe = ray.data.range(5).pipeline(parallelism=2) + pipe = ray.data.range(5).window(blocks_per_window=2) assert pipe.schema() == int @@ -178,8 +223,8 @@ def test_parquet_write(ray_start_regular_shared, tmp_path): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) - ds = ds.pipeline(parallelism=1) + ds = ray.data.from_pandas([df1, df2]) + ds = ds.window(blocks_per_window=1) path = os.path.join(tmp_path, "test_parquet_dir") os.mkdir(path) ds._set_uuid("data") diff --git a/python/ray/data/tests/test_raydp_dataset.py b/python/ray/data/tests/test_raydp_dataset.py index c86c6a0803c13..c23b672f97e38 100644 --- a/python/ray/data/tests/test_raydp_dataset.py +++ b/python/ray/data/tests/test_raydp_dataset.py @@ -16,6 +16,10 @@ def stop_all(): return spark +@pytest.mark.skip( + reason=( + "raydp.spark.spark_dataframe_to_ray_dataset needs to be updated to " + "use ray.data.from_arrow_refs.")) def test_raydp_roundtrip(spark_on_ray_small): spark = spark_on_ray_small spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 58d1706549d2a..f46be7c0a1a15 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -28,7 +28,11 @@ def from_bytes(b): ray_exception = RayException() ray_exception.ParseFromString(b) if ray_exception.language == PYTHON: - return pickle.loads(ray_exception.serialized_exception) + try: + return pickle.loads(ray_exception.serialized_exception) + except Exception as e: + msg = "Failed to unpickle serialized exception" + raise RuntimeError(msg) from e else: return CrossLanguageError(ray_exception) diff --git a/python/ray/experimental/array/remote/core.py b/python/ray/experimental/array/remote/core.py index f4572da82babe..7b6d24f75b283 100644 --- a/python/ray/experimental/array/remote/core.py +++ b/python/ray/experimental/array/remote/core.py @@ -68,8 +68,8 @@ def diag(v, k=0): @ray.remote -def transpose(a, axes=[]): - axes = None if axes == [] else axes +def transpose(a, axes=None): + axes = None if (axes == [] or axes is None) else axes return np.transpose(a, axes=axes) diff --git a/python/ray/experimental/internal_kv.py b/python/ray/experimental/internal_kv.py index e434c3cf5f979..456adabcb66ca 100644 --- a/python/ray/experimental/internal_kv.py +++ b/python/ray/experimental/internal_kv.py @@ -35,7 +35,7 @@ def _initialize_internal_kv(gcs_client: "ray._raylet.GcsClient" = None): return global_gcs_client -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_initialized(): gcs_client = _initialize_internal_kv() @@ -46,7 +46,7 @@ def _internal_kv_initialized(): return hasattr(worker, "mode") and worker.mode is not None -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_get(key: Union[str, bytes]) -> bytes: """Fetch the value of a binary key.""" gcs_client = _initialize_internal_kv() @@ -57,7 +57,7 @@ def _internal_kv_get(key: Union[str, bytes]) -> bytes: return ray.worker.global_worker.redis_client.hget(key, "value") -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_exists(key: Union[str, bytes]) -> bool: """Check key exists or not.""" gcs_client = _initialize_internal_kv() @@ -67,7 +67,7 @@ def _internal_kv_exists(key: Union[str, bytes]) -> bool: return ray.worker.global_worker.redis_client.hexists(key, "value") -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_put(key: Union[str, bytes], value: Union[str, bytes], overwrite: bool = True) -> bool: @@ -91,7 +91,7 @@ def _internal_kv_put(key: Union[str, bytes], return updated == 0 # already exists -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_del(key: Union[str, bytes]): gcs_client = _initialize_internal_kv() if gcs_client is not None: @@ -100,7 +100,7 @@ def _internal_kv_del(key: Union[str, bytes]): return ray.worker.global_worker.redis_client.delete(key) -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_list(prefix: Union[str, bytes]) -> List[bytes]: """List all keys in the internal KV store that start with the prefix. """ diff --git a/python/ray/experimental/raysort/constants.py b/python/ray/experimental/raysort/constants.py index 5ab3b2df29831..9c32b5f07330e 100644 --- a/python/ray/experimental/raysort/constants.py +++ b/python/ray/experimental/raysort/constants.py @@ -1,12 +1,15 @@ import os -from ray.experimental.raysort.types import ByteCount, RecordCount +from ray.experimental.raysort.types import ByteCount, PartId, RecordCount __DIR__ = os.path.dirname(os.path.abspath(__file__)) # Basics RECORD_SIZE = 100 # bytes +# Progress Tracker Actor +PROGRESS_TRACKER_ACTOR = "ProgressTrackerActor" + # Executable locations GENSORT_PATH = os.path.join(__DIR__, "bin/gensort/64/gensort") VALSORT_PATH = os.path.join(__DIR__, "bin/gensort/64/valsort") @@ -18,10 +21,12 @@ DATA_DIR_FMT = { "input": "{mnt}/tmp/input/", "output": "{mnt}/tmp/output/", + "temp": "{mnt}/tmp/temp/", } FILENAME_FMT = { "input": "input-{part_id:08}", "output": "output-{part_id:08}", + "temp": "temp-{part_id:08}", } # Prometheus config @@ -33,3 +38,7 @@ def bytes_to_records(n_bytes: ByteCount) -> RecordCount: assert n_bytes % RECORD_SIZE == 0 return int(n_bytes / RECORD_SIZE) + + +def merge_part_ids(reducer_id: PartId, mapper_id: PartId) -> PartId: + return reducer_id * 1_000_000 + mapper_id diff --git a/python/ray/experimental/raysort/main.py b/python/ray/experimental/raysort/main.py index 1cc8d0df1c5af..0df5bfb59ec75 100644 --- a/python/ray/experimental/raysort/main.py +++ b/python/ray/experimental/raysort/main.py @@ -1,10 +1,12 @@ import argparse +import contextlib import csv import logging import os import random import subprocess -from typing import Iterable, List +import tempfile +from typing import Callable, Dict, Iterable, List import numpy as np import ray @@ -13,14 +15,17 @@ from ray.experimental.raysort import logging_utils from ray.experimental.raysort import sortlib from ray.experimental.raysort import tracing_utils -from ray.experimental.raysort.types import BlockInfo, ByteCount, RecordCount, PartId, PartitionInfo, Path # noqa: E501 +from ray.experimental.raysort.types import \ + BlockInfo, ByteCount, RecordCount, PartId, PartInfo, Path + +Args = argparse.Namespace # ------------------------------------------------------------ # Parse Arguments # ------------------------------------------------------------ -def get_args(): +def get_args(*args, **kwargs): parser = argparse.ArgumentParser() parser.add_argument( "--ray_address", @@ -30,27 +35,39 @@ def get_args(): ) parser.add_argument( "--total_data_size", - default=1_000_000_000, + default=1 * 1000 * 1024 * 1024 * 1024, type=ByteCount, - help="partition size in bytes", + help="total data size in bytes", ) parser.add_argument( "--num_mappers", - default=4, + default=256, type=int, help="number of map tasks", ) + parser.add_argument( + "--num_mappers_per_round", + default=16, + type=int, + help="number of map tasks per first-stage merge tasks", + ) parser.add_argument( "--num_reducers", + default=16, + type=int, + help="number of second-stage reduce tasks", + ) + parser.add_argument( + "--num_concurrent_rounds", default=4, type=int, - help="number of reduce tasks", + help="max number of rounds of map/merge tasks in flight", ) parser.add_argument( - "--reducer_batch_num_records", - default=1_000_000, - type=RecordCount, - help="number of bytes to buffer before writing the output to EBS", + "--reducer_input_chunk", + default=100 * 1024 * 1024, + type=ByteCount, + help="bytes to read from each file in reduce tasks", ) parser.add_argument( "--skip_sorting", @@ -75,13 +92,13 @@ def get_args(): "tasks to run", "if no task is specified, will run all tasks") tasks = ["generate_input", "sort", "validate_output"] for task in tasks: - tasks_group.add_argument( - f"--{task}", action="store_true", help=f"run task {task}") + tasks_group.add_argument(f"--{task}", action="store_true") - args = parser.parse_args() + args = parser.parse_args(*args, **kwargs) # Derive additional arguments. args.input_part_size = ByteCount(args.total_data_size / args.num_mappers) - args.output_part_size = ByteCount(args.total_data_size / args.num_reducers) + assert args.num_mappers % args.num_mappers_per_round == 0 + args.num_rounds = int(args.num_mappers / args.num_mappers_per_round) args.mount_points = _get_mount_points() # If no tasks are specified, run all tasks. args_dict = vars(args) @@ -92,28 +109,29 @@ def get_args(): def _get_mount_points(): + default_ret = [tempfile.gettempdir()] mnt = "/mnt" - if not os.path.exists(mnt): - return [] - return [os.path.join(mnt, d) for d in os.listdir(mnt)] + if os.path.exists(mnt): + ret = [os.path.join(mnt, d) for d in os.listdir(mnt)] + if len(ret) > 0: + return ret + return default_ret -args = None - # ------------------------------------------------------------ # Generate Input # ------------------------------------------------------------ -def _make_partition_info(part_id: PartId, kind="input") -> PartitionInfo: +def _part_info(args: Args, part_id: PartId, kind="input") -> PartInfo: node = ray.worker.global_worker.node_ip_address mnt = random.choice(args.mount_points) filepath = _get_part_path(mnt, part_id, kind) - return PartitionInfo(part_id, node, filepath) + return PartInfo(part_id, node, filepath) def _get_part_path(mnt: Path, part_id: PartId, kind="input") -> Path: - assert kind in {"input", "output"} + assert kind in {"input", "output", "temp"} dir_fmt = constants.DATA_DIR_FMT[kind] dirpath = dir_fmt.format(mnt=mnt) os.makedirs(dirpath, exist_ok=True) @@ -124,26 +142,25 @@ def _get_part_path(mnt: Path, part_id: PartId, kind="input") -> Path: @ray.remote -def generate_part(part_id: PartId, size: RecordCount, - offset: RecordCount) -> PartitionInfo: +def generate_part(args: Args, part_id: PartId, size: RecordCount, + offset: RecordCount) -> PartInfo: logging_utils.init() - pinfo = _make_partition_info(part_id) - if not args.skip_input: - subprocess.run( - [constants.GENSORT_PATH, f"-b{offset}", f"{size}", pinfo.path], - check=True) - logging.info(f"Generated input {pinfo}") + pinfo = _part_info(args, part_id) + subprocess.run( + [constants.GENSORT_PATH, f"-b{offset}", f"{size}", pinfo.path], + check=True) + logging.info(f"Generated input {pinfo}") return pinfo -def generate_input(): +def generate_input(args: Args): if args.skip_input: return size = constants.bytes_to_records(args.input_part_size) offset = 0 tasks = [] for part_id in range(args.num_mappers): - tasks.append(generate_part.remote(part_id, size, offset)) + tasks.append(generate_part.remote(args, part_id, size, offset)) offset += size assert offset == constants.bytes_to_records(args.total_data_size), args logging.info(f"Generating {len(tasks)} partitions") @@ -158,22 +175,21 @@ def generate_input(): # ------------------------------------------------------------ -def _load_manifest(path: Path) -> List[PartitionInfo]: +def _load_manifest(args: Args, path: Path) -> List[PartInfo]: if args.skip_input: - return _load_dummy_manifest() + return [PartInfo(i, None, None) for i in range(args.num_mappers)] with open(path) as fin: reader = csv.reader(fin) return [ - PartitionInfo(int(part_id), node, path) + PartInfo(int(part_id), node, path) for part_id, node, path in reader ] -def _load_dummy_manifest() -> List[PartitionInfo]: - return [PartitionInfo(i, "", "") for i in range(args.num_mappers)] - - -def _load_partition(path: Path) -> np.ndarray: +def _load_partition(args: Args, path: Path) -> np.ndarray: + if args.skip_input: + return np.frombuffer( + np.random.bytes(args.input_part_size), dtype=np.uint8).copy() return np.fromfile(path, dtype=np.uint8) @@ -190,115 +206,214 @@ def _dummy_sort_and_partition(part: np.ndarray, @ray.remote -def mapper(boundaries: List[int], mapper_id: PartId, - path: Path) -> List[ray.ObjectRef]: +@tracing_utils.timeit("map") +def mapper(args: Args, mapper_id: PartId, boundaries: List[int], + path: Path) -> List[np.ndarray]: logging_utils.init() - task_id = f"M-{mapper_id} Mapper" - logging.info(f"{task_id} starting {args}") - if args.skip_input: - block_size = int(np.ceil(args.input_part_size / args.num_reducers)) - return [ - ray.put( - np.frombuffer(np.random.bytes(block_size), dtype=np.uint8)) - for _ in range(args.num_reducers) - ] - - part = _load_partition(path) + part = _load_partition(args, path) sort_fn = _dummy_sort_and_partition \ if args.skip_sorting else sortlib.sort_and_partition blocks = sort_fn(part, boundaries) - logging.info(f"{task_id} saving to object store") - return [ray.put(part[offset:offset + size]) for offset, size in blocks] + return [part[offset:offset + size] for offset, size in blocks] -def _dummy_merge(blocks: List[np.ndarray], _n: int) -> Iterable[memoryview]: - for block in blocks: +def _dummy_merge( + num_blocks: int, _n: int, + get_block: Callable[[int, int], np.ndarray]) -> Iterable[np.ndarray]: + blocks = [((i, 0), get_block(i, 0)) for i in range(num_blocks)] + while len(blocks) > 0: + (m, d), block = blocks.pop(random.randrange(len(blocks))) yield block - - -@ray.remote -def reducer(reducer_id: PartId, *blocks: List[ray.ObjectRef]) -> PartitionInfo: - logging_utils.init() - task_id = f"R-{reducer_id} Reducer" - logging.info(f"{task_id} starting") - blocks = [np.copy(ray.get(block)) for block in blocks] + d_ = d + 1 + block = get_block(m, d_) + if block is None: + continue + blocks.append(((m, d_), block)) + + +def _merge_impl(args: Args, + M: int, + pinfo: PartInfo, + get_block: Callable[[int, int], np.ndarray], + skip_output=False): merge_fn = _dummy_merge if args.skip_sorting else sortlib.merge_partitions - merger = merge_fn(blocks, args.reducer_batch_num_records) - if args.skip_output: + merger = merge_fn(M, get_block) + + if skip_output: for datachunk in merger: del datachunk - logging.info(f"{task_id} done") - return None else: - pinfo = _make_partition_info(reducer_id, "output") with open(pinfo.path, "wb") as fout: for datachunk in merger: fout.write(datachunk) - logging.info(f"{task_id} done") - return pinfo + return pinfo -@tracing_utils.timeit("sorting") -def sort_main(): - partitions = _load_manifest(constants.INPUT_MANIFEST_FILE) +# See worker_placement_groups() for why `num_cpus=0`. +@ray.remote(num_cpus=0, resources={"worker": 1}) +@tracing_utils.timeit("merge") +def merge_mapper_blocks(args: Args, reducer_id: PartId, mapper_id: PartId, + *blocks: List[np.ndarray]) -> PartInfo: + part_id = constants.merge_part_ids(reducer_id, mapper_id) + pinfo = _part_info(args, part_id, kind="temp") + M = len(blocks) + + def get_block(i, d): + if i >= M or d > 0: + return None + return blocks[i] + + return _merge_impl(args, M, pinfo, get_block) + + +# See worker_placement_groups() for why `num_cpus=0`. +@ray.remote(num_cpus=0, resources={"worker": 1}) +@tracing_utils.timeit("reduce") +def final_merge(args: Args, reducer_id: PartId, + *merged_parts: List[PartInfo]) -> PartInfo: + M = len(merged_parts) + + def _load_block_chunk(pinfo: PartInfo, d: int) -> np.ndarray: + return np.fromfile( + pinfo.path, + dtype=np.uint8, + count=args.reducer_input_chunk, + offset=d * args.reducer_input_chunk) + + def get_block(i, d): + ret = _load_block_chunk(merged_parts[i], d) + if ret.size == 0: + return None + return ret + + pinfo = _part_info(args, reducer_id, "output") + return _merge_impl(args, M, pinfo, get_block, args.skip_output) + + +def _node_res(node: str) -> Dict[str, float]: + return {"resources": {f"node:{node}": 1e-3}} + + +@contextlib.contextmanager +def worker_placement_groups(args: Args) -> List[ray.PlacementGroupID]: + """ + Returns one placement group per node with a `worker` resource. To run + tasks in the placement group, use + `@ray.remote(num_cpus=0, resources={"worker": 1})`. Ray does not + automatically reserve CPU resources, so tasks must specify `num_cpus=0` + in order to run in a placement group. + """ + pgs = [ + ray.util.placement_group([{ + "worker": 1 + }]) for _ in range(args.num_reducers) + ] + ray.get([pg.ready() for pg in pgs]) + try: + yield pgs + finally: + for pg in pgs: + ray.util.remove_placement_group(pg) + + +@tracing_utils.timeit("sort", report_time=True) +def sort_main(args: Args): + parts = _load_manifest(args, constants.INPUT_MANIFEST_FILE) + assert len(parts) == args.num_mappers boundaries = sortlib.get_boundaries(args.num_reducers) - mapper_results = np.empty( - (args.num_mappers, args.num_reducers), dtype=object) - for part_id, node, path in partitions: - opt = {} if args.skip_input else { - "resources": { - f"node:{node}": 1 / args.num_mappers - }, - "memory": args.input_part_size * 1.2, - } - opt.update(num_returns=args.num_reducers) - mapper_results[part_id, :] = mapper.options(**opt).remote( - boundaries, part_id, path) - - reducer_results = [] - for r in range(args.num_reducers): - opt = { - "memory": args.output_part_size * 1.0, - } - blocks = mapper_results[:, r].tolist() - ret = reducer.options(**opt).remote(r, *blocks) - reducer_results.append(ret) - - reducer_results = ray.get(reducer_results) + + mapper_opt = { + "num_returns": args.num_reducers, + "num_cpus": os.cpu_count() / args.num_concurrent_rounds, + } # Load balance across worker nodes by setting `num_cpus`. + merge_results = np.empty( + (args.num_rounds, args.num_reducers), dtype=object) + + part_id = 0 + with worker_placement_groups(args) as pgs: + for round in range(args.num_rounds): + # Limit the number of in-flight rounds. + num_extra_rounds = round - args.num_concurrent_rounds + 1 + if num_extra_rounds > 0: + ray.wait( + [f for f in merge_results.flatten() if f is not None], + num_returns=num_extra_rounds * args.num_reducers) + + # Submit map tasks. + mapper_results = np.empty( + (args.num_mappers_per_round, args.num_reducers), dtype=object) + for _ in range(args.num_mappers_per_round): + _, node, path = parts[part_id] + m = part_id % args.num_mappers_per_round + mapper_results[m, :] = mapper.options(**mapper_opt).remote( + args, part_id, boundaries, path) + part_id += 1 + + # Submit merge tasks. + merge_results[round, :] = [ + merge_mapper_blocks.options(placement_group=pgs[r]).remote( + args, r, round, *mapper_results[:, r].tolist()) + for r in range(args.num_reducers) + ] + + # Delete local references to mapper results. + mapper_results = None + + # Submit second-stage reduce tasks. + reducer_results = [ + final_merge.options(placement_group=pgs[r]).remote( + args, r, *merge_results[:, r].tolist()) + for r in range(args.num_reducers) + ] + reducer_results = ray.get(reducer_results) + if not args.skip_output: with open(constants.OUTPUT_MANIFEST_FILE, "w") as fout: writer = csv.writer(fout) writer.writerows(reducer_results) + logging.info(ray.internal.internal_api.memory_summary(stats_only=True)) + # ------------------------------------------------------------ # Validate Output # ------------------------------------------------------------ +def _run_valsort(args: List[str]): + proc = subprocess.run([constants.VALSORT_PATH] + args, capture_output=True) + if proc.returncode != 0: + logging.critical("\n" + proc.stderr.decode("ascii")) + raise RuntimeError(f"Validation failed: {args}") + + @ray.remote def validate_part(path: Path): logging_utils.init() - proc = subprocess.run([constants.VALSORT_PATH, path], capture_output=True) - if proc.returncode != 0: - logging.critical("\n" + proc.stderr.decode("ascii")) - raise RuntimeError(f"Validation failed: {path}") + sum_path = path + ".sum" + _run_valsort(["-o", sum_path, path]) logging.info(f"Validated output {path}") + with open(sum_path, "rb") as fin: + return os.path.getsize(path), fin.read() -def validate_output(): - if args.skip_output: +def validate_output(args: Args): + if args.skip_sorting or args.skip_output: return - partitions = _load_manifest(constants.OUTPUT_MANIFEST_FILE) - tasks = [] + partitions = _load_manifest(args, constants.OUTPUT_MANIFEST_FILE) + results = [] for _, node, path in partitions: - tasks.append( - validate_part.options(resources={ - f"node:{node}": 1 / args.num_reducers - }).remote(path)) - logging.info(f"Validating {len(tasks)} partitions") - ray.get(tasks) - logging.info("All done!") + results.append(validate_part.options(**_node_res(node)).remote(path)) + logging.info(f"Validating {len(results)} partitions") + results = ray.get(results) + total = sum(s for s, _ in results) + assert total == args.total_data_size, total - args.total_data_size + all_checksum = b"".join(c for _, c in results) + with tempfile.NamedTemporaryFile() as fout: + fout.write(all_checksum) + fout.flush() + _run_valsort(["-s", fout.name]) + logging.info("All OK!") # ------------------------------------------------------------ @@ -306,30 +421,34 @@ def validate_output(): # ------------------------------------------------------------ -def init(): - if args.ray_address is None: - ray.init() +def init(args: Args): + if not args.ray_address: + ray.init(resources={"worker": os.cpu_count()}) else: ray.init(address=args.ray_address) logging_utils.init() logging.info(args) - logging.info(ray.available_resources()) os.makedirs(constants.WORK_DIR, exist_ok=True) + resources = ray.cluster_resources() + logging.info(resources) + args.num_workers = resources["worker"] + progress_tracker = tracing_utils.create_progress_tracker(args) + return progress_tracker -def main(): - init() +def main(args: Args): + # Keep the actor handle in scope for the duration of the program. + _progress_tracker = init(args) # noqa F841 if args.generate_input: - generate_input() + generate_input(args) if args.sort: - sort_main() + sort_main(args) if args.validate_output: - validate_output() + validate_output(args) if __name__ == "__main__": - args = get_args() - main() + main(get_args()) diff --git a/python/ray/experimental/raysort/sortlib.py b/python/ray/experimental/raysort/sortlib.py index ea79ec7168de4..6242867286d5f 100644 --- a/python/ray/experimental/raysort/sortlib.py +++ b/python/ray/experimental/raysort/sortlib.py @@ -1,4 +1,4 @@ -from typing import Iterable, List +from typing import Callable, Iterable, List import numpy as np @@ -21,7 +21,9 @@ def sort_and_partition(part: np.ndarray, return blocks -def merge_partitions(blocks: List[np.ndarray], - _n: int) -> Iterable[memoryview]: +def merge_partitions( + num_blocks: int, + get_block: Callable[[int, int], np.ndarray]) -> Iterable[memoryview]: + blocks = [get_block(i, 0) for i in range(num_blocks)] for block in blocks: yield block diff --git a/python/ray/experimental/raysort/tracing_utils.py b/python/ray/experimental/raysort/tracing_utils.py index e75bae8297429..e67584b62588c 100644 --- a/python/ray/experimental/raysort/tracing_utils.py +++ b/python/ray/experimental/raysort/tracing_utils.py @@ -1,13 +1,122 @@ -import contextlib +import datetime +import functools import logging import time +from typing import List, Tuple +import ray +from ray.util.metrics import Gauge, Histogram -@contextlib.contextmanager -def timeit(event="operation", args={}): - start = time.time() - yield - end = time.time() - duration = end - start - args = {"duration": duration} - logging.info(f"{event} {args}") +from ray.experimental.raysort import constants +from ray.experimental.raysort import logging_utils + +HISTOGRAM_BOUNDARIES = list(range(50, 200, 50)) + + +def timeit( + event: str, + report_time=False, + report_in_progress=True, + report_completed=True, +): + def decorator(f): + @functools.wraps(f) + def wrapped_f(*args, **kwargs): + progress_tracker = ray.get_actor(constants.PROGRESS_TRACKER_ACTOR) + progress_tracker.inc.remote( + f"{event}_in_progress", echo=report_in_progress) + try: + start = time.time() + ret = f(*args, **kwargs) + end = time.time() + duration = end - start + progress_tracker.observe.remote( + f"{event}_time", + duration, + echo=report_time, + ) + progress_tracker.inc.remote( + f"{event}_completed", echo=report_completed) + return ret + finally: + progress_tracker.dec.remote(f"{event}_in_progress") + + return wrapped_f + + return decorator + + +def get_metrics(_args): + return { + "gauges": [ + "map_in_progress", + "merge_in_progress", + "reduce_in_progress", + "sort_in_progress", + "map_completed", + "merge_completed", + "reduce_completed", + "sort_completed", + ], + "histograms": [ + ("map_time", HISTOGRAM_BOUNDARIES), + ("merge_time", HISTOGRAM_BOUNDARIES), + ("reduce_time", HISTOGRAM_BOUNDARIES), + ("sort_time", HISTOGRAM_BOUNDARIES), + ], + } + + +def create_progress_tracker(args): + return ProgressTracker.options( + name=constants.PROGRESS_TRACKER_ACTOR).remote(**get_metrics(args)) + + +@ray.remote +class ProgressTracker: + def __init__( + self, + gauges: List[str], + histograms: List[Tuple[str, List[int]]], + ): + self.counts = {m: 0 for m in gauges} + self.gauges = {m: Gauge(m) for m in gauges} + self.reset_gauges() + self.histograms = { + m: Histogram(m, boundaries=b) + for m, b in histograms + } + logging_utils.init() + + def reset_gauges(self): + for g in self.gauges.values(): + g.set(0) + + def inc(self, metric_name, value=1, echo=False): + gauge = self.gauges.get(metric_name) + if gauge is None: + logging.warning(f"No such Gauge: {metric_name}") + return + self.counts[metric_name] += value + gauge.set(self.counts[metric_name]) + if echo: + logging.info(f"{metric_name} {self.counts[metric_name]}") + + def dec(self, metric_name, value=1, echo=False): + return self.inc(metric_name, -value, echo) + + def observe(self, metric_name, value, echo=False): + histogram = self.histograms.get(metric_name) + if histogram is None: + logging.warning(f"No such Histogram: {metric_name}") + return + histogram.observe(value) + if echo: + logging.info(f"{metric_name} {value}") + + +def export_timeline(): + timestr = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + filename = f"/tmp/ray-timeline-{timestr}.json" + ray.timeline(filename=filename) + logging.info(f"Exported Ray timeline to {filename}") diff --git a/python/ray/experimental/raysort/types.py b/python/ray/experimental/raysort/types.py index 02c6f70e5004a..5d1c39a33a521 100644 --- a/python/ray/experimental/raysort/types.py +++ b/python/ray/experimental/raysort/types.py @@ -7,6 +7,12 @@ RecordCount = int BlockInfo = Tuple[int, int] -PartitionInfo = NamedTuple("PartitionInfo", - [("part_id", PartId), ("node", NodeAddress), - ("path", Path)]) + + +class PartInfo(NamedTuple): + part_id: PartId + node: NodeAddress + path: Path + + def __repr__(self): + return f"Part({self.node}:{self.path})" diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 33d9ce1a92fd4..17b2f3879f05b 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -260,8 +260,7 @@ cdef extern from "ray/core_worker/common.h" nogil: unordered_map[c_string, double] &resources, c_string concurrency_group_name, c_string serialized_runtime_env, - const unordered_map[c_string, c_string] - &override_environment_variables) + c_vector[c_string] runtime_env_uris) cdef cppclass CActorCreationOptions "ray::core::ActorCreationOptions": CActorCreationOptions() @@ -277,8 +276,7 @@ cdef extern from "ray/core_worker/common.h" nogil: c_pair[CPlacementGroupID, int64_t] placement_options, c_bool placement_group_capture_child_tasks, c_string serialized_runtime_env, - const unordered_map[c_string, c_string] - &override_environment_variables) + c_vector[c_string] runtime_env_uris) cdef cppclass CPlacementGroupCreationOptions \ "ray::core::PlacementGroupCreationOptions": diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 7e56dab60965a..a95eaee2c228f 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -2,7 +2,7 @@ # distutils: language = c++ # cython: embedsignature = True -from libc.stdint cimport int64_t +from libc.stdint cimport int64_t, uint64_t from libcpp cimport bool as c_bool from libcpp.memory cimport shared_ptr, unique_ptr from libcpp.pair cimport pair as c_pair @@ -177,7 +177,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_vector[CObjectReference] GetObjectRefs( const c_vector[CObjectID] &object_ids) const - void PromoteObjectToPlasma(const CObjectID &object_id) void GetOwnershipInfo(const CObjectID &object_id, CAddress *owner_address, c_string *object_status) @@ -254,6 +253,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: int64_t GetNumLeasesRequested() const + unordered_map[c_string, c_vector[uint64_t]] GetActorCallStats() const + cdef cppclass CCoreWorkerOptions "ray::core::CoreWorkerOptions": CWorkerType worker_type CLanguage language diff --git a/python/ray/job_config.py b/python/ray/job_config.py index 9ba513f71195e..e9dc6b3d7cd7d 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -1,17 +1,13 @@ from typing import Any, Dict, Optional import uuid -import json import ray._private.gcs_utils as gcs_utils -from ray.core.generated.common_pb2 import RuntimeEnv as RuntimeEnvPB class JobConfig: """A class used to store the configurations of a job. Attributes: - worker_env (dict): Environment variables to be set on worker - processes. num_java_workers_per_process (int): The number of java workers per worker process. jvm_options (str[]): The jvm options for java workers of the job. @@ -24,7 +20,6 @@ class JobConfig: """ def __init__(self, - worker_env=None, num_java_workers_per_process=1, jvm_options=None, code_search_path=None, @@ -32,10 +27,6 @@ def __init__(self, client_job=False, metadata=None, ray_namespace=None): - if worker_env is None: - self.worker_env = dict() - else: - self.worker_env = worker_env self.num_java_workers_per_process = num_java_workers_per_process self.jvm_options = jvm_options or [] self.code_search_path = code_search_path or [] @@ -54,21 +45,23 @@ def set_metadata(self, key: str, value: str) -> None: def serialize(self): """Serialize the struct into protobuf string""" - job_config = self.get_proto_job_config() - return job_config.SerializeToString() + return self.get_proto_job_config().SerializeToString() def set_runtime_env(self, runtime_env: Optional[Dict[str, Any]]) -> None: - # Lazily import this to avoid circular dependencies. - import ray._private.runtime_env as runtime_support - if runtime_env: - self._parsed_runtime_env = runtime_support.RuntimeEnvDict( - runtime_env) - self.worker_env.update( - self._parsed_runtime_env.get_parsed_dict().get("env_vars") - or {}) - else: - self._parsed_runtime_env = runtime_support.RuntimeEnvDict({}) + # TODO(edoakes): this is really unfortunate, but JobConfig is imported + # all over the place so this causes circular imports. We should remove + # this dependency and pass in a validated runtime_env instead. + from ray._private.runtime_env.validation import ParsedRuntimeEnv + self._parsed_runtime_env = ParsedRuntimeEnv(runtime_env or {}) self.runtime_env = runtime_env or dict() + eager_install = False + if runtime_env and "eager_install" in runtime_env: + eager_install = runtime_env["eager_install"] + self.runtime_env_eager_install = eager_install + assert isinstance(self.runtime_env_eager_install, bool), \ + f"The type of eager_install is incorrect: " \ + f"{type(self.runtime_env_eager_install)}" \ + f", the bool type is needed." self._cached_pb = None def set_ray_namespace(self, ray_namespace: str) -> None: @@ -84,35 +77,27 @@ def get_proto_job_config(self): self._cached_pb.ray_namespace = str(uuid.uuid4()) else: self._cached_pb.ray_namespace = self.ray_namespace - for key in self.worker_env: - self._cached_pb.worker_env[key] = self.worker_env[key] self._cached_pb.num_java_workers_per_process = ( self.num_java_workers_per_process) self._cached_pb.jvm_options.extend(self.jvm_options) self._cached_pb.code_search_path.extend(self.code_search_path) - self._cached_pb.runtime_env.CopyFrom(self._get_proto_runtime()) - self._cached_pb.serialized_runtime_env = \ - self.get_serialized_runtime_env() + self._cached_pb.runtime_env.uris[:] = self.get_runtime_env_uris() + serialized_env = self.get_serialized_runtime_env() + self._cached_pb.runtime_env.serialized_runtime_env = serialized_env for k, v in self.metadata.items(): self._cached_pb.metadata[k] = v + self._cached_pb.runtime_env.runtime_env_eager_install = \ + self.runtime_env_eager_install return self._cached_pb def get_runtime_env_uris(self): """Get the uris of runtime environment""" - if self.runtime_env.get("uris"): - return self.runtime_env.get("uris") - return [] - - def set_runtime_env_uris(self, uris): - self.runtime_env["uris"] = uris - self._parsed_runtime_env.set_uris(uris) + return self._parsed_runtime_env.get("uris") or [] def get_serialized_runtime_env(self) -> str: """Return the JSON-serialized parsed runtime env dict""" return self._parsed_runtime_env.serialize() - def _get_proto_runtime(self) -> RuntimeEnvPB: - runtime_env = RuntimeEnvPB() - runtime_env.uris[:] = self.get_runtime_env_uris() - runtime_env.raw_json = json.dumps(self.runtime_env) - return runtime_env + def set_runtime_env_uris(self, uris): + self.runtime_env["uris"] = uris + self._parsed_runtime_env["uris"] = uris diff --git a/python/ray/node.py b/python/ray/node.py index cee0f8bfebeac..0bc731e815bc2 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -356,7 +356,11 @@ def merge_resources(env_dict, params_dict): env_string = os.getenv( ray_constants.RESOURCES_ENVIRONMENT_VARIABLE) if env_string: - env_resources = json.loads(env_string) + try: + env_resources = json.loads(env_string) + except Exception: + logger.exception("Failed to load {}".format(env_string)) + raise logger.debug( f"Autoscaler overriding resources: {env_resources}.") num_cpus, num_gpus, memory, object_store_memory, resources = \ @@ -572,7 +576,10 @@ def _get_log_file_names(self, name, unique=False): log_stderr = os.path.join(self._logs_dir, f"{name}.err") return log_stdout, log_stderr - def _get_unused_port(self, close_on_exit=True): + def _get_unused_port(self, allocated_ports=None): + if allocated_ports is None: + allocated_ports = set() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) port = s.getsockname()[1] @@ -582,6 +589,10 @@ def _get_unused_port(self, close_on_exit=True): # from this method has been used by a different process. for _ in range(NUM_PORT_RETRIES): new_port = random.randint(port, 65535) + if new_port in allocated_ports: + # This port is allocated for other usage already, + # so we shouldn't use it even if it's not in use right now. + continue new_s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: new_s.bind(("", new_port)) @@ -589,13 +600,11 @@ def _get_unused_port(self, close_on_exit=True): new_s.close() continue s.close() - if close_on_exit: - new_s.close() - return new_port, new_s + new_s.close() + return new_port logger.error("Unable to succeed in selecting a random port.") - if close_on_exit: - s.close() - return port, s + s.close() + return port def _prepare_socket_file(self, socket_path, default_prefix): """Prepare the socket file for raylet and plasma. @@ -613,7 +622,7 @@ def _prepare_socket_file(self, socket_path, default_prefix): if sys.platform == "win32": if socket_path is None: result = (f"tcp://{self._localhost}" - f":{self._get_unused_port()[0]}") + f":{self._get_unused_port()}") else: if socket_path is None: result = self._make_inc_temp( @@ -665,7 +674,8 @@ def _get_cached_port(self, port = int(ports_by_node[self.unique_id][port_name]) else: # Pick a new port to use and cache it at this node. - port = (default_port or self._get_unused_port()[0]) + port = (default_port or self._get_unused_port( + set(ports_by_node[self.unique_id].values()))) ports_by_node[self.unique_id][port_name] = port with open(file_path, "w") as f: json.dump(ports_by_node, f) @@ -836,6 +846,7 @@ def start_raylet(self, start_initial_python_workers_for_first_job=self._ray_params. start_initial_python_workers_for_first_job, ray_debugger_external=self._ray_params.ray_debugger_external, + env_updates=self._ray_params.env_vars, ) assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 6854a93535b9e..ea3df3acb2f9a 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -1,7 +1,7 @@ -import uuid -import logging -import inspect from functools import wraps +import inspect +import logging +import uuid from ray import cloudpickle as pickle from ray._raylet import PythonFunctionDescriptor @@ -14,7 +14,8 @@ get_current_placement_group, ) import ray._private.signature -import ray._private.runtime_env as runtime_support +from ray._private.runtime_env.validation import ( + override_task_or_actor_runtime_env, ParsedRuntimeEnv) from ray.util.tracing.tracing_helper import (_tracing_task_invocation, _inject_tracing_into_function) @@ -78,7 +79,7 @@ class RemoteFunction: def __init__(self, language, function, function_descriptor, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type, num_returns, max_calls, max_retries, - retry_exceptions, runtime_env): + retry_exceptions, runtime_env, placement_group): if inspect.iscoroutinefunction(function): raise ValueError("'async def' should not be used for remote " "tasks. You can wrap the async function with " @@ -108,7 +109,12 @@ def __init__(self, language, function, function_descriptor, num_cpus, self._retry_exceptions = (DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS if retry_exceptions is None else retry_exceptions) - self._runtime_env = runtime_env + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + self._runtime_env = ParsedRuntimeEnv( + runtime_env or {}, is_task_or_actor=True) + self._placement_group = placement_group self._decorator = getattr(function, "__ray_invocation_decorator__", None) self._function_signature = ray._private.signature.extract_signature( @@ -145,7 +151,6 @@ def options(self, placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, runtime_env=None, - override_environment_variables=None, name=""): """Configures and overrides the task invocation parameters. @@ -164,6 +169,11 @@ def f(): """ func_cls = self + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + new_runtime_env = ParsedRuntimeEnv( + runtime_env or {}, is_task_or_actor=True) class FuncWrapper: def remote(self, *args, **kwargs): @@ -183,9 +193,7 @@ def remote(self, *args, **kwargs): placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), - runtime_env=runtime_env, - override_environment_variables=( - override_environment_variables), + runtime_env=new_runtime_env, name=name) return FuncWrapper() @@ -207,10 +215,10 @@ def _remote(self, placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, runtime_env=None, - override_environment_variables=None, name=""): """Submit the remote function for execution.""" - if client_mode_should_convert(): + + if client_mode_should_convert(auto_init=True): return client_mode_convert_function( self, args, @@ -229,7 +237,6 @@ def _remote(self, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), runtime_env=runtime_env, - override_environment_variables=override_environment_variables, name=name) worker = ray.worker.global_worker @@ -270,7 +277,12 @@ def _remote(self, placement_group_capture_child_tasks = ( worker.should_capture_child_tasks_in_placement_group) - if placement_group == "default": + if self._placement_group != "default": + if self._placement_group: + placement_group = self._placement_group + else: + placement_group = PlacementGroup.empty() + elif placement_group == "default": if placement_group_capture_child_tasks: placement_group = get_current_placement_group() else: @@ -288,18 +300,16 @@ def _remote(self, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type) - if runtime_env is None: + if runtime_env and not isinstance(runtime_env, ParsedRuntimeEnv): + runtime_env = ParsedRuntimeEnv(runtime_env) + elif isinstance(runtime_env, ParsedRuntimeEnv): + pass + else: runtime_env = self._runtime_env - job_runtime_env = worker.core_worker.get_current_runtime_env_dict() - runtime_env_dict = runtime_support.override_task_or_actor_runtime_env( - runtime_env, job_runtime_env) - - if override_environment_variables: - logger.warning("override_environment_variables is deprecated and " - "will be removed in Ray 1.6. Please use " - ".options(runtime_env={'env_vars': {...}}).remote()" - "instead.") + parent_runtime_env = worker.core_worker.get_current_runtime_env() + parsed_runtime_env = override_task_or_actor_runtime_env( + runtime_env, parent_runtime_env) def invocation(args, kwargs): if self._is_cross_language: @@ -315,21 +325,12 @@ def invocation(args, kwargs): "Cross language remote function " \ "cannot be executed locally." object_refs = worker.core_worker.submit_task( - self._language, - self._function_descriptor, - list_args, - name, - num_returns, - resources, - max_retries, - retry_exceptions, - placement_group.id, - placement_group_bundle_index, + self._language, self._function_descriptor, list_args, name, + num_returns, resources, max_retries, retry_exceptions, + placement_group.id, placement_group_bundle_index, placement_group_capture_child_tasks, - worker.debugger_breakpoint, - runtime_env_dict, - override_environment_variables=override_environment_variables - or dict()) + worker.debugger_breakpoint, parsed_runtime_env.serialize(), + parsed_runtime_env.get("uris") or []) # Reset worker's debug context from the last "remote" command # (which applies only to this .remote call). worker.debugger_breakpoint = b"" diff --git a/python/ray/runtime_context.py b/python/ray/runtime_context.py index 750e213cc12b0..64bee3fc7cf79 100644 --- a/python/ray/runtime_context.py +++ b/python/ray/runtime_context.py @@ -152,7 +152,7 @@ def should_capture_child_tasks_in_placement_group(self): @property def runtime_env(self): - """Get the runtime env passed to job_config + """Get the runtime env used for the current driver or worker. Returns: The runtime env currently using by this worker. @@ -172,12 +172,24 @@ def current_actor(self): worker.check_connected() return worker.core_worker.get_actor_handle(self.actor_id) + def _get_actor_call_stats(self): + """Get the current worker's task counters. + + Returns: + A dictionary keyed by the function name. The values are + dictionaries with form ``{"received": 0, "executing": 1, + "exectued": 2}``. + """ + worker = self.worker + worker.check_connected() + return worker.core_worker.get_actor_call_stats() + _runtime_context = None @PublicAPI(stability="beta") -@client_mode_hook +@client_mode_hook(auto_init=False) def get_runtime_context(): """Get the runtime context of the current driver/worker. diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 08465f7a422e0..dec530cb3022b 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -25,6 +25,8 @@ get_local_dump_archive, get_cluster_dump_archive, debug_status, RUN_ENV_TYPES) from ray.autoscaler._private.constants import RAY_PROCESSES +from ray.autoscaler._private.fake_multi_node.node_provider import \ + FAKE_HEAD_NODE_ID from ray.autoscaler._private.util import DEBUG_AUTOSCALING_ERROR, \ DEBUG_AUTOSCALING_STATUS @@ -432,12 +434,6 @@ def debug(address): hidden=True, type=json.loads, help="Override system configuration defaults.") -@click.option( - "--lru-evict", - is_flag=True, - hidden=True, - default=False, - help="Specify whether LRU evict will be used for this cluster.") @click.option( "--enable-object-reconstruction", is_flag=True, @@ -483,9 +479,9 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, dashboard_agent_listen_port, block, plasma_directory, autoscaling_config, no_redirect_worker_output, no_redirect_output, plasma_store_socket_name, raylet_socket_name, temp_dir, - system_config, lru_evict, enable_object_reconstruction, - metrics_export_port, no_monitor, tracing_startup_hook, - ray_debugger_external, log_style, log_color, verbose): + system_config, enable_object_reconstruction, metrics_export_port, + no_monitor, tracing_startup_hook, ray_debugger_external, log_style, + log_color, verbose): """Start Ray processes manually on the local machine.""" cli_logger.configure(log_style, log_color, verbose) if gcs_server_port and not head: @@ -540,7 +536,6 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, dashboard_port=dashboard_port, dashboard_agent_listen_port=dashboard_agent_listen_port, _system_config=system_config, - lru_evict=lru_evict, enable_object_reconstruction=enable_object_reconstruction, metrics_export_port=metrics_export_port, no_monitor=no_monitor, @@ -556,6 +551,11 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, s.bind(("", 0)) port = s.getsockname()[1] + if os.environ.get("RAY_FAKE_CLUSTER"): + ray_params.env_vars = { + "RAY_OVERRIDE_NODE_ID_FOR_TESTING": FAKE_HEAD_NODE_ID + } + num_redis_shards = None # Start Ray on the head node. if redis_shard_ports is not None and address is None: @@ -598,10 +598,9 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, "password.", cf.bold("--redis-password"), cf.bold("--address")) - node_ip_address = services.get_node_ip_address() - # Get the node IP address if one is not provided. - ray_params.update_if_absent(node_ip_address=node_ip_address) + ray_params.update_if_absent( + node_ip_address=services.get_node_ip_address()) cli_logger.labeled_value("Local node IP", ray_params.node_ip_address) ray_params.update_if_absent( redis_port=port, @@ -614,7 +613,7 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, # Fail early when starting a new cluster when one is already running if address is None: - default_address = f"{node_ip_address}:{port}" + default_address = f"{ray_params.node_ip_address}:{port}" redis_addresses = services.find_redis_address(default_address) if len(redis_addresses) > 0: raise ConnectionError( diff --git a/python/ray/serialization.py b/python/ray/serialization.py index bc335e4a8c539..5bf8c0d1437f3 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -91,7 +91,7 @@ def object_ref_reducer(obj): worker = ray.worker.global_worker worker.check_connected() obj, owner_address, object_status = ( - worker.core_worker.serialize_and_promote_object_ref(obj)) + worker.core_worker.serialize_object_ref(obj)) return _object_ref_deserializer, \ (obj.binary(), obj.call_site(), owner_address, object_status) diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index cfab567726e77..9417f7c3798a3 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -106,7 +106,7 @@ py_test( py_test( name = "test_ray_client", - size = "small", + size = "medium", srcs = serve_tests_srcs, tags = ["exclusive", "team:serverless"], deps = [":serve_lib"], @@ -338,3 +338,11 @@ py_test( tags = ["exclusive", "team:serve"], deps = [":serve_lib"] ) + +py_test( + name = "conda_env", + size = "medium", + srcs = glob(["examples/doc/*.py"]), + tags = ["exclusive", "post_wheel_build", "team:serve"], + deps = [":serve_lib"] +) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index cd0ea1b033816..05a40dc34df0a 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -7,8 +7,7 @@ import time from dataclasses import dataclass from functools import wraps -from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, Union, - overload) +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, overload from weakref import WeakValueDictionary from fastapi import APIRouter, FastAPI @@ -189,7 +188,8 @@ def _wait_for_goal(self, def deploy(self, name: str, backend_def: Union[Callable, Type[Callable], str], - *init_args: Any, + init_args: Tuple[Any], + init_kwargs: Dict[Any, Any], ray_actor_options: Optional[Dict] = None, config: Optional[Union[BackendConfig, Dict[str, Any]]] = None, version: Optional[str] = None, @@ -213,7 +213,10 @@ def deploy(self, del ray_actor_options["runtime_env"]["working_dir"] replica_config = ReplicaConfig( - backend_def, *init_args, ray_actor_options=ray_actor_options) + backend_def, + init_args=init_args, + init_kwargs=init_kwargs, + ray_actor_options=ray_actor_options) if isinstance(config, dict): backend_config = BackendConfig.parse_obj(config) @@ -222,16 +225,10 @@ def deploy(self, else: raise TypeError("config must be a BackendConfig or a dictionary.") - python_methods = [] - if inspect.isclass(backend_def): - for method_name, _ in inspect.getmembers(backend_def, - inspect.isfunction): - python_methods.append(method_name) - goal_id, updating = ray.get( self._controller.deploy.remote( - name, backend_config.to_proto_bytes(), replica_config, - python_methods, version, prev_version, route_prefix, + name, backend_config.to_proto_bytes(), replica_config, version, + prev_version, route_prefix, ray.get_runtime_context().job_id)) tag = f"component=serve deployment={name}" @@ -318,27 +315,16 @@ def get_handle( "to create sync handle. Learn more at https://docs.ray.io/en/" "master/serve/http-servehandle.html#sync-and-async-handles") - if endpoint_name in all_endpoints: - this_endpoint = all_endpoints[endpoint_name] - python_methods: List[str] = this_endpoint["python_methods"] - else: - # This can happen in the missing_ok=True case. - # handle.method_name.remote won't work and user must - # use the legacy handle.options(method).remote(). - python_methods: List[str] = [] - if sync: handle = RayServeSyncHandle( self._controller, endpoint_name, - known_python_methods=python_methods, _internal_pickled_http_request=_internal_pickled_http_request, ) else: handle = RayServeHandle( self._controller, endpoint_name, - known_python_methods=python_methods, _internal_pickled_http_request=_internal_pickled_http_request, ) @@ -619,6 +605,7 @@ def __init__(self, version: Optional[str] = None, prev_version: Optional[str] = None, init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Tuple[Any]] = None, route_prefix: Optional[str] = None, ray_actor_options: Optional[Dict] = None, _internal=False) -> None: @@ -644,6 +631,8 @@ def __init__(self, raise TypeError("prev_version must be a string.") if not (init_args is None or isinstance(init_args, tuple)): raise TypeError("init_args must be a tuple.") + if not (init_kwargs is None or isinstance(init_kwargs, dict)): + raise TypeError("init_kwargs must be a dict.") if route_prefix is not None: if not isinstance(route_prefix, str): raise TypeError("route_prefix must be a string.") @@ -660,6 +649,16 @@ def __init__(self, if init_args is None: init_args = () + if init_kwargs is None: + init_kwargs = {} + + # TODO(architkulkarni): Enforce that autoscaling_config and + # user-provided num_replicas should be mutually exclusive. + if version is None and config.autoscaling_config is not None: + # TODO(architkulkarni): Remove this restriction. + raise ValueError( + "Currently autoscaling is only supported for " + "versioned deployments. Try @serve.deployment(version=...).") self._func_or_class = func_or_class self._name = name @@ -667,6 +666,7 @@ def __init__(self, self._prev_version = prev_version self._config = config self._init_args = init_args + self._init_kwargs = init_kwargs self._route_prefix = route_prefix self._ray_actor_options = ray_actor_options @@ -724,7 +724,12 @@ def ray_actor_options(self) -> Optional[Dict]: @property def init_args(self) -> Tuple[Any]: - """Arguments passed to the underlying class's constructor.""" + """Positional args passed to the underlying class's constructor.""" + return self._init_args + + @property + def init_kwargs(self) -> Tuple[Any]: + """Keyword args passed to the underlying class's constructor.""" return self._init_args @property @@ -738,20 +743,25 @@ def __call__(self): "Use `deployment.deploy() instead.`") @PublicAPI - def deploy(self, *init_args, _blocking=True): + def deploy(self, *init_args, _blocking=True, **init_kwargs): """Deploy or update this deployment. Args: init_args (optional): args to pass to the class __init__ method. Not valid if this deployment wraps a function. + init_kwargs (optional): kwargs to pass to the class __init__ + method. Not valid if this deployment wraps a function. """ if len(init_args) == 0 and self._init_args is not None: init_args = self._init_args + if len(init_kwargs) == 0 and self._init_kwargs is not None: + init_kwargs = self._init_kwargs return _get_global_client().deploy( self._name, self._func_or_class, - *init_args, + init_args, + init_kwargs, ray_actor_options=self._ray_actor_options, config=self._config, version=self._version, @@ -783,19 +793,23 @@ def get_handle(self, sync: Optional[bool] = True self._name, missing_ok=True, sync=sync) @PublicAPI - def options( - self, - func_or_class: Optional[Callable] = None, - name: Optional[str] = None, - version: Optional[str] = None, - prev_version: Optional[str] = None, - init_args: Optional[Tuple[Any]] = None, - route_prefix: Optional[str] = None, - num_replicas: Optional[int] = None, - ray_actor_options: Optional[Dict] = None, - user_config: Optional[Any] = None, - max_concurrent_queries: Optional[int] = None, - ) -> "Deployment": + def options(self, + func_or_class: Optional[Callable] = None, + name: Optional[str] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Dict[Any, Any]] = None, + route_prefix: Optional[str] = None, + num_replicas: Optional[int] = None, + ray_actor_options: Optional[Dict] = None, + user_config: Optional[Any] = None, + max_concurrent_queries: Optional[int] = None, + _autoscaling_config: Optional[Union[Dict, + AutoscalingConfig]] = None, + _graceful_shutdown_wait_loop_s: Optional[float] = None, + _graceful_shutdown_timeout_s: Optional[float] = None + ) -> "Deployment": """Return a copy of this deployment with updated options. Only those options passed in will be updated, all others will remain @@ -821,6 +835,9 @@ def options( if init_args is None: init_args = self._init_args + if init_kwargs is None: + init_kwargs = self._init_kwargs + if route_prefix is None: if self._route_prefix == f"/{self._name}": route_prefix = None @@ -830,6 +847,17 @@ def options( if ray_actor_options is None: ray_actor_options = self._ray_actor_options + if _autoscaling_config is None: + new_config.autoscaling_config = _autoscaling_config + + if _graceful_shutdown_wait_loop_s is not None: + new_config.graceful_shutdown_wait_loop_s = ( + _graceful_shutdown_wait_loop_s) + + if _graceful_shutdown_timeout_s is not None: + new_config.graceful_shutdown_timeout_s = ( + _graceful_shutdown_timeout_s) + return Deployment( func_or_class, name, @@ -837,6 +865,7 @@ def options( version=version, prev_version=prev_version, init_args=init_args, + init_kwargs=init_kwargs, route_prefix=route_prefix, ray_actor_options=ray_actor_options, _internal=True, @@ -848,6 +877,7 @@ def __eq__(self, other): self._version == other._version, self._config == other._config, self._init_args == other._init_args, + self._init_kwargs == other._init_kwargs, self._route_prefix == other._route_prefix, self._ray_actor_options == self._ray_actor_options, ]) @@ -871,16 +901,20 @@ def deployment(func_or_class: Callable) -> Deployment: @overload -def deployment(name: Optional[str] = None, - version: Optional[str] = None, - prev_version: Optional[str] = None, - num_replicas: Optional[int] = None, - init_args: Optional[Tuple[Any]] = None, - ray_actor_options: Optional[Dict] = None, - user_config: Optional[Any] = None, - max_concurrent_queries: Optional[int] = None, - _autoscaling_config: Optional[dict] = None - ) -> Callable[[Callable], Deployment]: +def deployment( + name: Optional[str] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + num_replicas: Optional[int] = None, + init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Dict[Any, Any]] = None, + ray_actor_options: Optional[Dict] = None, + user_config: Optional[Any] = None, + max_concurrent_queries: Optional[int] = None, + _autoscaling_config: Optional[Union[Dict, AutoscalingConfig]] = None, + _graceful_shutdown_wait_loop_s: Optional[float] = None, + _graceful_shutdown_timeout_s: Optional[float] = None +) -> Callable[[Callable], Deployment]: pass @@ -892,11 +926,14 @@ def deployment( prev_version: Optional[str] = None, num_replicas: Optional[int] = None, init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Dict[Any, Any]] = None, route_prefix: Optional[str] = None, ray_actor_options: Optional[Dict] = None, user_config: Optional[Any] = None, max_concurrent_queries: Optional[int] = None, - _autoscaling_config: Optional[dict] = None, + _autoscaling_config: Optional[Union[Dict, AutoscalingConfig]] = None, + _graceful_shutdown_wait_loop_s: Optional[float] = None, + _graceful_shutdown_timeout_s: Optional[float] = None ) -> Callable[[Callable], Deployment]: """Define a Serve deployment. @@ -915,7 +952,10 @@ def deployment( not check the existing deployment's version. num_replicas (Optional[int]): The number of processes to start up that will handle requests to this deployment. Defaults to 1. - init_args (Optional[Tuple]): Arguments to be passed to the class + init_args (Optional[Tuple]): Positional args to be passed to the class + constructor when starting up deployment replicas. These can also be + passed when you call `.deploy()` on the returned Deployment. + init_kwargs (Optional[Dict]): Keyword args to be passed to the class constructor when starting up deployment replicas. These can also be passed when you call `.deploy()` on the returned Deployment. route_prefix (Optional[str]): Requests to paths under this HTTP path @@ -962,8 +1002,13 @@ class MyDeployment: config.max_concurrent_queries = max_concurrent_queries if _autoscaling_config is not None: - config.autoscaling_config = AutoscalingConfig.parse_obj( - _autoscaling_config) + config.autoscaling_config = _autoscaling_config + + if _graceful_shutdown_wait_loop_s is not None: + config.graceful_shutdown_wait_loop_s = _graceful_shutdown_wait_loop_s + + if _graceful_shutdown_timeout_s is not None: + config.graceful_shutdown_timeout_s = _graceful_shutdown_timeout_s def decorator(_func_or_class): return Deployment( @@ -973,6 +1018,7 @@ def decorator(_func_or_class): version=version, prev_version=prev_version, init_args=init_args, + init_kwargs=init_kwargs, route_prefix=route_prefix, ray_actor_options=ray_actor_options, _internal=True, @@ -1014,6 +1060,7 @@ def get_deployment(name: str) -> Deployment: backend_info.backend_config, version=backend_info.version, init_args=backend_info.replica_config.init_args, + init_kwargs=backend_info.replica_config.init_kwargs, route_prefix=route_prefix, ray_actor_options=backend_info.replica_config.ray_actor_options, _internal=True, @@ -1037,6 +1084,7 @@ def list_deployments() -> Dict[str, Deployment]: backend_info.backend_config, version=backend_info.version, init_args=backend_info.replica_config.init_args, + init_kwargs=backend_info.replica_config.init_kwargs, route_prefix=route_prefix, ray_actor_options=backend_info.replica_config.ray_actor_options, _internal=True, diff --git a/python/ray/serve/autoscaling_metrics.py b/python/ray/serve/autoscaling_metrics.py index 084996760d297..4b0d030d700cc 100644 --- a/python/ray/serve/autoscaling_metrics.py +++ b/python/ray/serve/autoscaling_metrics.py @@ -74,7 +74,7 @@ def add_metrics_point(self, data_points: Dict[str, float], Args: data_points(dict): dictionary containing the metrics values. The - key should be a string that uniquely identitify this time series + key should be a string that uniquely identifies this time series and to be used to perform aggregation. timestamp(float): the unix epoch timestamp the metrics are collected at. @@ -98,6 +98,9 @@ def window_average(self, do_compact(bool): whether or not to delete the datapoints that's before `window_start_timestamp_s` to save memory. Default is true. + Returns: + The average of all the datapoints for the key on and after time + window_start_timestamp_s, or None if there are no such points. """ datapoints = self.data[key] diff --git a/python/ray/serve/autoscaling_policy.py b/python/ray/serve/autoscaling_policy.py index 6a9887fb7497c..23dbbf65159e9 100644 --- a/python/ray/serve/autoscaling_policy.py +++ b/python/ray/serve/autoscaling_policy.py @@ -16,7 +16,6 @@ def calculate_desired_num_replicas(autoscaling_config: AutoscalingConfig, current_num_ongoing_requests (List[float]): A list of the number of ongoing requests for each replica. Assumes each entry has already been time-averaged over the desired lookback window. - current_num_replicas (int): The current number of active replicas. Returns: desired_num_replicas: The desired number of replicas to scale to, based diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 2ab4c5e41d99d..6068887f9bd7a 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -76,6 +76,7 @@ def __init__(self, actor_name: str, detached: bool, controller_name: str, self._ready_obj_ref = None self._graceful_shutdown_ref = None + self._graceful_shutdown_timeout_s = None self._actor_resources = None self._health_check_ref = None @@ -147,6 +148,8 @@ def start(self, backend_info: BackendInfo, version: BackendVersion): Start a new actor for current BackendReplica instance. """ self._actor_resources = backend_info.replica_config.resource_dict + self._graceful_shutdown_timeout_s = ( + backend_info.backend_config.graceful_shutdown_timeout_s) if USE_PLACEMENT_GROUP: self._placement_group = self.create_placement_group( self._placement_group_name, self._actor_resources) @@ -164,6 +167,7 @@ def start(self, backend_info: BackendInfo, version: BackendVersion): **backend_info.replica_config.ray_actor_options).remote( self.backend_tag, self.replica_tag, backend_info.replica_config.init_args, + backend_info.replica_config.init_kwargs, backend_info.backend_config.to_proto_bytes(), version, self._controller_name, self._detached) @@ -243,14 +247,19 @@ def actor_resources(self) -> Dict[str, float]: def available_resources(self) -> Dict[str, float]: return ray.available_resources() - def graceful_stop(self) -> None: - """Request the actor to exit gracefully.""" + def graceful_stop(self) -> Duration: + """Request the actor to exit gracefully. + + Returns the timeout after which to kill the actor. + """ try: handle = ray.get_actor(self._actor_name) self._graceful_shutdown_ref = handle.prepare_for_shutdown.remote() except ValueError: pass + return self._graceful_shutdown_timeout_s + def check_stopped(self) -> bool: """Check if the actor has exited.""" try: @@ -386,14 +395,15 @@ def check_started(self) -> ReplicaStartupStatus: return status - def stop(self, graceful_shutdown_timeout_s: Duration = 0) -> None: + def stop(self, graceful: bool = True) -> None: """Stop the replica. Should handle the case where the replica is already stopped. """ - self._actor.graceful_stop() - self._graceful_shutdown_timeout_s = graceful_shutdown_timeout_s - self._shutdown_deadline = time.time() + graceful_shutdown_timeout_s + timeout_s = self._actor.graceful_stop() + if not graceful: + timeout_s = 0 + self._shutdown_deadline = time.time() + timeout_s def check_stopped(self) -> bool: """Check if the replica has finished stopping.""" @@ -402,14 +412,13 @@ def check_stopped(self) -> bool: self._actor.cleanup() return True - timeout_passed = time.time() >= self._shutdown_deadline - + timeout_passed = time.time() > self._shutdown_deadline if timeout_passed: # Graceful period passed, kill it forcefully. # This will be called repeatedly until the replica shuts down. logger.debug( - f"Replica {self.replica_tag} did not shutdown after " - f"{self._graceful_shutdown_timeout_s}s, force-killing. " + f"Replica {self.replica_tag} did not shut down after grace " + "period, force-killing it. " f"component=serve deployment={self.backend_tag} " f"replica={self.replica_tag}") @@ -722,9 +731,9 @@ def deploy(self, backend_info: BackendInfo) -> Tuple[Optional[GoalId], bool]: """Deploy the backend. - If the backend already exists with the same version, this is a no-op - and returns the GoalId corresponding to the existing update if there - is one. + If the backend already exists with the same version and BackendConfig, + this is a no-op and returns the GoalId corresponding to the existing + update if there is one. Returns: GoalId, bool: The GoalId for the client to wait for and whether or @@ -760,11 +769,8 @@ def deploy(self, self._goal_manager.complete_goal(existing_goal_id) return new_goal_id, True - def delete(self, force_kill: bool = False) -> Optional[GoalId]: + def delete(self) -> Optional[GoalId]: new_goal_id, existing_goal_id = self._set_backend_goal(None) - if force_kill: - self._target_info.backend_config.\ - experimental_graceful_shutdown_timeout_s = 0 self._save_checkpoint_func() self._notify_backend_configs_changed() @@ -822,9 +828,6 @@ def _stop_wrong_version_replicas(self) -> int: states=[ReplicaState.STARTING, ReplicaState.RUNNING], max_replicas=max_to_stop) - graceful_shutdown_timeout_s = ( - self._target_info.backend_config. - experimental_graceful_shutdown_timeout_s) code_version_changes = 0 user_config_changes = 0 for replica in replicas_to_update: @@ -834,8 +837,7 @@ def _stop_wrong_version_replicas(self) -> int: if (replica.version.code_version != self._target_version.code_version): code_version_changes += 1 - replica.stop( - graceful_shutdown_timeout_s=graceful_shutdown_timeout_s) + replica.stop() self._replicas.add(ReplicaState.STOPPING, replica) # If only the user_config is a mismatch, we update it dynamically # without restarting the replica. @@ -869,10 +871,6 @@ def _scale_backend_replicas(self) -> bool: assert self._target_replicas >= 0, ("Number of replicas must be" " greater than or equal to 0.") - graceful_shutdown_timeout_s = ( - self._target_info.backend_config. - experimental_graceful_shutdown_timeout_s) - self._stop_wrong_version_replicas() current_replicas = self._replicas.count(states=[ @@ -924,8 +922,7 @@ def _scale_backend_replicas(self) -> bool: for replica in replicas_to_stop: logger.debug(f"Adding STOPPING to replica_tag: {replica}, " f"backend_tag: {self._name}") - replica.stop( - graceful_shutdown_timeout_s=graceful_shutdown_timeout_s) + replica.stop() self._replicas.add(ReplicaState.STOPPING, replica) return True @@ -1014,7 +1011,7 @@ def _check_startup_replicas(self, # Increase startup failure counter if we're tracking it self._replica_constructor_retry_counter += 1 - replica.stop(graceful_shutdown_timeout_s=0) + replica.stop(graceful=False) self._replicas.add(ReplicaState.STOPPING, replica) transitioned = True elif start_status == ReplicaStartupStatus.PENDING: @@ -1026,7 +1023,7 @@ def _check_startup_replicas(self, if not stop_on_slow: self._replicas.add(original_state, replica) else: - replica.stop(graceful_shutdown_timeout_s=0) + replica.stop(graceful=False) self._replicas.add(ReplicaState.STOPPING, replica) transitioned = True slow_replicas.append(replica) @@ -1049,7 +1046,7 @@ def _check_and_update_replicas(self) -> bool: f"{self._name} failed health check, stopping it. " f"component=serve deployment={self._name} " f"replica={replica.replica_tag}") - replica.stop(graceful_shutdown_timeout_s=0) + replica.stop(graceful=False) self._replicas.add(ReplicaState.STOPPING, replica) slow_start_replicas = [] @@ -1073,8 +1070,9 @@ def _check_and_update_replicas(self) -> bool: f"Deployment '{self._name}' has " f"{len(slow_start_replicas)} replicas that have taken " f"more than {SLOW_STARTUP_WARNING_S}s to start up. This " - "may be caused by waiting for the cluster to auto-scale " - "or because the constructor is slow. Resources required " + "may be caused by waiting for the cluster to auto-scale, " + "waiting for a runtime environment to install, or a slow " + "constructor. Resources required " f"for each replica: {required}, resources available: " f"{available}. component=serve deployment={self._name}") @@ -1236,7 +1234,7 @@ def shutdown(self) -> List[GoalId]: shutdown_goals = [] for backend_state in self._backend_states.values(): - goal = backend_state.delete(force_kill=True) + goal = backend_state.delete() if goal is not None: shutdown_goals.append(goal) @@ -1302,9 +1300,9 @@ def deploy_backend(self, backend_tag: BackendTag, backend_info: BackendInfo ) -> Tuple[Optional[GoalId], bool]: """Deploy the backend. - If the backend already exists with the same version, this is a no-op - and returns the GoalId corresponding to the existing update if there - is one. + If the backend already exists with the same version and BackendConfig, + this is a no-op and returns the GoalId corresponding to the existing + update if there is one. Returns: GoalId, bool: The GoalId for the client to wait for and whether or @@ -1319,15 +1317,14 @@ def deploy_backend(self, backend_tag: BackendTag, backend_info: BackendInfo return self._backend_states[backend_tag].deploy(backend_info) - def delete_backend(self, backend_tag: BackendTag, - force_kill: bool = False) -> Optional[GoalId]: + def delete_backend(self, backend_tag: BackendTag) -> Optional[GoalId]: # This method must be idempotent. We should validate that the # specified backend exists on the client. if backend_tag not in self._backend_states: return None backend_state = self._backend_states[backend_tag] - return backend_state.delete(force_kill=force_kill) + return backend_state.delete() def update(self) -> bool: """Updates the state of all backends to match their goal state.""" diff --git a/python/ray/serve/common.py b/python/ray/serve/common.py index be13503c97334..0e97b5cf98eb7 100644 --- a/python/ray/serve/common.py +++ b/python/ray/serve/common.py @@ -1,7 +1,7 @@ import ray -from dataclasses import dataclass, field -from typing import List, Optional +from dataclasses import dataclass +from typing import Optional from uuid import UUID from ray.actor import ActorClass @@ -17,7 +17,6 @@ @dataclass class EndpointInfo: - python_methods: Optional[List[str]] = field(default_factory=list) route: Optional[str] = None diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 4002550ae109f..b7d5c08457691 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -1,7 +1,7 @@ import inspect import pickle from enum import Enum -from typing import Any, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import pydantic from google.protobuf.json_format import MessageToDict @@ -24,8 +24,10 @@ class AutoscalingConfig(BaseModel): # Private options below # Metrics scraping options + + # How often to scrape for metrics metrics_interval_s: float = 10.0 - loop_period_s: float = 30.0 + # Time window to average over for metrics. look_back_period_s: float = 30.0 # Internal autoscaling configuration options @@ -34,6 +36,7 @@ class AutoscalingConfig(BaseModel): smoothing_factor: float = 1.0 # TODO(architkulkarni): implement below + # loop_period_s = 30 # How frequently to make autoscaling decisions # How long to wait before scaling down replicas # downscale_delay_s: float = 600.0 # How long to wait before scaling up replicas @@ -52,8 +55,6 @@ class AutoscalingConfig(BaseModel): class BackendConfig(BaseModel): """Configuration options for a backend, to be set by the user. - DEPRECATED. Will be removed in Ray 1.5. See docs for details. - Args: num_replicas (Optional[int]): The number of processes to start up that will handle requests to this backend. Defaults to 1. @@ -63,10 +64,10 @@ class BackendConfig(BaseModel): user_config (Optional[Any]): Arguments to pass to the reconfigure method of the backend. The reconfigure method is called if user_config is not None. - experimental_graceful_shutdown_wait_loop_s (Optional[float]): Duration + graceful_shutdown_wait_loop_s (Optional[float]): Duration that backend workers will wait until there is no more work to be done before shutting down. Defaults to 2s. - experimental_graceful_shutdown_timeout_s (Optional[float]): + graceful_shutdown_timeout_s (Optional[float]): Controller waits for this duration to forcefully kill the replica for shutdown. Defaults to 20s. """ @@ -75,8 +76,8 @@ class BackendConfig(BaseModel): max_concurrent_queries: Optional[int] = None user_config: Any = None - experimental_graceful_shutdown_wait_loop_s: NonNegativeFloat = 2.0 - experimental_graceful_shutdown_timeout_s: NonNegativeFloat = 20.0 + graceful_shutdown_wait_loop_s: NonNegativeFloat = 2.0 + graceful_shutdown_timeout_s: NonNegativeFloat = 20.0 autoscaling_config: Optional[AutoscalingConfig] = None @@ -121,16 +122,23 @@ def from_proto_bytes(cls, proto_bytes: bytes): class ReplicaConfig: - def __init__(self, backend_def, *init_args, ray_actor_options=None): + def __init__(self, + backend_def: Callable, + init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Dict[Any, Any]] = None, + ray_actor_options=None): # Validate that backend_def is an import path, function, or class. if isinstance(backend_def, str): self.func_or_class_name = backend_def pass elif inspect.isfunction(backend_def): self.func_or_class_name = backend_def.__name__ - if len(init_args) != 0: + if init_args: raise ValueError( "init_args not supported for function backend.") + if init_kwargs: + raise ValueError( + "init_kwargs not supported for function backend.") elif inspect.isclass(backend_def): self.func_or_class_name = backend_def.__name__ else: @@ -139,7 +147,8 @@ def __init__(self, backend_def, *init_args, ray_actor_options=None): format(type(backend_def))) self.serialized_backend_def = cloudpickle.dumps(backend_def) - self.init_args = init_args + self.init_args = init_args if init_args is not None else () + self.init_kwargs = init_kwargs if init_kwargs is not None else {} if ray_actor_options is None: self.ray_actor_options = {} else: @@ -158,12 +167,13 @@ def _validate(self): raise TypeError("ray_actor_options must be a dictionary.") elif "lifetime" in self.ray_actor_options: raise ValueError( - "Specifying lifetime in init_args is not allowed.") + "Specifying lifetime in ray_actor_options is not allowed.") elif "name" in self.ray_actor_options: - raise ValueError("Specifying name in init_args is not allowed.") + raise ValueError( + "Specifying name in ray_actor_options is not allowed.") elif "max_restarts" in self.ray_actor_options: raise ValueError("Specifying max_restarts in " - "init_args is not allowed.") + "ray_actor_options is not allowed.") else: # Ray defaults to zero CPUs for placement, we default to one here. if "num_cpus" not in self.ray_actor_options: diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index c367dc4232b81..cdaf1cf008151 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -8,8 +8,8 @@ import ray from ray.actor import ActorHandle from ray.serve.async_goal_manager import AsyncGoalManager +from ray.serve.autoscaling_policy import calculate_desired_num_replicas from ray.serve.backend_state import ReplicaState, BackendStateManager -from ray.serve.backend_worker import create_backend_replica from ray.serve.common import ( BackendInfo, BackendTag, @@ -20,9 +20,10 @@ ReplicaTag, ) from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig -from ray.serve.constants import (CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY) +from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY from ray.serve.endpoint_state import EndpointState from ray.serve.http_state import HTTPState +from ray.serve.replica import create_replica_wrapper from ray.serve.storage.checkpoint_path import make_kv_store from ray.serve.long_poll import LongPollHost from ray.serve.utils import logger @@ -104,6 +105,10 @@ def record_autoscaling_metrics(self, data: Dict[str, float], def _dump_autoscaling_metrics_for_testing(self): return self.autoscaling_metrics_store.data + def _dump_replica_states_for_testing(self, deployment_name): + return self.backend_state_manager._backend_states[ + deployment_name]._replicas + async def wait_for_goal(self, goal_id: GoalId) -> Optional[Exception]: return await self.goal_manager.wait_for_goal(goal_id) @@ -129,8 +134,55 @@ def get_http_proxies(self) -> Dict[NodeId, ActorHandle]: """Returns a dictionary of node ID to http_proxy actor handles.""" return self.http_state.get_http_proxy_handles() + def autoscale(self) -> None: + """Update autoscaling deployments with calculated num_replicas.""" + for deployment_name, (backend_info, + route_prefix) in self.list_deployments().items(): + backend_config = backend_info.backend_config + autoscaling_config = backend_config.autoscaling_config + + if autoscaling_config is None: + continue + + replicas = self.backend_state_manager._backend_states[ + deployment_name]._replicas + running_replicas = replicas.get([ReplicaState.RUNNING]) + + current_num_ongoing_requests = [] + for replica in running_replicas: + replica_tag = replica.replica_tag + num_ongoing_requests = ( + self.autoscaling_metrics_store.window_average( + replica_tag, + time.time() - autoscaling_config.look_back_period_s)) + if num_ongoing_requests is not None: + current_num_ongoing_requests.append(num_ongoing_requests) + + if len(current_num_ongoing_requests) == 0: + continue + + new_backend_config = backend_config.copy() + new_backend_config.num_replicas = calculate_desired_num_replicas( + autoscaling_config, current_num_ongoing_requests) + + replica_config = backend_info.replica_config + deployer_job_id = backend_info.deployer_job_id + backend_config_proto_bytes = new_backend_config.to_proto_bytes() + goal_id, updating = self.deploy( + deployment_name, + backend_config_proto_bytes, + replica_config, + version=backend_info.version, + prev_version=backend_info.version, + route_prefix=route_prefix, + deployer_job_id=deployer_job_id) + async def run_control_loop(self) -> None: while True: + try: + self.autoscale() + except Exception: + logger.exception("Exception while autoscaling deployments.") async with self.write_lock: try: self.http_state.update() @@ -218,57 +270,56 @@ async def shutdown(self) -> List[GoalId]: return goal_ids - async def deploy(self, - name: str, - backend_config_proto_bytes: bytes, - replica_config: ReplicaConfig, - python_methods: List[str], - version: Optional[str], - prev_version: Optional[str], - route_prefix: Optional[str], - deployer_job_id: "Optional[ray._raylet.JobID]" = None - ) -> Tuple[Optional[GoalId], bool]: + def deploy(self, + name: str, + backend_config_proto_bytes: bytes, + replica_config: ReplicaConfig, + version: Optional[str], + prev_version: Optional[str], + route_prefix: Optional[str], + deployer_job_id: "Optional[ray._raylet.JobID]" = None + ) -> Tuple[Optional[GoalId], bool]: if route_prefix is not None: assert route_prefix.startswith("/") backend_config = BackendConfig.from_proto_bytes( backend_config_proto_bytes) - async with self.write_lock: - if prev_version is not None: - existing_backend_info = self.backend_state_manager.get_backend( - name) - if (existing_backend_info is None - or not existing_backend_info.version): - raise ValueError( - f"prev_version '{prev_version}' is specified but " - "there is no existing deployment.") - if existing_backend_info.version != prev_version: - raise ValueError( - f"prev_version '{prev_version}' " - "does not match with the existing " - f"version '{existing_backend_info.version}'.") - backend_info = BackendInfo( - actor_def=ray.remote( - create_backend_replica( - name, replica_config.serialized_backend_def)), - version=version, - backend_config=backend_config, - replica_config=replica_config, - deployer_job_id=deployer_job_id, - start_time_ms=int(time.time() * 1000)) - - goal_id, updating = self.backend_state_manager.deploy_backend( - name, backend_info) - endpoint_info = EndpointInfo( - route=route_prefix, python_methods=python_methods) - self.endpoint_state.update_endpoint(name, endpoint_info) - return goal_id, updating + if prev_version is not None: + existing_backend_info = self.backend_state_manager.get_backend( + name) + if (existing_backend_info is None + or not existing_backend_info.version): + raise ValueError( + f"prev_version '{prev_version}' is specified but " + "there is no existing deployment.") + if existing_backend_info.version != prev_version: + raise ValueError(f"prev_version '{prev_version}' " + "does not match with the existing " + f"version '{existing_backend_info.version}'.") + backend_info = BackendInfo( + actor_def=ray.remote( + create_replica_wrapper(name, + replica_config.serialized_backend_def)), + version=version, + backend_config=backend_config, + replica_config=replica_config, + deployer_job_id=deployer_job_id, + start_time_ms=int(time.time() * 1000)) + # TODO(architkulkarni): When a deployment is redeployed, even if + # the only change was num_replicas, the start_time_ms is refreshed. + # This is probably not the desired behavior for an autoscaling + # deployment, which redeploys very often to change num_replicas. + + goal_id, updating = self.backend_state_manager.deploy_backend( + name, backend_info) + endpoint_info = EndpointInfo(route=route_prefix) + self.endpoint_state.update_endpoint(name, endpoint_info) + return goal_id, updating def delete_deployment(self, name: str) -> Optional[GoalId]: self.endpoint_state.delete_endpoint(name) - return self.backend_state_manager.delete_backend( - name, force_kill=False) + return self.backend_state_manager.delete_backend(name) def get_deployment_info(self, name: str) -> Tuple[BackendInfo, str]: """Get the current information about a deployment. diff --git a/python/ray/serve/endpoint_state.py b/python/ray/serve/endpoint_state.py index 6483f7355ff0e..5bba277001c54 100644 --- a/python/ray/serve/endpoint_state.py +++ b/python/ray/serve/endpoint_state.py @@ -79,7 +79,6 @@ def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]: for endpoint, info in self._endpoints.items(): endpoints[endpoint] = { "route": info.route, - "python_methods": info.python_methods, } return endpoints diff --git a/python/ray/serve/examples/doc/conda_env.py b/python/ray/serve/examples/doc/conda_env.py index c431964bb0772..1607cf7d60e37 100644 --- a/python/ray/serve/examples/doc/conda_env.py +++ b/python/ray/serve/examples/doc/conda_env.py @@ -1,27 +1,28 @@ import requests from ray import serve -import tensorflow as tf serve.start() @serve.deployment -def tf_version(request): - return ("Tensorflow " + tf.__version__) +def requests_version(request): + return requests.__version__ -tf_version.options( - name="tf1", ray_actor_options={ +requests_version.options( + name="25", + ray_actor_options={ "runtime_env": { - "conda": "ray-tf1" + "pip": ["ray[serve]", "requests==2.25.1"] } }).deploy() -tf_version.options( - name="tf2", ray_actor_options={ +requests_version.options( + name="26", + ray_actor_options={ "runtime_env": { - "conda": "ray-tf2" + "pip": ["ray[serve]", "requests==2.26.0"] } }).deploy() -print(requests.get("http://127.0.0.1:8000/tf1").text) # Tensorflow 1.15.0 -print(requests.get("http://127.0.0.1:8000/tf2").text) # Tensorflow 2.3.0 +assert requests.get("http://127.0.0.1:8000/25").text == "2.25.1" +assert requests.get("http://127.0.0.1:8000/26").text == "2.26.0" diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 7c315f66605f4..340be1f987a7c 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -1,7 +1,7 @@ import asyncio import concurrent.futures from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union, Coroutine +from typing import Dict, Optional, Union, Coroutine import threading from enum import Enum @@ -75,14 +75,12 @@ def __init__( endpoint_name: EndpointTag, handle_options: Optional[HandleOptions] = None, *, - known_python_methods: List[str] = [], _router: Optional[Router] = None, _internal_pickled_http_request: bool = False, ): self.controller_handle = controller_handle self.endpoint_name = endpoint_name self.handle_options = handle_options or HandleOptions() - self.known_python_methods = known_python_methods self.handle_tag = f"{self.endpoint_name}#{get_random_letters()}" self._pickled_http_request = _internal_pickled_http_request @@ -181,21 +179,11 @@ def __reduce__(self): "controller_handle": self.controller_handle, "endpoint_name": self.endpoint_name, "handle_options": self.handle_options, - "known_python_methods": self.known_python_methods, "_internal_pickled_http_request": self._pickled_http_request, } return lambda kwargs: RayServeHandle(**kwargs), (serialized_data, ) def __getattr__(self, name): - if name not in self.known_python_methods: - raise AttributeError( - f"ServeHandle for endpoint {self.endpoint_name} doesn't have " - f"python method {name}. If you used the " - f"get_handle('{self.endpoint_name}', missing_ok=True) flag, " - f"Serve cannot know all methods for {self.endpoint_name}. " - "You can set the method manually via " - f"handle.options(method_name='{name}').remote().") - return self.options(method_name=name) @@ -237,7 +225,6 @@ def __reduce__(self): "controller_handle": self.controller_handle, "endpoint_name": self.endpoint_name, "handle_options": self.handle_options, - "known_python_methods": self.known_python_methods, "_internal_pickled_http_request": self._pickled_http_request, } return lambda kwargs: RayServeSyncHandle(**kwargs), (serialized_data, ) diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index 7eedc17fcfd5a..e129f5d60cab5 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -259,8 +259,11 @@ def __init__(self, port: int, controller_name: str, controller_namespace: str, - http_middlewares: List[ - "starlette.middleware.Middleware"] = []): # noqa: F821 + http_middlewares: Optional[List[ + "starlette.middleware.Middleware"]] = None): # noqa: F821 + if http_middlewares is None: + http_middlewares = [] + self.host = host self.port = port diff --git a/python/ray/serve/long_poll.py b/python/ray/serve/long_poll.py index b1133adb5a251..9d5a31bf86e6b 100644 --- a/python/ray/serve/long_poll.py +++ b/python/ray/serve/long_poll.py @@ -103,13 +103,14 @@ def _process_update(self, updates: Dict[str, UpdatedObject]): "Shutting down.") return + if isinstance(updates, ConnectionError): + logger.warning("LongPollClient connection failed, shutting down.") + return + if isinstance(updates, (ray.exceptions.RayTaskError)): - # This can happen during shutdown where the controller doesn't - # contain this key, we will just repull. - # NOTE(simon): should we repull or just wait in the long poll - # host? - if not isinstance(updates.as_instanceof_cause(), ValueError): - logger.error("LongPollHost errored\n" + updates.traceback_str) + # Some error happened in the controller. It could be a bug or some + # undesired state. + logger.error("LongPollHost errored\n" + updates.traceback_str) self._poll_next() return @@ -167,22 +168,21 @@ async def listen_for_change( until there's one updates. """ watched_keys = keys_to_snapshot_ids.keys() - nonexistent_keys = set(watched_keys) - set(self.snapshot_ids.keys()) - if len(nonexistent_keys) > 0: - raise ValueError(f"Keys not found: {nonexistent_keys}.") + existent_keys = set(watched_keys).intersection( + set(self.snapshot_ids.keys())) - # 2. If there are any outdated keys (by comparing snapshot ids) - # return immediately. + # If there are any outdated keys (by comparing snapshot ids) + # return immediately. client_outdated_keys = { key: UpdatedObject(self.object_snapshots[key], self.snapshot_ids[key]) - for key in watched_keys + for key in existent_keys if self.snapshot_ids[key] != keys_to_snapshot_ids[key] } if len(client_outdated_keys) > 0: return client_outdated_keys - # 3. Otherwise, register asyncio events to be waited. + # Otherwise, register asyncio events to be waited. async_task_to_watched_keys = {} for key in watched_keys: # Create a new asyncio event for this key diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/replica.py similarity index 91% rename from python/ray/serve/backend_worker.py rename to python/ray/serve/replica.py index a049bdfac3a84..cc90ada23fd7e 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/replica.py @@ -15,12 +15,11 @@ from ray.serve.autoscaling_metrics import start_metrics_pusher from ray.serve.common import BackendTag, ReplicaTag +from ray.serve.config import BackendConfig from ray.serve.http_util import ASGIHTTPSender from ray.serve.utils import parse_request_item, _get_logger from ray.serve.exceptions import RayServeException from ray.util import metrics -from ray.serve.config import BackendConfig -from ray.serve.long_poll import LongPollClient, LongPollNamespace from ray.serve.router import Query, RequestMetadata from ray.serve.constants import ( BACKEND_RECONFIGURE_METHOD, @@ -32,7 +31,7 @@ logger = _get_logger() -def create_backend_replica(name: str, serialized_backend_def: bytes): +def create_replica_wrapper(name: str, serialized_backend_def: bytes): """Creates a replica class wrapping the provided function or class. This approach is picked over inheritance to avoid conflict between user @@ -43,7 +42,7 @@ def create_backend_replica(name: str, serialized_backend_def: bytes): # TODO(architkulkarni): Add type hints after upgrading cloudpickle class RayServeWrappedReplica(object): async def __init__(self, backend_tag, replica_tag, init_args, - backend_config_proto_bytes: bytes, + init_kwargs, backend_config_proto_bytes: bytes, version: BackendVersion, controller_name: str, detached: bool): backend = cloudpickle.loads(serialized_backend_def) @@ -72,7 +71,8 @@ async def __init__(self, backend_tag, replica_tag, init_args, # This allows backends to define an async __init__ method # (required for FastAPI backend definition). _callable = backend.__new__(backend) - await sync_to_async(_callable.__init__)(*init_args) + await sync_to_async(_callable.__init__)(*init_args, + **init_kwargs) # Setting the context again to update the servable_object. ray.serve.api._set_internal_replica_context( backend_tag, @@ -149,8 +149,6 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, self.replica_tag = replica_tag self.callable = _callable self.is_function = is_function - - self.backend_config = backend_config self.user_config = user_config self.version = version @@ -166,16 +164,6 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, "replica": self.replica_tag }) - self.loop = asyncio.get_event_loop() - self.long_poll_client = LongPollClient( - controller_handle, - { - (LongPollNamespace.BACKEND_CONFIGS, self.backend_tag): self. - _update_backend_configs, - }, - call_in_event_loop=self.loop, - ) - self.error_counter = metrics.Counter( "serve_deployment_error_counter", description=("The number of exceptions that have " @@ -217,6 +205,9 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, self.restart_counter.inc() + self._shutdown_wait_loop_s = ( + backend_config.graceful_shutdown_wait_loop_s) + if backend_config.autoscaling_config: config = backend_config.autoscaling_config start_metrics_pusher( @@ -240,10 +231,19 @@ def _collect_autoscaling_metrics(self): def get_runner_method(self, request_item: Query) -> Callable: method_name = request_item.metadata.call_method if not hasattr(self.callable, method_name): - raise RayServeException("Backend doesn't have method {} " - "which is specified in the request. " - "The available methods are {}".format( - method_name, dir(self.callable))) + # Filter to methods that don't start with '__' prefix. + def callable_method_filter(attr): + if attr.startswith("__"): + return False + elif not callable(getattr(self.callable, attr)): + return False + + return True + + methods = list(filter(callable_method_filter, dir(self.callable))) + raise RayServeException(f"Tried to call a method '{method_name}' " + "that does not exist. Available methods: " + f"{methods}.") if self.is_function: return self.callable return getattr(self.callable, method_name) @@ -309,9 +309,6 @@ async def reconfigure(self, getattr(self.callable, BACKEND_RECONFIGURE_METHOD)) await reconfigure_method(user_config) - def _update_backend_configs(self, new_config_bytes: bytes) -> None: - self.backend_config = BackendConfig.from_proto_bytes(new_config_bytes) - async def handle_request(self, request: Query) -> asyncio.Future: request.tick_enter_replica = time.time() logger.debug("Replica {} received request {}".format( @@ -341,18 +338,17 @@ async def prepare_for_shutdown(self): Trigger a graceful shutdown protocol that will wait for all the queued tasks to be completed and return to the controller. """ - sleep_time = self.backend_config.experimental_graceful_shutdown_wait_loop_s # noqa: E501 while True: # Sleep first because we want to make sure all the routers receive # the notification to remove this replica first. - await asyncio.sleep(sleep_time) + await asyncio.sleep(self._shutdown_wait_loop_s) if self.num_ongoing_requests == 0: break else: logger.info( - f"Waiting for an additional {sleep_time}s to shut down " - f"because there are {self.num_ongoing_requests} " - "ongoing requests.") + "Waiting for an additional " + f"{self._shutdown_wait_loop_s}s to shut down because " + f"there are {self.num_ongoing_requests} ongoing requests.") # Explicitly call the del method to trigger clean up. # We set the del method to noop after succssifully calling it so the diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index fa18456546fa2..90fce03b03ceb 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -1,3 +1,4 @@ +import sys import asyncio import pickle import itertools @@ -66,7 +67,14 @@ def __init__( # Used to unblock this replica set waiting for free replicas. A newly # added replica or updated max_concurrent_queries value means the # query that waits on a free replica might be unblocked on. - self.config_updated_event = asyncio.Event(loop=event_loop) + + # Python 3.8 has deprecated the 'loop' parameter, and Python 3.10 has + # removed it alltogether. Call accordingly. + if sys.version_info.major >= 3 and sys.version_info.minor >= 10: + self.config_updated_event = asyncio.Event() + else: + self.config_updated_event = asyncio.Event(loop=event_loop) + self.num_queued_queries = 0 self.num_queued_queries_gauge = metrics.Gauge( "serve_deployment_queued_queries", diff --git a/python/ray/serve/storage/checkpoint_path.py b/python/ray/serve/storage/checkpoint_path.py index de892f0728978..f6abc8da22566 100644 --- a/python/ray/serve/storage/checkpoint_path.py +++ b/python/ray/serve/storage/checkpoint_path.py @@ -32,7 +32,9 @@ def make_kv_store(checkpoint_path, namespace): if parsed_url.scheme == "s3": bucket = parsed_url.netloc - prefix = parsed_url.path + # We need to strip leading "/" in path as right key to use in + # boto3. Ex: s3://bucket/folder/file.zip -> key = "folder/file.zip" + prefix = parsed_url.path.lstrip("/") logger.info( "Using Ray S3 KVStore for controller checkpoint and recovery: " f"bucket={bucket} checkpoint_path={checkpoint_path}") diff --git a/python/ray/serve/storage/kv_store.py b/python/ray/serve/storage/kv_store.py index 74b17d7a75932..ea24de4541abe 100644 --- a/python/ray/serve/storage/kv_store.py +++ b/python/ray/serve/storage/kv_store.py @@ -186,7 +186,7 @@ def __init__( ): self._namespace = namepsace self._bucket = bucket - self._prefix = prefix + "/" if prefix else "" + self._prefix = prefix if not boto3: raise ImportError( "You tried to use S3KVstore client without boto3 installed." @@ -199,7 +199,7 @@ def __init__( aws_session_token=aws_session_token) def get_storage_key(self, key: str) -> str: - return f"{self._prefix}{self._namespace}-{key}" + return f"{self._prefix}/{self._namespace}-{key}" def put(self, key: str, val: bytes) -> bool: """Put the key-value pair into the store. diff --git a/python/ray/serve/tests/conftest.py b/python/ray/serve/tests/conftest.py index 1e635dd1b647a..36fdba0d5b7cc 100644 --- a/python/ray/serve/tests/conftest.py +++ b/python/ray/serve/tests/conftest.py @@ -9,6 +9,13 @@ serve.controller._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5 +@pytest.fixture +def ray_shutdown(): + yield + serve.shutdown() + ray.shutdown() + + @pytest.fixture(scope="session") def _shared_serve_instance(): # Note(simon): diff --git a/python/ray/serve/tests/test_advanced.py b/python/ray/serve/tests/test_advanced.py index 74287ed358bef..03f606e58fcbd 100644 --- a/python/ray/serve/tests/test_advanced.py +++ b/python/ray/serve/tests/test_advanced.py @@ -9,12 +9,11 @@ def test_serve_forceful_shutdown(serve_instance): - @serve.deployment + @serve.deployment(_graceful_shutdown_timeout_s=0.1) def sleeper(): while True: time.sleep(1000) - sleeper._config.experimental_graceful_shutdown_timeout_s = 0.1 sleeper.deploy() handle = sleeper.get_handle() @@ -28,14 +27,15 @@ def sleeper(): def test_serve_graceful_shutdown(serve_instance): signal = SignalActor.remote() - @serve.deployment(name="wait", max_concurrent_queries=10) + @serve.deployment( + name="wait", + max_concurrent_queries=10, + _graceful_shutdown_timeout_s=1000, + _graceful_shutdown_wait_loop_s=0.5) class Wait: async def __call__(self, signal_actor): await signal_actor.wait.remote() - return "" - Wait._config.experimental_graceful_shutdown_wait_loop_s = 0.5 - Wait._config.experimental_graceful_shutdown_timeout_s = 1000 Wait.deploy() handle = Wait.get_handle() refs = [handle.remote(signal) for _ in range(10)] diff --git a/python/ray/serve/tests/test_autoscaling_metrics.py b/python/ray/serve/tests/test_autoscaling_metrics.py index e641f515d372d..d8a92d8a28b7a 100644 --- a/python/ray/serve/tests/test_autoscaling_metrics.py +++ b/python/ray/serve/tests/test_autoscaling_metrics.py @@ -59,20 +59,20 @@ def test_e2e(serve_instance): "min_replicas": 1, "max_replicas": 1 }, - max_concurrent_queries=1000) + # We will send over a lot of queries. This will make sure replicas are + # killed quickly during cleanup. + _graceful_shutdown_timeout_s=1, + max_concurrent_queries=1000, + version="v1") class A: def __call__(self): time.sleep(0.5) - # We will send over a lot of queries. This will make sure replicas are - # killed quickly during cleanup. - A._config.experimental_graceful_shutdown_timeout_s = 1 - A.deploy() handle = A.get_handle() [handle.remote() for _ in range(100)] - # Wait for metrics to propogate + # Wait for metrics to propagate def get_data(): return ray.get(serve_instance._controller. _dump_autoscaling_metrics_for_testing.remote()) diff --git a/python/ray/serve/tests/test_autoscaling_policy.py b/python/ray/serve/tests/test_autoscaling_policy.py index 56fb72ac4eea1..e72c2f68b65ce 100644 --- a/python/ray/serve/tests/test_autoscaling_policy.py +++ b/python/ray/serve/tests/test_autoscaling_policy.py @@ -1,3 +1,11 @@ +import sys +import time +import pytest + +import ray +from ray import serve +from ray._private.test_utils import wait_for_condition +from ray.serve.backend_state import ReplicaState from ray.serve.config import AutoscalingConfig from ray.serve.autoscaling_policy import calculate_desired_num_replicas @@ -71,3 +79,47 @@ def test_smoothing_factor(self): autoscaling_config=config, current_num_ongoing_requests=num_ongoing_requests) assert 5 <= desired_num_replicas <= 8 # 10 + 0.5 * (2.5 - 10) = 6.25 + + +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +def test_e2e_basic_scale_up_down(serve_instance): + """Send 100 requests and check that we autoscale up, and then back down.""" + + @serve.deployment( + _autoscaling_config={ + "metrics_interval_s": 0.1, + "min_replicas": 1, + "max_replicas": 2, + "look_back_period_s": 0.2 + }, + # We will send over a lot of queries. This will make sure replicas are + # killed quickly during cleanup. + _graceful_shutdown_timeout_s=1, + max_concurrent_queries=1000, + version="v1") + class A: + def __call__(self): + time.sleep(1) + + A.deploy() + handle = A.get_handle() + [handle.remote() for _ in range(100)] + + controller = serve_instance._controller + + def get_num_running_replicas(): + replicas = ray.get( + controller._dump_replica_states_for_testing.remote("A")) + running_replicas = replicas.get([ReplicaState.RUNNING]) + return len(running_replicas) + + wait_for_condition(lambda: get_num_running_replicas() >= 2) + + # As the queue is drained, we should scale back down. + wait_for_condition(lambda: get_num_running_replicas() <= 1) + + +if __name__ == "__main__": + import sys + import pytest + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_backend_state.py b/python/ray/serve/tests/test_backend_state.py index aa31dc6a9d82a..0112868821388 100644 --- a/python/ray/serve/tests/test_backend_state.py +++ b/python/ray/serve/tests/test_backend_state.py @@ -1,3 +1,5 @@ +import os +import sys import time from typing import Any, Dict, List, Optional, Tuple from unittest.mock import patch, Mock @@ -181,6 +183,7 @@ def set_starting_version(self, version: BackendVersion): def start(self, backend_info: BackendInfo, version: BackendVersion): self.started = True self.version = version + self.backend_info = backend_info def update_user_config(self, user_config: Any): self.started = True @@ -218,6 +221,7 @@ def available_resources(self) -> Dict[str, float]: def graceful_stop(self) -> None: assert self.started self.stopped = True + return self.backend_info.backend_config.graceful_shutdown_timeout_s def check_stopped(self) -> bool: return self.done_stopping @@ -526,9 +530,6 @@ def test_create_delete_single_replica(mock_backend_state): # Now the replica should be marked running. backend_state.update() check_counts(backend_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - - # TODO(edoakes): can we remove this extra update period for completing it? - backend_state.update() assert goal_manager.check_complete(create_goal) # Removing the replica should transition it to stopping. @@ -542,12 +543,9 @@ def test_create_delete_single_replica(mock_backend_state): # Once it's done stopping, replica should be removed. replica = backend_state._replicas.get()[0] replica._actor.set_done_stopping() - backend_state.update() - check_counts(backend_state, total=0) - - # TODO(edoakes): can we remove this extra update period for completing it? deleted = backend_state.update() assert deleted + check_counts(backend_state, total=0) assert goal_manager.check_complete(delete_goal) assert replica._actor.cleaned_up @@ -557,7 +555,7 @@ def test_force_kill(mock_backend_state): grace_period_s = 10 b_info_1, b_version_1 = backend_info( - experimental_graceful_shutdown_timeout_s=grace_period_s) + graceful_shutdown_timeout_s=grace_period_s) # Create and delete the backend. backend_state.deploy(b_info_1) @@ -571,8 +569,8 @@ def test_force_kill(mock_backend_state): check_counts(backend_state, total=1, by_state=[(ReplicaState.STOPPING, 1)]) assert backend_state._replicas.get()[0]._actor.stopped - backend_state.update() - backend_state.update() + for _ in range(10): + backend_state.update() # force_stop shouldn't be called until after the timer. assert not backend_state._replicas.get()[0]._actor.force_stopped_counter @@ -597,12 +595,9 @@ def test_force_kill(mock_backend_state): # Once the replica is done stopping, it should be removed. replica = backend_state._replicas.get()[0] replica._actor.set_done_stopping() - backend_state.update() - check_counts(backend_state, total=0) - - # TODO(edoakes): can we remove this extra update period for completing it? deleted = backend_state.update() assert deleted + check_counts(backend_state, total=0) assert goal_manager.check_complete(delete_goal) assert replica._actor.cleaned_up @@ -644,8 +639,6 @@ def test_redeploy_same_version(mock_backend_state): version=b_version_1, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - - backend_state.update() assert goal_manager.check_complete(goal_1) # Test redeploying after the initial deployment has finished. @@ -727,12 +720,10 @@ def test_redeploy_no_version(mock_backend_state): states=[ReplicaState.STARTING])[0]._actor.set_ready() check_counts(backend_state, total=1, by_state=[(ReplicaState.STARTING, 1)]) - backend_state.update() - check_counts(backend_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - deleted = backend_state.update() - assert goal_manager.check_complete(goal_3) assert not deleted + check_counts(backend_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) + assert goal_manager.check_complete(goal_3) def test_redeploy_new_version(mock_backend_state): @@ -826,16 +817,14 @@ def test_redeploy_new_version(mock_backend_state): total=1, by_state=[(ReplicaState.STARTING, 1)]) - backend_state.update() + deleted = backend_state.update() + assert not deleted check_counts( backend_state, version=b_version_3, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - - deleted = backend_state.update() assert goal_manager.check_complete(goal_3) - assert not deleted def test_deploy_new_config_same_version(mock_backend_state): @@ -855,7 +844,6 @@ def test_deploy_new_config_same_version(mock_backend_state): version=b_version_1, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - backend_state.update() assert goal_manager.check_complete(goal_id) # Update to a new config without changing the version. @@ -886,8 +874,6 @@ def test_deploy_new_config_same_version(mock_backend_state): version=b_version_2, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - - backend_state.update() assert goal_manager.check_complete(goal_id) @@ -907,7 +893,6 @@ def test_deploy_new_config_new_version(mock_backend_state): version=b_version_1, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - backend_state.update() assert goal_manager.check_complete(create_goal) # Update to a new config and a new version. @@ -945,8 +930,6 @@ def test_deploy_new_config_new_version(mock_backend_state): version=b_version_2, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - - backend_state.update() assert goal_manager.check_complete(update_goal) @@ -966,8 +949,6 @@ def test_initial_deploy_no_throttling(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() - backend_state.update() - # Check that the new replicas have started. backend_state.update() check_counts( @@ -994,8 +975,6 @@ def test_new_version_deploy_throttling(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() - backend_state.update() - # Check that the new replicas have started. backend_state.update() check_counts( @@ -1236,8 +1215,6 @@ def test_new_version_deploy_throttling(mock_backend_state): version=b_version_2, total=10, by_state=[(ReplicaState.RUNNING, 10)]) - - backend_state.update() assert goal_manager.check_complete(goal_2) @@ -1258,8 +1235,6 @@ def test_reconfigure_throttling(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() - backend_state.update() - # Check that the new replicas have started. backend_state.update() check_counts(backend_state, total=2, by_state=[(ReplicaState.RUNNING, 2)]) @@ -1318,8 +1293,6 @@ def test_reconfigure_throttling(mock_backend_state): version=b_version_2, total=2, by_state=[(ReplicaState.RUNNING, 2)]) - - backend_state.update() assert goal_manager.check_complete(goal_1) @@ -1341,8 +1314,6 @@ def test_new_version_and_scale_down(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() - backend_state.update() - # Check that the new replicas have started. backend_state.update() check_counts( @@ -1479,8 +1450,6 @@ def test_new_version_and_scale_down(mock_backend_state): version=b_version_2, total=2, by_state=[(ReplicaState.RUNNING, 2)]) - - backend_state.update() assert goal_manager.check_complete(goal_2) @@ -1501,8 +1470,6 @@ def test_new_version_and_scale_up(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() - backend_state.update() - # Check that the new replicas have started. backend_state.update() check_counts(backend_state, total=2, by_state=[(ReplicaState.RUNNING, 2)]) @@ -1610,8 +1577,6 @@ def test_health_check(mock_backend_state): # Check that the new replicas have started. backend_state.update() check_counts(backend_state, total=2, by_state=[(ReplicaState.RUNNING, 2)]) - - backend_state.update() assert goal_manager.check_complete(goal_1) backend_state.update() @@ -1859,6 +1824,9 @@ def mock_backend_state_manager( yield backend_state_manager, timer, goal_manager # Clear checkpoint at the end of each test kv_store.delete(CHECKPOINT_KEY) + if sys.platform != "win32": + # This line fails on windows with a PermissionError. + os.remove("test_kv_store.db") def test_shutdown(mock_backend_state_manager): @@ -1870,7 +1838,9 @@ def test_shutdown(mock_backend_state_manager): tag = "test" - b_info_1, b_version_1 = backend_info() + grace_period_s = 10 + b_info_1, b_version_1 = backend_info( + graceful_shutdown_timeout_s=grace_period_s) create_goal, updating = backend_state_manager.deploy_backend(tag, b_info_1) backend_state = backend_state_manager._backend_states[tag] @@ -1889,25 +1859,21 @@ def test_shutdown(mock_backend_state_manager): shutdown_goal = backend_state_manager.shutdown()[0] + timer.advance(grace_period_s + 0.1) backend_state_manager.update() check_counts(backend_state, total=1, by_state=[(ReplicaState.STOPPING, 1)]) assert backend_state._replicas.get()[0]._actor.stopped - assert backend_state._replicas.get()[0]._actor.force_stopped_counter == 1 assert not backend_state._replicas.get()[0]._actor.cleaned_up assert not goal_manager.check_complete(shutdown_goal) # Once it's done stopping, replica should be removed. replica = backend_state._replicas.get()[0] replica._actor.set_done_stopping() - backend_state.update() - check_counts(backend_state, total=0) - - # TODO(edoakes): can we remove this extra update period for completing it? backend_state_manager.update() + check_counts(backend_state, total=0) assert goal_manager.check_complete(shutdown_goal) assert replica._actor.cleaned_up - assert len(backend_state_manager._backend_states) == 0 @@ -1974,5 +1940,4 @@ def test_resume_backend_state_from_replica_tags(mock_backend_state_manager): if __name__ == "__main__": - import sys sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_config.py b/python/ray/serve/tests/test_config.py index dd30aeab0f77f..8c71cf8ae3a91 100644 --- a/python/ray/serve/tests/test_config.py +++ b/python/ray/serve/tests/test_config.py @@ -52,6 +52,8 @@ def function(_): # Check ray_actor_options validation. ReplicaConfig( Class, + tuple(), + dict(), ray_actor_options={ "num_cpus": 1.0, "num_gpus": 10, diff --git a/python/ray/serve/tests/test_deploy.py b/python/ray/serve/tests/test_deploy.py index d593081a43a9b..ab5ab2e2f3d75 100644 --- a/python/ray/serve/tests/test_deploy.py +++ b/python/ray/serve/tests/test_deploy.py @@ -10,6 +10,7 @@ import ray from ray._private.test_utils import SignalActor, wait_for_condition from ray import serve +from ray.serve.exceptions import RayServeException from ray.serve.utils import get_random_letters @@ -676,8 +677,8 @@ def b(self, *args): assert ray.get(handle.options(method_name="b").remote()) == "hello" # New code path assert ray.get(handle.b.remote()) == "hello" - with pytest.raises(AttributeError): - handle.c.remote() + with pytest.raises(RayServeException): + ray.get(handle.c.remote()) def test_init_args(serve_instance): @@ -733,6 +734,58 @@ def check(*args): check(10, 11, 12) +def test_init_kwargs(serve_instance): + with pytest.raises(TypeError): + + @serve.deployment(init_kwargs=[1, 2, 3]) + class BadInitArgs: + pass + + @serve.deployment(init_kwargs={"a": 1, "b": 2}) + class D: + def __init__(self, **kwargs): + self._kwargs = kwargs + + def get_kwargs(self, *args): + return self._kwargs + + D.deploy() + handle = D.get_handle() + + def check(kwargs): + assert ray.get(handle.get_kwargs.remote()) == kwargs + + # Basic sanity check. + check({"a": 1, "b": 2}) + + # Check passing args to `.deploy()`. + D.deploy(a=3, b=4) + check({"a": 3, "b": 4}) + + # Passing args to `.deploy()` shouldn't override those passed in decorator. + D.deploy() + check({"a": 1, "b": 2}) + + # Check setting with `.options()`. + new_D = D.options(init_kwargs={"c": 8, "d": 10}) + new_D.deploy() + check({"c": 8, "d": 10}) + + # Should not have changed old deployment object. + D.deploy() + check({"a": 1, "b": 2}) + + # Check that args are only updated on version change. + D.options(version="1").deploy() + check({"a": 1, "b": 2}) + + D.options(version="1").deploy(c=10, d=11) + check({"a": 1, "b": 2}) + + D.options(version="2").deploy(c=10, d=11) + check({"c": 10, "d": 11}) + + def test_input_validation(): name = "test" diff --git a/python/ray/serve/tests/test_get_deployment.py b/python/ray/serve/tests/test_get_deployment.py index cb1d6c9484e31..1f6968abe4974 100644 --- a/python/ray/serve/tests/test_get_deployment.py +++ b/python/ray/serve/tests/test_get_deployment.py @@ -116,6 +116,37 @@ def __call__(self, *arg): assert pid3 != pid2 +def test_init_kwargs(serve_instance): + name = "test" + + @serve.deployment(name=name) + class D: + def __init__(self, *, val=None): + assert val is not None + self._val = val + + def __call__(self, *arg): + return self._val, os.getpid() + + D.deploy(val="1") + val1, pid1 = ray.get(D.get_handle().remote()) + assert val1 == "1" + + del D + + D2 = serve.get_deployment(name=name) + D2.deploy() + val2, pid2 = ray.get(D2.get_handle().remote()) + assert val2 == "1" + assert pid2 != pid1 + + D2 = serve.get_deployment(name=name) + D2.deploy(val="2") + val3, pid3 = ray.get(D2.get_handle().remote()) + assert val3 == "2" + assert pid3 != pid2 + + def test_scale_replicas(serve_instance): name = "test" diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index 95c55aba35b3e..360fb3336b247 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -1,9 +1,10 @@ +import concurrent.futures import pytest import requests import ray -import concurrent.futures from ray import serve +from ray.serve.exceptions import RayServeException @pytest.mark.asyncio @@ -167,6 +168,30 @@ def call(): ray.get(obj_ref) +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False]) +async def test_nonexistent_method(serve_instance, sync): + @serve.deployment + class A: + def exists(self): + pass + + A.deploy() + handle = A.get_handle(sync=sync) + + if sync: + obj_ref = handle.does_not_exist.remote() + else: + obj_ref = await handle.does_not_exist.remote() + + with pytest.raises(RayServeException) as excinfo: + ray.get(obj_ref) + + exception_string = str(excinfo.value) + assert "'does_not_exist'" in exception_string + assert "Available methods: ['exists']" in exception_string + + if __name__ == "__main__": import sys import pytest diff --git a/python/ray/serve/tests/test_long_poll.py b/python/ray/serve/tests/test_long_poll.py index 79cf0c841ea35..2081e705d976e 100644 --- a/python/ray/serve/tests/test_long_poll.py +++ b/python/ray/serve/tests/test_long_poll.py @@ -37,6 +37,20 @@ def test_host_standalone(serve_instance): assert "key_2" in result +def test_long_poll_wait_for_keys(serve_instance): + # Variation of the basic case, but the keys are requests before any values + # are set. + host = ray.remote(LongPollHost).remote() + object_ref = host.listen_for_change.remote({"key_1": -1, "key_2": -1}) + ray.get(host.notify_changed.remote("key_1", 999)) + ray.get(host.notify_changed.remote("key_2", 999)) + + # We should be able to get the one of the result immediately + result: Dict[str, UpdatedObject] = ray.get(object_ref) + assert set(result.keys()).issubset({"key_1", "key_2"}) + assert {v.object_snapshot for v in result.values()} == {999} + + def test_long_poll_restarts(serve_instance): @ray.remote( max_restarts=-1, diff --git a/python/ray/serve/tests/test_ray_client.py b/python/ray/serve/tests/test_ray_client.py index 7bc2d54aad388..db640970eedbe 100644 --- a/python/ray/serve/tests/test_ray_client.py +++ b/python/ray/serve/tests/test_ray_client.py @@ -126,7 +126,7 @@ def hello(request): @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows") -def test_quickstart_task(serve_with_client): +def test_quickstart_counter(serve_with_client): serve.start() @serve.deployment @@ -140,10 +140,13 @@ def __call__(self, *args): # Deploy our class. Counter.deploy() + print("deploy finished") # Query our endpoint in two different ways: from HTTP and from Python. assert requests.get("http://127.0.0.1:8000/Counter").json() == {"count": 1} + print("query 1 finished") assert ray.get(Counter.get_handle().remote()) == {"count": 2} + print("query 2 finished") if __name__ == "__main__": diff --git a/python/ray/serve/tests/test_regression.py b/python/ray/serve/tests/test_regression.py index e4d519bc06ffc..9ac205803e492 100644 --- a/python/ray/serve/tests/test_regression.py +++ b/python/ray/serve/tests/test_regression.py @@ -71,7 +71,7 @@ async def __call__(self, _request): assert result.json() == 100.0 -def test_backend_worker_memory_growth(serve_instance): +def test_replica_memory_growth(serve_instance): # https://github.com/ray-project/ray/issues/12395 @serve.deployment(name="model") def gc_unreachable_objects(*args): diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index ce8183b2d8577..1c6df064247f5 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -27,13 +27,6 @@ import ray._private.gcs_utils as gcs_utils -@pytest.fixture -def ray_shutdown(): - yield - serve.shutdown() - ray.shutdown() - - @pytest.fixture def ray_cluster(): cluster = Cluster() @@ -102,7 +95,7 @@ def test_detached_deployment(ray_cluster): # https://github.com/ray-project/ray/issues/11437 cluster = ray_cluster - head_node = cluster.add_node(node_ip_address="127.0.0.1", num_cpus=6) + head_node = cluster.add_node(num_cpus=6) # Create first job, check we can run a simple serve endpoint ray.init(head_node.address, namespace="serve") diff --git a/python/ray/sgd/__init__.py b/python/ray/sgd/__init__.py index c5d4677aa041e..d5f8ec4c0d6f1 100644 --- a/python/ray/sgd/__init__.py +++ b/python/ray/sgd/__init__.py @@ -1,2 +1 @@ -from ray.util.sgd.v2 import * # noqa: F401, F403 -from ray.util.sgd.v2.callbacks import JsonLoggerCallback, TBXLoggerCallback # noqa: E501, F401, F403 +from ray.util.sgd.v2 import * # noqa: F401, F403 diff --git a/python/ray/sgd/callbacks.py b/python/ray/sgd/callbacks.py new file mode 100644 index 0000000000000..9b85815190b9b --- /dev/null +++ b/python/ray/sgd/callbacks.py @@ -0,0 +1 @@ +from ray.util.sgd.v2.callbacks import * # noqa: E501, F401, F403 diff --git a/python/ray/state.py b/python/ray/state.py index 3c2f2185caffb..b074fd4062641 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -1,7 +1,6 @@ from collections import defaultdict import json import logging -import os import ray @@ -50,10 +49,6 @@ def _check_connected(self): # _really_init_global_state should have set self.global_state_accessor if self.global_state_accessor is None: - if os.environ.get("RAY_ENABLE_AUTO_CONNECT", "") != "0": - ray.client().connect() - # Retry connect! - return self._check_connected() raise ray.exceptions.RaySystemError( "Ray has not been started yet. You can start Ray with " "'ray.init()'.") @@ -720,6 +715,7 @@ def _live_node_ids(self): def _available_resources_per_node(self): """Returns a dictionary mapping node id to avaiable resources.""" + self._check_connected() available_resources_by_id = {} all_available_resources = \ @@ -811,7 +807,7 @@ def next_job_id(): @DeveloperAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def nodes(): """Get a list of the nodes in the cluster (for debugging only). @@ -875,7 +871,7 @@ def actors(actor_id=None): @DeveloperAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def timeline(filename=None): """Return a list of profiling events that can viewed as a timeline. @@ -917,7 +913,7 @@ def object_transfer_timeline(filename=None): @DeveloperAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def cluster_resources(): """Get the current total cluster resources. @@ -932,7 +928,7 @@ def cluster_resources(): @DeveloperAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def available_resources(): """Get the current available cluster resources. diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index f854f00e560e7..1fad95edee1e3 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -48,6 +48,7 @@ py_test_module_list( files = [ "test_client.py", "test_client_builder.py", + "test_client_compat.py", "test_client_init.py", "test_client_multi.py", "test_client_proxy.py", @@ -77,12 +78,12 @@ py_test_module_list( "test_placement_group.py", "test_placement_group_2.py", "test_placement_group_3.py", - "test_placement_group_mini_integration.py", "test_ray_init.py", "test_reconstruction.py", "test_reference_counting.py", "test_resource_demand_scheduler.py", "test_runtime_env_env_vars.py", + "test_runtime_env_plugin.py", "test_runtime_env_fork_process.py", "test_serialization.py", "test_shuffle.py", @@ -101,6 +102,7 @@ py_test_module_list( py_test_module_list( files = [ + "test_autoscaler_fake_multinode.py", # Temporarily owned by core. "test_args.py", "test_asyncio_cluster.py", "test_asyncio.py", @@ -167,6 +169,7 @@ py_test_module_list( "test_failure_4.py", "test_object_spilling.py", "test_plasma_unlimited.py", + "test_placement_group_mini_integration.py", ], size = "large", extra_srcs = SRCS, @@ -300,6 +303,14 @@ py_test( deps = ["//:ray_lib"], ) +py_test( + name = "test_runtime_env_validation", + size = "small", + srcs = SRCS + ["test_runtime_env_validation.py"], + tags = ["exclusive", "team:serve"], + deps = ["//:ray_lib"], +) + # TODO(ekl) we can't currently support tagging these as flaky since there's # no way to filter by both flaky and client mode tests in bazel. py_test_module_list( diff --git a/python/ray/tests/client_test_utils.py b/python/ray/tests/client_test_utils.py index c7b0081d3274c..30c016d32bd3a 100644 --- a/python/ray/tests/client_test_utils.py +++ b/python/ray/tests/client_test_utils.py @@ -18,3 +18,20 @@ async def wait(self, should_wait=True): await self.ready_event.wait() return SignalActor + + +# See test_client::test_wrapped_actor_creation for details on usage of +# run_wrapped_actor_creation and SomeClass. +def run_wrapped_actor_creation(): + import ray + RemoteClass = ray.remote(SomeClass) + handle = RemoteClass.remote() + return ray.get(handle.ready.remote()) + + +class SomeClass: + def __init__(self): + pass + + def ready(self): + return 1 diff --git a/python/ray/tests/mock_setup_worker.py b/python/ray/tests/mock_setup_worker.py index a19a9ce22d1fd..7cd981b9ac00f 100644 --- a/python/ray/tests/mock_setup_worker.py +++ b/python/ray/tests/mock_setup_worker.py @@ -30,6 +30,9 @@ parser.add_argument( "--session-dir", type=str, help="the directory for the current session") +parser.add_argument( + "--language", type=str, help="the language type of the worker") + args, remaining_args = parser.parse_known_args() # add worker-shim-pid argument diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py index b7962ff71e44b..041e5e7bb559a 100644 --- a/python/ray/tests/test_advanced.py +++ b/python/ray/tests/test_advanced.py @@ -777,14 +777,13 @@ def method(self): # This case tests whether RequestWorkerLeaseReply carries normal task resources # when the request is rejected (due to resource preemption by normal tasks). -@pytest.mark.skip( - reason="The period of pull based resource report (10ms) is hard-coded.") +@pytest.mark.skipif(sys.platform == "win32", reason="Time out on Windows") def test_worker_lease_reply_with_resources(ray_start_cluster): cluster = ray_start_cluster cluster.add_node( memory=2000 * 1024**2, _system_config={ - "raylet_report_resources_period_milliseconds": 1000000, + "gcs_resource_report_poll_period_ms": 1000000, "gcs_actor_scheduling_enabled": True, }) node2 = cluster.add_node(memory=1000 * 1024**2) diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index a03850916328a..90d3de16dd60d 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -1,5 +1,6 @@ # coding: utf-8 import glob +import json import logging import os import sys @@ -726,20 +727,19 @@ def test_k8s_cpu(): def test_sync_job_config(shutdown_only): num_java_workers_per_process = 8 - worker_env = { - "key": "value", - } + runtime_env = {"env_vars": {"key": "value"}} ray.init( job_config=ray.job_config.JobConfig( num_java_workers_per_process=num_java_workers_per_process, - worker_env=worker_env)) + runtime_env=runtime_env)) # Check that the job config is synchronized at the driver side. job_config = ray.worker.global_worker.core_worker.get_job_config() assert (job_config.num_java_workers_per_process == num_java_workers_per_process) - assert (job_config.worker_env == worker_env) + job_runtime_env = json.loads(job_config.runtime_env.serialized_runtime_env) + assert job_runtime_env["env_vars"] == runtime_env["env_vars"] @ray.remote def get_job_config(): @@ -751,7 +751,8 @@ def get_job_config(): job_config.ParseFromString(ray.get(get_job_config.remote())) assert (job_config.num_java_workers_per_process == num_java_workers_per_process) - assert (job_config.worker_env == worker_env) + job_runtime_env = json.loads(job_config.runtime_env.serialized_runtime_env) + assert job_runtime_env["env_vars"] == runtime_env["env_vars"] def test_duplicated_arg(ray_start_cluster): diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index d428188173cbd..4cc7ee63570fe 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -1,6 +1,7 @@ import json import jsonschema import os +import re import shutil from subprocess import CalledProcessError import tempfile @@ -13,7 +14,7 @@ from collections import defaultdict from ray.autoscaler._private.commands import get_or_create_head_node from jsonschema.exceptions import ValidationError -from typing import Dict, Callable +from typing import Dict, Callable, List, Optional import ray from ray.autoscaler._private.util import prepare_config, validate_config @@ -105,42 +106,56 @@ def check_output(self, cmd): return return_string.encode() - def assert_has_call(self, ip, pattern=None, exact=None): + def assert_has_call(self, + ip: str, + pattern: Optional[str] = None, + exact: Optional[List[str]] = None): + """Checks if the given value was called by this process runner. + + NOTE: Either pattern or exact must be specified, not both! + + Args: + ip: IP address of the node that the given call was executed on. + pattern: RegEx that matches one specific call. + exact: List of strings that when joined exactly match one call. + """ with self.lock: - assert pattern or exact, \ + assert bool(pattern) ^ bool(exact), \ "Must specify either a pattern or exact match." - out = "" + debug_output = "" if pattern is not None: for cmd in self.command_history(): if ip in cmd: - out += cmd - out += "\n" - if pattern in out: - return True + debug_output += cmd + debug_output += "\n" + if re.search(pattern, cmd): + return True else: raise Exception( - f"Did not find [{pattern}] in [{out}] for ip={ip}." - f"\n\nFull output: {self.command_history()}") + f"Did not find [{pattern}] in [{debug_output}] for " + f"ip={ip}.\n\nFull output: {self.command_history()}") elif exact is not None: exact_cmd = " ".join(exact) for cmd in self.command_history(): if ip in cmd: - out += cmd - out += "\n" + debug_output += cmd + debug_output += "\n" if cmd == exact_cmd: return True raise Exception( - f"Did not find [{exact_cmd}] in [{out}] for ip={ip}." - f"\n\nFull output: {self.command_history()}") + f"Did not find [{exact_cmd}] in [{debug_output}] for " + f"ip={ip}.\n\nFull output: {self.command_history()}") - def assert_not_has_call(self, ip, pattern): + def assert_not_has_call(self, ip: str, pattern: str): + """Ensure that the given regex pattern was never called. + """ with self.lock: out = "" for cmd in self.command_history(): if ip in cmd: out += cmd out += "\n" - if pattern in out: + if re.search(pattern, out): raise Exception("Found [{}] in [{}] for {}".format( pattern, out, ip)) else: @@ -449,7 +464,10 @@ def waitFor(self, condition, num_retries=50, fail_msg=None): fail_msg = fail_msg or "Timed out waiting for {}".format(condition) raise RayTestTimeoutException(fail_msg) - def waitForNodes(self, expected, comparison=None, tag_filters={}): + def waitForNodes(self, expected, comparison=None, tag_filters=None): + if tag_filters is None: + tag_filters = {} + MAX_ITER = 50 for i in range(MAX_ITER): n = len(self.provider.non_terminated_nodes(tag_filters)) @@ -2560,8 +2578,7 @@ def testContinuousFileMounts(self): for i in [0, 1]: runner.assert_not_has_call(f"172.0.0.{i}", "setup_cmd") runner.assert_has_call( - f"172.0.0.{i}", f"172.0.0.{i}", - f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" + f"172.0.0.{i}", f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" f"{docker_mount_prefix}/home/test-folder/") def testFileMountsNonContinuous(self): @@ -2596,8 +2613,7 @@ def testFileMountsNonContinuous(self): for i in [0, 1]: runner.assert_has_call(f"172.0.0.{i}", "setup_cmd") runner.assert_has_call( - f"172.0.0.{i}", f"172.0.0.{i}", - f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" + f"172.0.0.{i}", f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" f"{docker_mount_prefix}/home/test-folder/") runner.clear_history() @@ -2640,8 +2656,7 @@ def testFileMountsNonContinuous(self): for i in [0, 1]: runner.assert_has_call(f"172.0.0.{i}", "setup_cmd") runner.assert_has_call( - f"172.0.0.{i}", f"172.0.0.{i}", - f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" + f"172.0.0.{i}", f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" f"{docker_mount_prefix}/home/test-folder/") def testAutodetectResources(self): diff --git a/python/ray/tests/test_autoscaler_fake_multinode.py b/python/ray/tests/test_autoscaler_fake_multinode.py new file mode 100644 index 0000000000000..1f6c96b3d3c19 --- /dev/null +++ b/python/ray/tests/test_autoscaler_fake_multinode.py @@ -0,0 +1,58 @@ +import pytest +import platform + +import ray +from ray.cluster_utils import AutoscalingCluster + + +@pytest.mark.skipif( + platform.system() == "Windows", reason="Failing on Windows.") +def test_fake_autoscaler_basic_e2e(shutdown_only): + cluster = AutoscalingCluster( + head_resources={"CPU": 2}, + worker_node_types={ + "cpu_node": { + "resources": { + "CPU": 4, + "object_store_memory": 1024 * 1024 * 1024, + }, + "node_config": {}, + "min_workers": 0, + "max_workers": 2, + }, + "gpu_node": { + "resources": { + "CPU": 2, + "GPU": 1, + "object_store_memory": 1024 * 1024 * 1024, + }, + "node_config": {}, + "min_workers": 0, + "max_workers": 2, + }, + }) + + try: + cluster.start() + ray.init("auto") + + # Triggers the addition of a GPU node. + @ray.remote(num_gpus=1) + def f(): + print("gpu ok") + + # Triggers the addition of a CPU node. + @ray.remote(num_cpus=3) + def g(): + print("cpu ok") + + ray.get(f.remote()) + ray.get(g.remote()) + ray.shutdown() + finally: + cluster.shutdown() + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_autoscaler_yaml.py b/python/ray/tests/test_autoscaler_yaml.py index 137d188c8caaf..ad4ef152acd9e 100644 --- a/python/ray/tests/test_autoscaler_yaml.py +++ b/python/ray/tests/test_autoscaler_yaml.py @@ -91,6 +91,9 @@ def testValidateDefaultConfig(self): if "local" in config_path: # local tested in testValidateLocal continue + if "fake_multi_node" in config_path: + # not supported with ray up + continue with open(config_path) as f: config = yaml.safe_load(f) config = prepare_config(config) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index d5b73ece9bf54..ad4d844b7c304 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -76,8 +76,7 @@ def test_omp_threads_set(shutdown_only): assert os.environ["OMP_NUM_THREADS"] == "1" -@pytest.mark.parametrize("use_tls", [False, True], indirect=True) -def test_submit_api(shutdown_only, use_tls): +def test_submit_api(shutdown_only): ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) @ray.remote @@ -141,8 +140,7 @@ def method(self): assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] -@pytest.mark.parametrize("use_tls", [False, True], indirect=True) -def test_invalid_arguments(shutdown_only, use_tls): +def test_invalid_arguments(shutdown_only): ray.init(num_cpus=2) for opt in [np.random.randint(-100, -1), np.random.uniform(0, 1)]: @@ -238,8 +236,7 @@ def check(): {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}) -@pytest.mark.parametrize("use_tls", [False, True], indirect=True) -def test_put_get(shutdown_only, use_tls): +def test_put_get(shutdown_only): ray.init(num_cpus=0) for i in range(100): @@ -268,8 +265,7 @@ def test_put_get(shutdown_only, use_tls): @pytest.mark.skipif(sys.platform != "linux", reason="Failing on Windows") -@pytest.mark.parametrize("use_tls", [False, True], indirect=True) -def test_wait_timing(shutdown_only, use_tls): +def test_wait_timing(shutdown_only): ray.init(num_cpus=2) @ray.remote @@ -303,8 +299,7 @@ def test_function_descriptor(): assert d.get(python_descriptor2) == 123 -@pytest.mark.parametrize("use_tls", [False, True], indirect=True) -def test_ray_options(shutdown_only, use_tls): +def test_ray_options(shutdown_only): ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2}) @ray.remote( diff --git a/python/ray/tests/test_basic_3.py b/python/ray/tests/test_basic_3.py index 400f79c407b8f..9e050e0b04979 100644 --- a/python/ray/tests/test_basic_3.py +++ b/python/ray/tests/test_basic_3.py @@ -168,7 +168,16 @@ def f(): @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") def test_fair_queueing(shutdown_only): - ray.init(num_cpus=1) + ray.init( + num_cpus=1, + _system_config={ + # Having parallel leases is slow in this case + # because tasks are scheduled FIFO, + # the more parallism we have, + # the more workers we need to start to execute f and g tasks + # before we can execute the first h task. + "max_pending_lease_requests_per_scheduling_category": 1 + }) @ray.remote def h(): diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index de552b1fe2977..0f6dcadb10cbc 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -6,9 +6,11 @@ import queue import threading import _thread +from unittest.mock import patch import ray.util.client.server.server as ray_client_server from ray.tests.client_test_utils import create_remote_signal_actor +from ray.tests.client_test_utils import run_wrapped_actor_creation from ray.util.client.common import ClientObjectRef from ray.util.client.ray_client_helpers import connect_to_client_or_not from ray.util.client.ray_client_helpers import ray_start_client_server @@ -24,11 +26,11 @@ def test_client_context_manager(ray_start_regular_shared, connect_to_client): with connect_to_client_or_not(connect_to_client): if connect_to_client: # Client mode is on. - assert client_mode_should_convert() + assert client_mode_should_convert(auto_init=True) # We're connected to Ray client. assert ray.util.client.ray.is_connected() else: - assert not client_mode_should_convert() + assert not client_mode_should_convert(auto_init=True) assert not ray.util.client.ray.is_connected() @@ -70,20 +72,20 @@ def run(self): def test_client_mode_hook_thread_safe(ray_start_regular_shared): with ray_start_client_server(): with enable_client_mode(): - assert client_mode_should_convert() + assert client_mode_should_convert(auto_init=True) lock = threading.Lock() lock.acquire() q = queue.Queue() def disable(): with disable_client_hook(): - q.put(client_mode_should_convert()) + q.put(client_mode_should_convert(auto_init=True)) lock.acquire() - q.put(client_mode_should_convert()) + q.put(client_mode_should_convert(auto_init=True)) t = threading.Thread(target=disable) t.start() - assert client_mode_should_convert() + assert client_mode_should_convert(auto_init=True) lock.release() t.join() assert q.get( @@ -467,8 +469,11 @@ def print_on_stderr_and_stdout(s): time.sleep(1) print_on_stderr_and_stdout.remote("Hello world") time.sleep(1) - assert len(log_msgs) == 2 - assert all((msg.find("Hello world") for msg in log_msgs)) + num_hello = 0 + for msg in log_msgs: + if "Hello world" in msg: + num_hello += 1 + assert num_hello == 2, f"Invalid logs: {log_msgs}" @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") @@ -648,6 +653,7 @@ def stop_server(server): @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +@patch.dict(os.environ, {"RAY_ENABLE_AUTO_CONNECT": "0"}) def test_client_gpu_ids(call_ray_stop_only): import ray ray.init(num_cpus=2) @@ -702,7 +708,42 @@ def test_object_ref_cleanup(): # See https://github.com/ray-project/ray/issues/17968 for details with ray_start_client_server(): result = run_string_as_driver(object_ref_cleanup_script) - assert result == "" + assert "Error in sys.excepthook:" not in result + assert "AttributeError: 'NoneType' object has no " not in result + assert "Exception ignored in" not in result + + +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +@pytest.mark.parametrize( + "call_ray_start", + ["ray start --head --ray-client-server-port 25552 --port 0"], + indirect=True) +def test_wrapped_actor_creation(call_ray_start): + """ + When the client schedules an actor, the server will load a separate + copy of the actor class if it's defined in a separate file. This + means that modifications to the client's copy of the actor class + aren't propagated to the server. Currently, tracing logic modifies + the signatures of actor methods to pass around metadata when ray.remote + is applied to an actor class. However, if a user does something like: + + class SomeActor: + def __init__(self): + pass + + def decorate_actor(): + RemoteActor = ray.remote(SomeActor) + ... + + Then the SomeActor class will have its signatures modified on the client + side, but not on the server side, since ray.remote was applied inside of + the function instead of directly on the actor. Note if it were directly + applied to the actor then the signature would be modified when the server + imports the class. + """ + import ray + ray.init("ray://localhost:25552") + run_wrapped_actor_creation() if __name__ == "__main__": diff --git a/python/ray/tests/test_client_compat.py b/python/ray/tests/test_client_compat.py new file mode 100644 index 0000000000000..98f4e9f4ba43d --- /dev/null +++ b/python/ray/tests/test_client_compat.py @@ -0,0 +1,33 @@ +import pytest +import sys + +import ray +try: + import pyspark # noqa +except ImportError: + pyspark = None + + +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +@pytest.mark.skipif(pyspark is None, reason="PySpark dependency not found") +@pytest.mark.parametrize( + "call_ray_start", [ + "ray start --head --num-cpus=1 --min-worker-port=0 " + "--max-worker-port=0 --port 0 --ray-client-server-port 10002", + ], + indirect=True) +def test_client_data_get(call_ray_start): + """PySpark import changes NamedTuple pickling behavior, leading + to inconpatibilities with the Ray client and Ray Data. This test + makes sure that our fix in the ClientPickler works.""" + address = call_ray_start + ip = address.split(":")[0] + + ray.util.connect(f"{ip}:10002") + + ray_pipeline = ray.data.from_items(list(range(1_000))) + ray.get(ray_pipeline.to_numpy()[0]) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_client_library_integration.py b/python/ray/tests/test_client_library_integration.py index 774f46954d045..417b31efb5e3b 100644 --- a/python/ray/tests/test_client_library_integration.py +++ b/python/ray/tests/test_client_library_integration.py @@ -14,11 +14,11 @@ def test_rllib_integration(ray_start_regular_shared): import ray.rllib.agents.dqn as dqn # Confirming the behavior of this context manager. # (Client mode hook not yet enabled.) - assert not client_mode_should_convert() + assert not client_mode_should_convert(auto_init=True) # Need to enable this for client APIs to be used. with enable_client_mode(): # Confirming mode hook is enabled. - assert client_mode_should_convert() + assert client_mode_should_convert(auto_init=True) config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy() # Run locally. @@ -38,11 +38,11 @@ def test_rllib_integration_tune(ray_start_regular_shared): with ray_start_client_server(): # Confirming the behavior of this context manager. # (Client mode hook not yet enabled.) - assert not client_mode_should_convert() + assert not client_mode_should_convert(auto_init=True) # Need to enable this for client APIs to be used. with enable_client_mode(): # Confirming mode hook is enabled. - assert client_mode_should_convert() + assert client_mode_should_convert(auto_init=True) tune.run( "DQN", config={"env": "CartPole-v1"}, diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index 03d1f34cb6582..8440268da6980 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -253,7 +253,10 @@ def test_prepare_runtime_init_req_no_modification(): """ Check that `prepare_runtime_init_req` properly extracts the JobConfig. """ - job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc") + job_config = JobConfig( + runtime_env={"env_vars": { + "KEY": "VALUE" + }}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest( init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config), @@ -273,7 +276,10 @@ def test_prepare_runtime_init_req_modified_job(): Check that `prepare_runtime_init_req` properly extracts the JobConfig and modifies it according to `ray_client_server_env_prep`. """ - job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc") + job_config = JobConfig( + runtime_env={"env_vars": { + "KEY": "VALUE" + }}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest( init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config), diff --git a/python/ray/tests/test_client_reconnect.py b/python/ray/tests/test_client_reconnect.py index b830403449ba3..0672b755f9eb1 100644 --- a/python/ray/tests/test_client_reconnect.py +++ b/python/ray/tests/test_client_reconnect.py @@ -294,6 +294,7 @@ def disconnect(middleman): disconnect_thread.join() +@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on windows") def test_valid_actor_state(): """ Repeatedly inject errors in the middle of mutating actor calls. Check @@ -311,24 +312,28 @@ def incr(self): return self.val i = 0 + # This is to prevent erroring in the initial connection logic. + started = False def fail_every_seven(_): # Inject an error every seventh time this method is called - nonlocal i + nonlocal i, started i += 1 - if i % 7 == 0: + if i % 7 == 0 and started: raise RuntimeError with start_middleman_server( on_data_response=fail_every_seven, on_task_request=fail_every_seven, on_task_response=fail_every_seven): + started = True actor = IncrActor.remote() for _ in range(100): ref = actor.incr.remote() assert ray.get(ref) == 100 +@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on windows") def test_valid_actor_state_2(): """ Do a full disconnect (cancel channel) every 11 requests. Failure diff --git a/python/ray/tests/test_dashboard.py b/python/ray/tests/test_dashboard.py index c92d9610ead84..578707baebf4a 100644 --- a/python/ray/tests/test_dashboard.py +++ b/python/ray/tests/test_dashboard.py @@ -4,14 +4,34 @@ import sys import time +import psutil import pytest import requests -from ray._private.test_utils import run_string_as_driver, wait_for_condition +from ray._private.test_utils import (run_string_as_driver, wait_for_condition, + get_error_message) import ray from ray import ray_constants +def search_agents(cluster): + all_processes = cluster.head_node.all_processes + raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0] + raylet_proc = psutil.Process(raylet_proc_info.process.pid) + + def _search_agent(processes): + for p in processes: + try: + for c in p.cmdline(): + if "dashboard/agent.py" in c: + return p + except Exception: + pass + + agent_proc = _search_agent(raylet_proc.children()) + return agent_proc + + def test_ray_start_default_port_conflict(call_ray_stop_only, shutdown_only): subprocess.check_call(["ray", "start", "--head"]) ray.init(address="auto") @@ -90,8 +110,6 @@ def test_port_conflict(call_ray_stop_only, shutdown_only): sock.close() -@pytest.mark.skipif( - sys.version_info < (3, 5, 3), reason="requires python3.5.3 or higher") def test_dashboard(shutdown_only): addresses = ray.init(include_dashboard=True, num_cpus=1) dashboard_url = addresses["webui_url"] @@ -121,8 +139,32 @@ def test_dashboard(shutdown_only): f"Dashboard output log: {out_log}\n") -if __name__ == "__main__": - import sys +@pytest.mark.parametrize( + "ray_start_cluster_head", [{ + "metrics_export_port": 6379, + "_system_config": { + "agent_restart_interval_ms": 10, + "agent_max_restart_count": 5 + } + }], + indirect=True) +def test_dashboard_agent_restart(ray_start_cluster_head, error_pubsub): + """Test that when the agent fails to start many times in a row + if the error message is suppressed correctly without spamming + the driver. + """ + # Choose a duplicated port for the agent so that it will crash. + p = error_pubsub + errors = get_error_message( + p, 1, ray_constants.DASHBOARD_AGENT_DIED_ERROR, timeout=10) + for e in errors: + assert ("There are 2 possible problems " + "if you see this error." in e.error_message) + # Make sure the agent process is not started anymore. + cluster = ray_start_cluster_head + wait_for_condition(lambda: search_agents(cluster) is None) + +if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_distributed_sort.py b/python/ray/tests/test_distributed_sort.py index 55cc7e37ebdfd..75cb682b165e8 100644 --- a/python/ray/tests/test_distributed_sort.py +++ b/python/ray/tests/test_distributed_sort.py @@ -4,14 +4,19 @@ from ray.experimental.raysort import main +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows") def test_distributed_sort(): - main.args = main.get_args() - main.args.ray_address = None - main.args.total_data_size = 1_000_000_000 - main.args.skip_input = True - main.args.skip_output = True - main.main() + args = main.get_args([ + "--total_data_size=1_000_000_000", + "--num_mappers=4", + "--num_reducers=4", + "--num_mappers_per_round=2", + "--ray_address=", + "--skip_input", + "--skip_output", + ]) + main.main(args) if __name__ == "__main__": - sys.exit(pytest.main(["-v", __file__])) + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_failure_2.py b/python/ray/tests/test_failure_2.py index 6bb0986e649c3..3b33e1c3f173b 100644 --- a/python/ray/tests/test_failure_2.py +++ b/python/ray/tests/test_failure_2.py @@ -67,11 +67,12 @@ class Foo: pass # The actor creation should be infeasible. - Foo.remote() + a = Foo.remote() errors = get_error_message(p, 1, ray_constants.INFEASIBLE_TASK_ERROR) assert len(errors) == 1 assert errors[0].type == ray_constants.INFEASIBLE_TASK_ERROR p.close() + del a def test_warning_for_too_many_actors(shutdown_only): diff --git a/python/ray/tests/test_global_state.py b/python/ray/tests/test_global_state.py index 6d9c35bef37cd..8bf964791292e 100644 --- a/python/ray/tests/test_global_state.py +++ b/python/ray/tests/test_global_state.py @@ -287,8 +287,9 @@ def _read_resource_usage(self): def test_backlog_report(shutdown_only): cluster = ray.init( - num_cpus=1, _system_config={ - "report_worker_backlog": True, + num_cpus=1, + _system_config={ + "max_pending_lease_requests_per_scheduling_category": 1 }) global_state_accessor = GlobalStateAccessor( cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) @@ -333,10 +334,7 @@ def backlog_size_set(): def test_heartbeat_ip(shutdown_only): - cluster = ray.init( - num_cpus=1, _system_config={ - "report_worker_backlog": True, - }) + cluster = ray.init(num_cpus=1) global_state_accessor = GlobalStateAccessor( cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) global_state_accessor.connect() diff --git a/python/ray/tests/test_multi_tenancy.py b/python/ray/tests/test_multi_tenancy.py index 1267570d3660b..f2913a50c05ba 100644 --- a/python/ray/tests/test_multi_tenancy.py +++ b/python/ray/tests/test_multi_tenancy.py @@ -111,12 +111,14 @@ def get_pid(): all_worker_pids.add(worker_pid) -def test_worker_env(shutdown_only): +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +def test_runtime_env(shutdown_only): ray.init( - job_config=ray.job_config.JobConfig(worker_env={ - "foo1": "bar1", - "foo2": "bar2" - })) + job_config=ray.job_config.JobConfig( + runtime_env={"env_vars": { + "foo1": "bar1", + "foo2": "bar2" + }})) @ray.remote def get_env(key): diff --git a/python/ray/tests/test_object_manager.py b/python/ray/tests/test_object_manager.py index 1f2c5e5dc4944..e44bf22e83187 100644 --- a/python/ray/tests/test_object_manager.py +++ b/python/ray/tests/test_object_manager.py @@ -296,8 +296,6 @@ def driver(): ray.get(driver.remote()) -# TODO(ekl) this sometimes takes much longer (10+s) due to a higher level -# pull retry. We should try to resolve these hangs in the chunk transfer logic. def test_pull_bundles_admission_control(shutdown_only): cluster = Cluster() object_size = int(6e6) @@ -605,6 +603,52 @@ def task(x): ray.get(t, timeout=10) +@pytest.mark.parametrize( + "ray_start_cluster_head", [{ + "num_cpus": 0, + "object_store_memory": 75 * 1024 * 1024, + "_system_config": { + "worker_lease_timeout_milliseconds": 0, + "object_manager_pull_timeout_ms": 20000, + "object_spilling_threshold": 1.0, + } + }], + indirect=True) +def test_maximize_concurrent_pull_race_condition(ray_start_cluster_head): + # Test if https://github.com/ray-project/ray/issues/18062 is mitigated + cluster = ray_start_cluster_head + cluster.add_node(num_cpus=8, object_store_memory=75 * 1024 * 1024) + + @ray.remote + class RemoteObjectCreator: + def put(self, i): + return np.random.rand(i * 1024 * 1024) # 8 MB data + + def idle(self): + pass + + @ray.remote + def f(x): + print(f"timestamp={time.time()} pulled {len(x)*8} bytes") + time.sleep(1) + return + + remote_obj_creator = RemoteObjectCreator.remote() + remote_refs = [remote_obj_creator.put.remote(1) for _ in range(7)] + print(remote_refs) + # Make sure all objects are created. + ray.get(remote_obj_creator.idle.remote()) + + local_refs = [ray.put(np.random.rand(1 * 1024 * 1024)) for _ in range(20)] + remote_tasks = [f.remote(x) for x in local_refs] + + start = time.time() + ray.get(remote_tasks) + end = time.time() + assert end - start < 20, "Too much time spent in pulling objects, " \ + "check the amount of time in retries" + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tests/test_output.py b/python/ray/tests/test_output.py index 93cba471ee21a..958cdecfe732e 100644 --- a/python/ray/tests/test_output.py +++ b/python/ray/tests/test_output.py @@ -65,13 +65,15 @@ def test_autoscaler_no_spam(): import ray import time -ray.init(num_cpus=1) +# Check that there are no false positives with custom resources. +ray.init(num_cpus=1, resources={"node:x": 1}) -@ray.remote(num_cpus=1) +@ray.remote(num_cpus=1, resources={"node:x": 1}) def f(): time.sleep(1) + print("task done") -ray.get([f.remote() for _ in range(5)]) +ray.get([f.remote() for _ in range(15)]) """ proc = run_string_as_driver_nonblocking(script) diff --git a/python/ray/tests/test_placement_group.py b/python/ray/tests/test_placement_group.py index 55a2cc5a007e9..345f19ff80951 100644 --- a/python/ray/tests/test_placement_group.py +++ b/python/ray/tests/test_placement_group.py @@ -345,6 +345,13 @@ def test_remove_placement_group(ray_start_cluster, connect_to_client): cluster.add_node(num_cpus=4) ray.init(address=cluster.address) + @ray.remote + def warmup(): + pass + + # warm up the cluster. + ray.get([warmup.remote() for _ in range(4)]) + with connect_to_client_or_not(connect_to_client): # First try to remove a placement group that doesn't # exist. This should not do anything. diff --git a/python/ray/tests/test_placement_group_3.py b/python/ray/tests/test_placement_group_3.py index 12afdfee47ecb..eeb6df0f5c4bb 100644 --- a/python/ray/tests/test_placement_group_3.py +++ b/python/ray/tests/test_placement_group_3.py @@ -608,5 +608,40 @@ def is_usage_updated(): assert cpu_usage == expected +def test_placement_group_removal_leak_regression(ray_start_cluster): + """Related issue: + https://github.com/ray-project/ray/issues/19131 + """ + cluster = ray_start_cluster + cluster.add_node(num_cpus=5) + ray.init(address=cluster.address) + + TOTAL_CPUS = 8 + bundles = [{"CPU": 1, "GPU": 1}] + bundles += [{"CPU": 1} for _ in range(TOTAL_CPUS - 1)] + + pg = placement_group(bundles, strategy="PACK") + # Here, we simulate that the ready task is queued and + # the new node is up. As soon as the new node is up, + # the ready task is scheduled. + # See https://github.com/ray-project/ray/pull/19138 + # for more details about the test. + o = pg.ready() + # Add an artificial delay until the new node is up. + time.sleep(3) + cluster.add_node(num_cpus=5, num_gpus=1) + ray.get(o) + bundle_resource_name = f"bundle_group_{pg.id.hex()}" + expected_bundle_wildcard_val = TOTAL_CPUS * 1000 + + # This should fail if there's a leakage + # because the bundle resources are never returned properly. + def check_bundle_leaks(): + bundle_resources = ray.available_resources()[bundle_resource_name] + return expected_bundle_wildcard_val == bundle_resources + + wait_for_condition(check_bundle_leaks) + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_ray_debugger.py b/python/ray/tests/test_ray_debugger.py index fdc6c56da1eb3..3cc980bb14026 100644 --- a/python/ray/tests/test_ray_debugger.py +++ b/python/ray/tests/test_ray_debugger.py @@ -11,6 +11,7 @@ import ray from ray.cluster_utils import Cluster from ray._private.test_utils import run_string_as_driver, wait_for_condition +from ray._private import services def test_ray_debugger_breakpoint(shutdown_only): @@ -217,7 +218,7 @@ def f(): host, port = session["pdb_address"].split(":") if ray_debugger_external: - assert host not in ["localhost", "127.0.0.1"], host + assert host == services.get_node_ip_address(), host else: assert host == "localhost", host @@ -267,13 +268,13 @@ def f(): host1, port1 = session1["pdb_address"].split(":") if ray_debugger_external: - assert host1 not in ["localhost", "127.0.0.1"], host1 + assert host1 == services.get_node_ip_address(), host1 else: assert host1 == "localhost", host1 host2, port2 = session2["pdb_address"].split(":") if ray_debugger_external: - assert host2 not in ["localhost", "127.0.0.1"], host2 + assert host2 == services.get_node_ip_address(), host2 else: assert host2 == "localhost", host2 diff --git a/python/ray/tests/test_ray_init.py b/python/ray/tests/test_ray_init.py index 5040f4bd65ef4..3fdb6a6ea110d 100644 --- a/python/ray/tests/test_ray_init.py +++ b/python/ray/tests/test_ray_init.py @@ -11,6 +11,7 @@ from ray.client_builder import ClientContext from ray.cluster_utils import Cluster from ray._private.test_utils import run_string_as_driver +from ray._raylet import ClientObjectRef from ray.util.client.worker import Worker import grpc @@ -216,6 +217,7 @@ def test_ray_address(input, call_ray_start): res = ray.init(input) # Ensure this is not a client.connect() assert not isinstance(res, ClientContext) + ray.shutdown() class Credentials(grpc.ChannelCredentials): @@ -257,9 +259,47 @@ def mock_secure_channel(conn_str, with pytest.raises(Stop) as stop: ray.init("ray://127.0.0.1", _credentials=Credentials("test")) + ray.util.disconnect() assert stop.value.credentials.name == "test" +def test_auto_init_non_client(call_ray_start): + address = call_ray_start + with unittest.mock.patch.dict(os.environ, {"RAY_ADDRESS": address}): + res = ray.put(300) + # Ensure this is not a client.connect() + assert not isinstance(res, ClientObjectRef) + ray.shutdown() + + addr = "localhost:{}".format(address.split(":")[-1]) + with unittest.mock.patch.dict(os.environ, {"RAY_ADDRESS": addr}): + res = ray.put(300) + # Ensure this is not a client.connect() + assert not isinstance(res, ClientObjectRef) + + +@pytest.mark.parametrize( + "call_ray_start", + ["ray start --head --ray-client-server-port 25036 --port 0"], + indirect=True) +@pytest.mark.parametrize( + "function", [lambda: ray.put(300), lambda: ray.remote(ray.nodes).remote()]) +def test_auto_init_client(call_ray_start, function): + address = call_ray_start.split(":")[0] + with unittest.mock.patch.dict(os.environ, + {"RAY_ADDRESS": f"ray://{address}:25036"}): + res = function() + # Ensure this is a client connection. + assert isinstance(res, ClientObjectRef) + ray.shutdown() + + with unittest.mock.patch.dict(os.environ, + {"RAY_ADDRESS": "ray://localhost:25036"}): + res = function() + # Ensure this is a client connection. + assert isinstance(res, ClientObjectRef) + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tests/test_resource_demand_scheduler.py b/python/ray/tests/test_resource_demand_scheduler.py index add24d4a571a9..eb1260db32aa1 100644 --- a/python/ray/tests/test_resource_demand_scheduler.py +++ b/python/ray/tests/test_resource_demand_scheduler.py @@ -46,24 +46,52 @@ def get_nodes_for(*a, **kw): def test_util_score(): assert _utilization_score({"CPU": 64}, [{"TPU": 16}]) is None - assert _utilization_score({"GPU": 4}, [{"GPU": 2}]) == (0.5, 0.5) + assert _utilization_score({"GPU": 4}, [{"GPU": 2}]) == (1, 0.5, 0.5) assert _utilization_score({"GPU": 4}, [{"GPU": 1}, {"GPU": 1}]) == \ - (0.5, 0.5) - assert _utilization_score({"GPU": 2}, [{"GPU": 2}]) == (2, 2) - assert _utilization_score({"GPU": 2}, [{"GPU": 1}, {"GPU": 1}]) == (2, 2) - assert _utilization_score({"GPU": 2, "TPU": 1}, [{"GPU": 2}]) == (0, 1) - assert _utilization_score({"CPU": 64}, [{"CPU": 64}]) == (64, 64) - assert _utilization_score({"CPU": 64}, [{"CPU": 32}]) == (8, 8) + (1, 0.5, 0.5) + assert _utilization_score({"GPU": 2}, [{"GPU": 2}]) == (1, 2, 2) + assert _utilization_score({ + "GPU": 2 + }, [{ + "GPU": 1 + }, { + "GPU": 1 + }]) == (1, 2, 2) + assert _utilization_score({ + "GPU": 1 + }, [{ + "GPU": 1, + "CPU": 1 + }, { + "GPU": 1 + }]) == (1, 1, 1) + assert _utilization_score({ + "GPU": 1, + "CPU": 1 + }, [{ + "GPU": 1, + "CPU": 1 + }, { + "GPU": 1 + }]) == (2, 1, 1) + assert _utilization_score({"GPU": 2, "TPU": 1}, [{"GPU": 2}]) == (1, 0, 1) + assert _utilization_score({"CPU": 64}, [{"CPU": 64}]) == (1, 64, 64) + assert _utilization_score({"CPU": 64}, [{"CPU": 32}]) == (1, 8, 8) assert _utilization_score({"CPU": 64}, [{"CPU": 16}, {"CPU": 16}]) == \ - (8, 8) + (1, 8, 8) def test_gpu_node_util_score(): # Avoid scheduling CPU tasks on GPU node. assert _utilization_score({"GPU": 1, "CPU": 1}, [{"CPU": 1}]) is None assert _utilization_score({"GPU": 1, "CPU": 1}, [{"CPU": 1, "GPU": 1}]) \ - == (1.0, 1.0) - assert _utilization_score({"GPU": 1, "CPU": 1}, [{"GPU": 1}]) == (0.0, 0.5) + == (2, 1.0, 1.0) + assert _utilization_score({ + "GPU": 1, + "CPU": 1 + }, [{ + "GPU": 1 + }]) == (1, 0.0, 0.5) def test_zero_resource(): @@ -197,7 +225,7 @@ def test_get_nodes_packing_heuristic(): }] * 8) + ([{ "CPU": 1 }] * 64)) == { - "m4.16xlarge": 1, + "m4.4xlarge": 2, "p2.8xlarge": 1 } @@ -215,6 +243,47 @@ def test_get_nodes_packing_heuristic(): } +def test_node_packing_gpu_cpu_bundles(): + TYPES = { + "cpu": { + "resources": { + "CPU": 16, + }, + "max_workers": 10, + }, + "gpu": { + "resources": { + "CPU": 16, + "GPU": 1, + }, + "max_workers": 10, + }, + } + nodes = get_nodes_for(TYPES, {}, "cpu", 9999, ([{ + "CPU": 1 + }] * 30 + [{ + "GPU": 1, + "CPU": 1 + }])) + assert nodes == {"gpu": 1, "cpu": 1} + + nodes = get_nodes_for(TYPES, {}, "cpu", 9999, ([{ + "GPU": 1, + "CPU": 1 + }] + [{ + "CPU": 1 + }] * 30)) + assert nodes == {"gpu": 1, "cpu": 1} + + nodes = get_nodes_for(TYPES, {}, "cpu", 9999, ([{ + "GPU": 1, + "CPU": 1 + }] + [{ + "CPU": 1 + }] * 15)) + assert nodes == {"gpu": 1} + + def test_gpu_node_avoid_cpu_task(): types = { "cpu": { @@ -630,13 +699,8 @@ def test_backlog_queue_impact_on_binpacking_time_aux( "CPU": 1 }]) # If not for the max launch concurrency the next assert should be: - # {'m4.large': 4, 'm4.4xlarge': 2, 'm4.16xlarge': 15, 'p2.8xlarge': 125}. - assert to_launch == { - "m4.large": 4, - "m4.4xlarge": 2, - "m4.16xlarge": 5, - "p2.8xlarge": 5 - } + # {'m4.16xlarge': 1, 'p2.8xlarge': 125, 'p2.xlarge': 1} + assert to_launch == {"m4.16xlarge": 1, "p2.8xlarge": 5, "p2.xlarge": 1} # Check the time it takes when there are 100 nodes available and the demand # requires another 75 nodes. @@ -1322,7 +1386,10 @@ def tearDown(self): shutil.rmtree(self.tmpdir) ray.shutdown() - def waitForNodes(self, expected, comparison=None, tag_filters={}): + def waitForNodes(self, expected, comparison=None, tag_filters=None): + if tag_filters is None: + tag_filters = {} + MAX_ITER = 50 for i in range(MAX_ITER): n = len(self.provider.non_terminated_nodes(tag_filters)) @@ -1664,7 +1731,7 @@ def testScaleUpMinWorkers(self): assert cnt == 2 def testScaleUpIgnoreUsed(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) # Commenting out this line causes the test case to fail?!?! config["min_workers"] = 0 config["target_utilization_fraction"] = 1.0 @@ -1705,7 +1772,7 @@ def testScaleUpIgnoreUsed(self): assert self.provider.mock_nodes[1].node_type == "p2.xlarge" def testRequestBundlesAccountsForHeadNode(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["head_node_type"] = "p2.8xlarge" config["min_workers"] = 0 config["max_workers"] = 50 @@ -1744,7 +1811,7 @@ def testRequestBundlesAccountsForHeadNode(self): assert self.provider.mock_nodes[1].node_type == "p2.8xlarge" def testRequestBundles(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["min_workers"] = 0 config["max_workers"] = 50 config_path = self.write_config(config) @@ -1781,7 +1848,7 @@ def testRequestBundles(self): assert self.provider.mock_nodes[4].node_type == "m4.16xlarge" def testResourcePassing(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["min_workers"] = 0 config["max_workers"] = 50 config_path = self.write_config(config) @@ -1812,7 +1879,7 @@ def testResourcePassing(self): assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" # TODO (Alex): Autoscaler creates the node during one update then - # starts the updater in the enxt update. The sleep is largely + # starts the updater in the next update. The sleep is largely # unavoidable because the updater runs in its own thread and we have no # good way of ensuring that the commands are sent in time. autoscaler.update() @@ -1827,7 +1894,7 @@ def testResourcePassing(self): runner.assert_has_call("172.0.0.2", "\"GPU\":8") def testScaleUpLoadMetrics(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["min_workers"] = 0 config["max_workers"] = 50 config_path = self.write_config(config) @@ -1858,16 +1925,15 @@ def testScaleUpLoadMetrics(self): "CPU": 16 }]) autoscaler.update() - self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) nodes = { self.provider.mock_nodes[1].node_type, - self.provider.mock_nodes[2].node_type } - assert nodes == {"p2.xlarge", "m4.4xlarge"} + assert nodes == {"p2.xlarge"} def testCommandPassing(self): t = "custom" - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["available_node_types"]["p2.8xlarge"][ "worker_setup_commands"] = ["new_worker_setup_command"] config["available_node_types"]["p2.xlarge"][ @@ -1923,7 +1989,7 @@ def testCommandPassing(self): "init_cmd") def testDockerWorkers(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["available_node_types"]["p2.8xlarge"]["docker"] = { "worker_image": "p2.8x_image:latest", "worker_run_options": ["p2.8x-run-options"] @@ -1981,7 +2047,7 @@ def testDockerWorkers(self): }]) autoscaler.update() self.waitForNodes(5) - assert self.provider.mock_nodes[4].node_type == "m4.16xlarge" + assert self.provider.mock_nodes[4].node_type == "m4.large" autoscaler.update() sleep(0.1) runner.assert_has_call(self.provider.mock_nodes[2].internal_ip, @@ -2044,7 +2110,7 @@ def testUpdateConfig(self): self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) def testEmptyDocker(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) del config["docker"] config["min_workers"] = 0 config["max_workers"] = 10 diff --git a/python/ray/tests/test_runtime_context.py b/python/ray/tests/test_runtime_context.py index 1c069e10066df..8ce983da2085a 100644 --- a/python/ray/tests/test_runtime_context.py +++ b/python/ray/tests/test_runtime_context.py @@ -4,6 +4,8 @@ import time import sys +from ray._private.test_utils import SignalActor + def test_was_current_actor_reconstructed(shutdown_only): ray.init() @@ -113,6 +115,119 @@ def echo2(self, s): assert ray.get(ray.get(obj)) == "hello" +def test_actor_stats_normal_task(ray_start_regular): + # Because it works at the core worker level, this API works for tasks. + @ray.remote + def func(): + return ray.get_runtime_context()._get_actor_call_stats() + + assert ray.get(func.remote())["func"] == { + "pending": 0, + "running": 1, + "finished": 0, + } + + +def test_actor_stats_sync_actor(ray_start_regular): + signal = SignalActor.remote() + + @ray.remote + class SyncActor: + def run(self): + return ray.get_runtime_context()._get_actor_call_stats() + + def wait_signal(self): + ray.get(signal.wait.remote()) + return ray.get_runtime_context()._get_actor_call_stats() + + actor = SyncActor.remote() + counts = ray.get(actor.run.remote()) + assert counts == { + "SyncActor.run": { + "pending": 0, + "running": 1, + "finished": 0 + }, + "SyncActor.__init__": { + "pending": 0, + "running": 0, + "finished": 1 + } + } + + ref = actor.wait_signal.remote() + other_refs = [actor.run.remote() for _ in range(3) + ] + [actor.wait_signal.remote() for _ in range(5)] + ray.wait(other_refs, timeout=1) + signal.send.remote() + counts = ray.get(ref) + assert counts == { + "SyncActor.run": { + "pending": 3, + "running": 0, + "finished": 1, # from previous run + }, + "SyncActor.wait_signal": { + "pending": 5, + "running": 1, + "finished": 0, + }, + "SyncActor.__init__": { + "pending": 0, + "running": 0, + "finished": 1 + } + } + + +def test_actor_stats_threaded_actor(ray_start_regular): + signal = SignalActor.remote() + + @ray.remote + class ThreadedActor: + def func(self): + ray.get(signal.wait.remote()) + return ray.get_runtime_context()._get_actor_call_stats() + + actor = ThreadedActor.options(max_concurrency=3).remote() + refs = [actor.func.remote() for _ in range(6)] + ready, _ = ray.wait(refs, timeout=1) + assert len(ready) == 0 + signal.send.remote() + results = ray.get(refs) + assert max(result["ThreadedActor.func"]["running"] + for result in results) > 1 + assert max(result["ThreadedActor.func"]["pending"] + for result in results) > 1 + + +def test_actor_stats_async_actor(ray_start_regular): + signal = SignalActor.remote() + + @ray.remote + class AysncActor: + async def func(self): + await signal.wait.remote() + return ray.get_runtime_context()._get_actor_call_stats() + + actor = AysncActor.options(max_concurrency=3).remote() + refs = [actor.func.remote() for _ in range(6)] + ready, _ = ray.wait(refs, timeout=1) + assert len(ready) == 0 + signal.send.remote() + results = ray.get(refs) + assert max(result["AysncActor.func"]["running"] for result in results) == 3 + assert max(result["AysncActor.func"]["pending"] for result in results) == 3 + + +# get_runtime_context() can be called outside of Ray so it should not start +# Ray automatically. +def test_no_auto_init(shutdown_only): + assert not ray.is_initialized() + ray.get_runtime_context() + assert not ray.is_initialized() + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_runtime_env.py b/python/ray/tests/test_runtime_env.py index 110beb4490a6b..0f9297238c3cd 100644 --- a/python/ray/tests/test_runtime_env.py +++ b/python/ray/tests/test_runtime_env.py @@ -13,7 +13,6 @@ from ray._private.test_utils import ( run_string_as_driver, run_string_as_driver_nonblocking, wait_for_condition) from ray._private.runtime_env import working_dir as working_dir_pkg -from ray._private.runtime_env.validation import override_task_or_actor_runtime_env # noqa: E501 from ray._private.utils import (get_wheel_filename, get_master_wheel_url, get_release_wheel_url) @@ -774,41 +773,38 @@ def test_container_option_serialize(): job_config = ray.job_config.JobConfig(runtime_env=runtime_env) job_config_serialized = job_config.serialize() # job_config_serialized is JobConfig protobuf serialized string, - # job_config.runtime_env.raw_json has container_option info - # job_config.serialized_runtime_env also has container_option info - assert job_config_serialized.count(b"image") == 2 + # job_config.runtime_env.serialized_runtime_env has container_option info + assert job_config_serialized.count(b"image") == 1 def test_working_dir_override_failure(shutdown_only): ray.init() - @ray.remote(runtime_env={"working_dir": "."}) - def f(): - pass - with pytest.raises(NotImplementedError): - f.remote() + + @ray.remote(runtime_env={"working_dir": "."}) + def f(): + pass @ray.remote def g(): pass with pytest.raises(NotImplementedError): - g.options(runtime_env={"working_dir": "."}).remote() - - @ray.remote(runtime_env={"working_dir": "."}) - class A: - pass + g.options(runtime_env={"working_dir": "."}) with pytest.raises(NotImplementedError): - A.remote() + + @ray.remote(runtime_env={"working_dir": "."}) + class A: + pass @ray.remote class B: pass with pytest.raises(NotImplementedError): - B.options(runtime_env={"working_dir": "."}).remote() + B.options(runtime_env={"working_dir": "."}) @pytest.mark.skipif( @@ -944,46 +940,6 @@ def test_large_file_error(shutdown_only): os.chdir(old_dir) -class TestOverrideTaskOrActorRuntimeEnv: - def test_working_dir_in_child_invalid(self): - child_env = {"working_dir": "some_dir"} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - - with pytest.raises(NotImplementedError): - override_task_or_actor_runtime_env(child_env, parent_env) - - def test_uri_inherit(self): - child_env = {} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env == {"uris": ["a", "b"]} - - # The dicts passed in should not be mutated. - assert child_env == {} - assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} - - def test_uri_override(self): - child_env = {"uris": ["c", "d"]} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env["uris"] == ["c", "d"] - assert result_env.get("working_dir") is None - - # The dicts passed in should not be mutated. - assert child_env == {"uris": ["c", "d"]} - assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} - - def test_no_mutate(self): - child_env = {} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env == {"uris": ["a", "b"]} - - # The dictis passed in should not be mutated. - assert child_env == {} - assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} - - if __name__ == "__main__": import sys sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_runtime_env_complicated.py b/python/ray/tests/test_runtime_env_complicated.py index d8c334c413606..e5c7047f275b5 100644 --- a/python/ray/tests/test_runtime_env_complicated.py +++ b/python/ray/tests/test_runtime_env_complicated.py @@ -12,15 +12,16 @@ import yaml import ray -from ray._private.runtime_env import RuntimeEnvDict from ray._private.runtime_env.conda import ( inject_dependencies, _inject_ray_to_conda_site, _resolve_install_from_source_ray_dependencies, _current_py_version, ) -from ray._private.test_utils import (run_string_as_driver, - run_string_as_driver_nonblocking) + +from ray._private.runtime_env.conda_utils import get_conda_env_list +from ray._private.test_utils import ( + run_string_as_driver, run_string_as_driver_nonblocking, wait_for_condition) from ray._private.utils import get_conda_env_dir, get_conda_bin_executable if not os.environ.get("CI"): @@ -190,6 +191,39 @@ def test_job_config_conda_env(conda_envs, shutdown_only): ray.shutdown() +@pytest.mark.skipif( + os.environ.get("CONDA_DEFAULT_ENV") is None, + reason="must be run from within a conda environment") +@pytest.mark.skipif( + os.environ.get("CI") and sys.platform != "linux", + reason="This test is only run on linux CI machines.") +def test_job_eager_install(shutdown_only): + # Test enable eager install + runtime_env = {"conda": {"dependencies": ["toolz"]}, "eager_install": True} + env_count = len(get_conda_env_list()) + ray.init(runtime_env=runtime_env) + wait_for_condition( + lambda: len(get_conda_env_list()) == env_count + 1, timeout=60) + ray.shutdown() + # Test disable eager install + runtime_env = { + "conda": { + "dependencies": ["toolz"] + }, + "eager_install": False + } + ray.init(runtime_env=runtime_env) + with pytest.raises(RuntimeError): + wait_for_condition( + lambda: len(get_conda_env_list()) == env_count + 2, timeout=60) + ray.shutdown() + # Test unavailable type + runtime_env = {"conda": {"dependencies": ["toolz"]}, "eager_install": 123} + with pytest.raises(AssertionError): + ray.init(runtime_env=runtime_env) + ray.shutdown() + + def test_get_conda_env_dir(tmp_path): """ Typical output of `conda env list`, for context: @@ -449,28 +483,6 @@ def f(): assert ray.get(f.remote()) -@pytest.mark.skipif(sys.platform == "win32", reason="Unsupported on Windows.") -@pytest.mark.parametrize("use_working_dir", [True, False]) -def test_conda_input_filepath(use_working_dir, tmp_path): - conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} - d = tmp_path / "pip_requirements" - d.mkdir() - p = d / "environment.yml" - - p.write_text(yaml.dump(conda_dict)) - - if use_working_dir: - runtime_env_dict = RuntimeEnvDict({ - "working_dir": str(d), - "conda": "environment.yml" - }) - else: - runtime_env_dict = RuntimeEnvDict({"conda": str(p)}) - - output_conda_dict = runtime_env_dict.get_parsed_dict().get("conda") - assert output_conda_dict == conda_dict - - @skipIf(sys.platform == "win32", "Fail to create temp dir.") def test_experimental_package(shutdown_only): ray.init(num_cpus=2) @@ -514,7 +526,7 @@ def test_experimental_package_github(shutdown_only): ["ray start --head --ray-client-server-port 24001 --port 0"], indirect=True) def test_client_working_dir_filepath(call_ray_start, tmp_path): - """Test that pip and conda relative filepaths work with working_dir.""" + """Test that pip and conda filepaths work with working_dir.""" working_dir = tmp_path / "requirements" working_dir.mkdir() @@ -524,10 +536,7 @@ def test_client_working_dir_filepath(call_ray_start, tmp_path): pip-install-test==0.5 """ pip_file.write_text(requirements_txt) - runtime_env_pip = { - "working_dir": str(working_dir), - "pip": "requirements.txt" - } + runtime_env_pip = {"working_dir": str(working_dir), "pip": str(pip_file)} conda_file = working_dir / "environment.yml" conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} @@ -535,7 +544,7 @@ def test_client_working_dir_filepath(call_ray_start, tmp_path): conda_file.write_text(conda_str) runtime_env_conda = { "working_dir": str(working_dir), - "conda": "environment.yml" + "conda": str(conda_file) } @ray.remote @@ -557,6 +566,64 @@ def f(): assert ray.get(f.remote()) +@pytest.mark.skipif( + os.environ.get("CI") and sys.platform != "linux", + reason="This test is only run on linux CI machines.") +@pytest.mark.parametrize( + "call_ray_start", + ["ray start --head --ray-client-server-port 24001 --port 0"], + indirect=True) +def test_conda_pip_filepaths_remote(call_ray_start, tmp_path): + """Test that pip and conda filepaths work, simulating a remote cluster.""" + + working_dir = tmp_path / "requirements" + working_dir.mkdir() + + pip_file = working_dir / "requirements.txt" + requirements_txt = """ + pip-install-test==0.5 + """ + pip_file.write_text(requirements_txt) + runtime_env_pip = {"pip": str(pip_file)} + + conda_file = working_dir / "environment.yml" + conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} + conda_str = yaml.dump(conda_dict) + conda_file.write_text(conda_str) + runtime_env_conda = {"conda": str(conda_file)} + + @ray.remote + def f(): + import pip_install_test # noqa + return True + + with ray.client("localhost:24001").connect(): + with pytest.raises(ModuleNotFoundError): + # Ensure pip-install-test is not installed in a client that doesn't + # use the runtime_env + ray.get(f.remote()) + + # pip and conda files should be parsed when the function is declared. + f_pip = f.options(runtime_env=runtime_env_pip) + f_conda = f.options(runtime_env=runtime_env_conda) + + # Remove the pip and conda files from the local filesystem. This is + # necessary to simulate the files not being present on the remote cluster, + # because in this single-machine test, the cluster has the same filesystem. + os.remove(pip_file) + os.remove(conda_file) + + # Test with and without a working_dir. + client_envs = [{}, {"working_dir": str(working_dir)}] + for runtime_env in client_envs: + with ray.client("localhost:24001").env(runtime_env).connect(): + with pytest.raises(ModuleNotFoundError): + # Ensure pip-install-test is not installed on the test machine + import pip_install_test # noqa + assert ray.get(f_pip.remote()) + assert ray.get(f_conda.remote()) + + install_env_script = """ import ray import time @@ -718,7 +785,7 @@ def test(self): # Start a new job on the same cluster using the Summit 2021 requirements. with ray.client(f"localhost:{CLIENT_SERVER_PORT}").env({ "working_dir": str(tmp_path), - "pip": "requirements.txt" + "pip": str(requirement_path) }).connect(): @ray.remote @@ -752,7 +819,9 @@ def test(self): return Path("./test").read_text() - a = TestActor.options(runtime_env={"pip": "requirements.txt"}).remote() + a = TestActor.options(runtime_env={ + "pip": str(requirement_path) + }).remote() assert ray.get(a.test.remote()) == "Hello" # Check that per-task pip specification works and that the job's @@ -888,7 +957,7 @@ def f(self): @pytest.mark.skipif( os.environ.get("CI") and sys.platform != "linux", reason="This test is only run on linux CI machines.") -def test_runtime_env_logging_to_dirver(ray_start_regular_shared, log_pubsub): +def test_runtime_env_logging_to_driver(ray_start_regular_shared, log_pubsub): @ray.remote(runtime_env={"pip": [f"requests=={REQUEST_VERSIONS[0]}"]}) def func(): pass diff --git a/python/ray/tests/test_runtime_env_env_vars.py b/python/ray/tests/test_runtime_env_env_vars.py index 22ce5d5ce59b9..479a7f4130bd2 100644 --- a/python/ray/tests/test_runtime_env_env_vars.py +++ b/python/ray/tests/test_runtime_env_env_vars.py @@ -7,54 +7,37 @@ import ray -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_task(ray_start_regular, - use_runtime_env): +def test_environment_variables_task(ray_start_regular): @ray.remote def get_env(key): return os.environ.get(key) - if use_runtime_env: - assert (ray.get( - get_env.options(runtime_env={ - "env_vars": { - "a": "b", - } - }).remote("a")) == "b") - else: - assert (ray.get( - get_env.options(override_environment_variables={ + assert (ray.get( + get_env.options(runtime_env={ + "env_vars": { "a": "b", - }).remote("a")) == "b") + } + }).remote("a")) == "b") -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_actor(ray_start_regular, - use_runtime_env): +def test_environment_variables_actor(ray_start_regular): @ray.remote class EnvGetter: def get(self, key): return os.environ.get(key) - if use_runtime_env: - a = EnvGetter.options(runtime_env={ - "env_vars": { - "a": "b", - "c": "d", - } - }).remote() - else: - a = EnvGetter.options(override_environment_variables={ + a = EnvGetter.options(runtime_env={ + "env_vars": { "a": "b", "c": "d", - }).remote() + } + }).remote() + assert (ray.get(a.get.remote("a")) == "b") assert (ray.get(a.get.remote("c")) == "d") -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_nested_task(ray_start_regular, - use_runtime_env): +def test_environment_variables_nested_task(ray_start_regular): @ray.remote def get_env(key): return os.environ.get(key) @@ -63,36 +46,19 @@ def get_env(key): def get_env_wrapper(key): return ray.get(get_env.remote(key)) - if use_runtime_env: - assert (ray.get( - get_env_wrapper.options(runtime_env={ - "env_vars": { - "a": "b", - } - }).remote("a")) == "b") - else: - assert (ray.get( - get_env_wrapper.options(override_environment_variables={ + assert (ray.get( + get_env_wrapper.options(runtime_env={ + "env_vars": { "a": "b", - }).remote("a")) == "b") - - -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_multitenancy(shutdown_only, - use_runtime_env): - if use_runtime_env: - ray.init( - job_config=ray.job_config.JobConfig( - runtime_env={"env_vars": { - "foo1": "bar1", - "foo2": "bar2", - }})) - else: - ray.init( - job_config=ray.job_config.JobConfig(worker_env={ - "foo1": "bar1", - "foo2": "bar2", - })) + } + }).remote("a")) == "b") + + +def test_environment_variables_multitenancy(shutdown_only): + ray.init(runtime_env={"env_vars": { + "foo1": "bar1", + "foo2": "bar2", + }}) @ray.remote def get_env(key): @@ -100,48 +66,27 @@ def get_env(key): assert ray.get(get_env.remote("foo1")) == "bar1" assert ray.get(get_env.remote("foo2")) == "bar2" - if use_runtime_env: - assert ray.get( - get_env.options(runtime_env={ - "env_vars": { - "foo1": "baz1", - } - }).remote("foo1")) == "baz1" - assert ray.get( - get_env.options(runtime_env={ - "env_vars": { - "foo1": "baz1", - } - }).remote("foo2")) == "bar2" - else: - assert ray.get( - get_env.options(override_environment_variables={ + assert ray.get( + get_env.options(runtime_env={ + "env_vars": { "foo1": "baz1", - }).remote("foo1")) == "baz1" - assert ray.get( - get_env.options(override_environment_variables={ + } + }).remote("foo1")) == "baz1" + assert ray.get( + get_env.options(runtime_env={ + "env_vars": { "foo1": "baz1", - }).remote("foo2")) == "bar2" + } + }).remote("foo2")) == "bar2" -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_complex(shutdown_only, - use_runtime_env): - if use_runtime_env: - ray.init(runtime_env={ - "env_vars": { - "a": "job_a", - "b": "job_b", - "z": "job_z", - } - }) - else: - ray.init( - job_config=ray.job_config.JobConfig(worker_env={ - "a": "job_a", - "b": "job_b", - "z": "job_z", - })) +def test_environment_variables_complex(shutdown_only): + ray.init( + runtime_env={"env_vars": { + "a": "job_a", + "b": "job_b", + "z": "job_z", + }}) @ray.remote def get_env(key): @@ -164,69 +109,45 @@ def get_task(self, key): return ray.get(get_env.remote(key)) def nested_get(self, key): - if use_runtime_env: - aa = NestedEnvGetter.options(runtime_env={ - "env_vars": { - "c": "e", - "d": "dd", - } - }).remote() - else: - aa = NestedEnvGetter.options(override_environment_variables={ + aa = NestedEnvGetter.options(runtime_env={ + "env_vars": { "c": "e", "d": "dd", - }).remote() + } + }).remote() return ray.get(aa.get.remote(key)) - if use_runtime_env: - a = EnvGetter.options(runtime_env={ - "env_vars": { - "a": "b", - "c": "d", - } - }).remote() - else: - a = EnvGetter.options(override_environment_variables={ + a = EnvGetter.options(runtime_env={ + "env_vars": { "a": "b", "c": "d", - }).remote() + } + }).remote() + assert (ray.get(a.get.remote("a")) == "b") assert (ray.get(a.get_task.remote("a")) == "b") assert (ray.get(a.nested_get.remote("a")) == "b") assert (ray.get(a.nested_get.remote("c")) == "e") assert (ray.get(a.nested_get.remote("d")) == "dd") - if use_runtime_env: - assert (ray.get( - get_env.options(runtime_env={ - "env_vars": { - "a": "b", - } - }).remote("a")) == "b") - else: - assert (ray.get( - get_env.options(override_environment_variables={ + assert (ray.get( + get_env.options(runtime_env={ + "env_vars": { "a": "b", - }).remote("a")) == "b") + } + }).remote("a")) == "b") assert (ray.get(a.get.remote("z")) == "job_z") assert (ray.get(a.get_task.remote("z")) == "job_z") assert (ray.get(a.nested_get.remote("z")) == "job_z") - if use_runtime_env: - assert (ray.get( - get_env.options(runtime_env={ - "env_vars": { - "a": "b", - } - }).remote("z")) == "job_z") - else: - assert (ray.get( - get_env.options(override_environment_variables={ + assert (ray.get( + get_env.options(runtime_env={ + "env_vars": { "a": "b", - }).remote("z")) == "job_z") + } + }).remote("z")) == "job_z") -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_reuse(shutdown_only, use_runtime_env): +def test_environment_variables_reuse(shutdown_only): """Test that new tasks don't incorrectly reuse previous environments.""" ray.init() @@ -244,32 +165,20 @@ def g(): return os.environ.get(env_var_name) assert ray.get(f.remote()) is None - if use_runtime_env: - assert ray.get( - f.options(runtime_env={ - "env_vars": { - env_var_name: val1 - } - }).remote()) == val1 - else: - assert ray.get( - f.options(override_environment_variables={ + assert ray.get( + f.options(runtime_env={ + "env_vars": { env_var_name: val1 - }).remote()) == val1 + } + }).remote()) == val1 assert ray.get(f.remote()) is None assert ray.get(g.remote()) is None - if use_runtime_env: - assert ray.get( - f.options(runtime_env={ - "env_vars": { - env_var_name: val2 - } - }).remote()) == val2 - else: - assert ray.get( - f.options(override_environment_variables={ + assert ray.get( + f.options(runtime_env={ + "env_vars": { env_var_name: val2 - }).remote()) == val2 + } + }).remote()) == val2 assert ray.get(g.remote()) is None assert ray.get(f.remote()) is None @@ -278,9 +187,7 @@ def g(): # there aren't enough CPUs (2-4 on Travis CI vs. likely 8 on Buildkite) and # worker processes are being killed to adhere to the soft limit. @pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on Travis CI.") -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_env_caching(shutdown_only, - use_runtime_env): +def test_environment_variables_env_caching(shutdown_only): """Test that workers with specified envs are cached and reused. When a new task or actor is created with a new runtime env, a @@ -307,10 +214,7 @@ def g(): return task() def get_options(val): - if use_runtime_env: - return {"override_environment_variables": {env_var_name: val}} - else: - return {"runtime_env": {"env_vars": {env_var_name: val}}} + return {"runtime_env": {"env_vars": {env_var_name: val}}} # Empty runtime env does not set our env var. assert ray.get(f.remote())[0] is None diff --git a/python/ray/tests/test_runtime_env_plugin.py b/python/ray/tests/test_runtime_env_plugin.py new file mode 100644 index 0000000000000..629cdca4e6d25 --- /dev/null +++ b/python/ray/tests/test_runtime_env_plugin.py @@ -0,0 +1,75 @@ +import os +import tempfile + +import pytest +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.plugin import RuntimeEnvPlugin + +import ray + +MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin" + + +class MyPlugin(RuntimeEnvPlugin): + env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY" + + @staticmethod + def validate(runtime_env_dict: dict) -> str: + value = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH] + if value == "fail": + raise ValueError("not allowed") + return value + + @staticmethod + def modify_context(uri: str, runtime_env_dict: dict, + ctx: RuntimeEnvContext) -> None: + plugin_config_dict = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH] + ctx.env_vars[MyPlugin.env_key] = str(plugin_config_dict["env_value"]) + ctx.command_prefix.append( + f"echo {plugin_config_dict['tmp_content']} > " + f"{plugin_config_dict['tmp_file']}") + ctx.py_executable = ( + plugin_config_dict["prefix_command"] + " " + ctx.py_executable) + + +def test_simple_env_modification_plugin(ray_start_regular): + _, tmp_file_path = tempfile.mkstemp() + + @ray.remote + def f(): + import psutil + with open(tmp_file_path, "r") as f: + content = f.read().strip() + return { + "env_value": os.environ[MyPlugin.env_key], + "tmp_content": content, + "nice": psutil.Process().nice(), + } + + with pytest.raises(ValueError, match="not allowed"): + f.options(runtime_env={ + "plugins": { + MY_PLUGIN_CLASS_PATH: "fail" + } + }).remote() + + output = ray.get( + f.options( + runtime_env={ + "plugins": { + MY_PLUGIN_CLASS_PATH: { + "env_value": 42, + "tmp_file": tmp_file_path, + "tmp_content": "hello", + # See https://en.wikipedia.org/wiki/Nice_(Unix) + "prefix_command": "nice -n 19", + } + } + }).remote()) + + assert output == {"env_value": "42", "tmp_content": "hello", "nice": 19} + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_runtime_env_validation.py b/python/ray/tests/test_runtime_env_validation.py new file mode 100644 index 0000000000000..1f3fc254cec29 --- /dev/null +++ b/python/ray/tests/test_runtime_env_validation.py @@ -0,0 +1,379 @@ +import os +import pytest +import sys +import tempfile +from pathlib import Path +import yaml + +from ray._private.runtime_env.validation import ( + parse_and_validate_excludes, parse_and_validate_working_dir, + parse_and_validate_conda, parse_and_validate_pip, + parse_and_validate_env_vars, ParsedRuntimeEnv, + override_task_or_actor_runtime_env) + +CONDA_DICT = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} + +PIP_LIST = ["requests==1.0.0", "pip-install-test"] + + +@pytest.fixture +def test_directory(): + with tempfile.TemporaryDirectory() as tmp_dir: + path = Path(tmp_dir) + subdir = path / "subdir" + subdir.mkdir(parents=True) + requirements_file = subdir / "requirements.txt" + with requirements_file.open(mode="w") as f: + print("\n".join(PIP_LIST), file=f) + + good_conda_file = subdir / "good_conda_env.yaml" + with good_conda_file.open(mode="w") as f: + yaml.dump(CONDA_DICT, f) + + bad_conda_file = subdir / "bad_conda_env.yaml" + with bad_conda_file.open(mode="w") as f: + print("% this is not a YAML file %", file=f) + + old_dir = os.getcwd() + os.chdir(tmp_dir) + yield subdir, requirements_file, good_conda_file, bad_conda_file + os.chdir(old_dir) + + +class TestValidateWorkingDir: + @pytest.mark.parametrize("absolute_path", [True, False]) + def test_validate_working_dir_valid_path(self, test_directory, + absolute_path): + subdir, _, _, _ = test_directory + + rel1 = "." + assert parse_and_validate_working_dir( + rel1, is_task_or_actor=False) == rel1 + + if absolute_path: + subdir = subdir.resolve() + + rel2 = str(subdir) + assert parse_and_validate_working_dir( + rel2, is_task_or_actor=False) == rel2 + + def test_validate_working_dir_absolute_path(self, test_directory): + subdir, _, _, _ = test_directory + + abspath = str(subdir.resolve()) + assert parse_and_validate_working_dir( + abspath, is_task_or_actor=False) == abspath + + def test_validate_working_dir_invalid_path(self): + with pytest.raises(ValueError): + parse_and_validate_working_dir("fake_path", is_task_or_actor=False) + + def test_validate_working_dir_invalid_types(self): + with pytest.raises(TypeError): + parse_and_validate_working_dir( + { + "working_dir": 1 + }, is_task_or_actor=False) + + def test_validate_working_dir_reject_task_or_actor(self): + # Can't pass working_dir for tasks/actors. + with pytest.raises(NotImplementedError): + parse_and_validate_working_dir( + { + "working_dir": "." + }, is_task_or_actor=True) + + +class TestValidateExcludes: + def test_validate_excludes_invalid_types(self): + with pytest.raises(TypeError): + parse_and_validate_excludes(1) + + with pytest.raises(TypeError): + parse_and_validate_excludes(True) + + with pytest.raises(TypeError): + parse_and_validate_excludes("string") + + with pytest.raises(TypeError): + parse_and_validate_excludes(["string", 1]) + + def test_validate_excludes_empty_list(self): + assert ParsedRuntimeEnv({"excludes": []}) == {} + + +@pytest.mark.skipif( + sys.platform == "win32", reason="Conda option not supported on Windows.") +class TestValidateConda: + def test_validate_conda_invalid_types(self): + with pytest.raises(TypeError): + parse_and_validate_conda(1) + + with pytest.raises(TypeError): + parse_and_validate_conda(True) + + def test_validate_conda_str(self, test_directory): + assert parse_and_validate_conda("my_env_name") == "my_env_name" + + def test_validate_conda_invalid_path(self): + with pytest.raises(ValueError): + parse_and_validate_conda("../bad_path.yaml") + + @pytest.mark.parametrize("absolute_path", [True, False]) + def test_validate_conda_valid_file(self, test_directory, absolute_path): + _, _, good_conda_file, _ = test_directory + + if absolute_path: + good_conda_file = good_conda_file.resolve() + + assert parse_and_validate_conda(str(good_conda_file)) == CONDA_DICT + + @pytest.mark.parametrize("absolute_path", [True, False]) + def test_validate_conda_invalid_file(self, test_directory, absolute_path): + _, _, _, bad_conda_file = test_directory + + if absolute_path: + bad_conda_file = bad_conda_file.resolve() + + with pytest.raises(ValueError): + parse_and_validate_conda(str(bad_conda_file)) + + def test_validate_conda_valid_dict(self): + assert parse_and_validate_conda(CONDA_DICT) == CONDA_DICT + + +@pytest.mark.skipif( + sys.platform == "win32", reason="Pip option not supported on Windows.") +class TestValidatePip: + def test_validate_pip_invalid_types(self): + with pytest.raises(TypeError): + parse_and_validate_pip(1) + + with pytest.raises(TypeError): + parse_and_validate_pip(True) + + def test_validate_pip_invalid_path(self): + with pytest.raises(ValueError): + parse_and_validate_pip("../bad_path.txt") + + @pytest.mark.parametrize("absolute_path", [True, False]) + def test_validate_pip_valid_file(self, test_directory, absolute_path): + _, requirements_file, _, _ = test_directory + + if absolute_path: + requirements_file = requirements_file.resolve() + + result = parse_and_validate_pip(str(requirements_file)) + assert result == PIP_LIST + + def test_validate_pip_valid_list(self): + result = parse_and_validate_pip(PIP_LIST) + assert result == PIP_LIST + + +class TestValidateEnvVars: + def test_type_validation(self): + # Only strings allowed. + with pytest.raises(TypeError, match=".*Dict[str, str]*"): + parse_and_validate_env_vars({"INT_ENV": 1}) + + with pytest.raises(TypeError, match=".*Dict[str, str]*"): + parse_and_validate_env_vars({1: "hi"}) + + +class TestParsedRuntimeEnv: + @pytest.mark.parametrize("is_task_or_actor", [True, False]) + def test_empty(self, is_task_or_actor): + assert ParsedRuntimeEnv({}, is_task_or_actor=is_task_or_actor) == {} + + @pytest.mark.skipif( + sys.platform == "win32", reason="Pip option not supported on Windows.") + @pytest.mark.parametrize("is_task_or_actor", [True, False]) + def test_serialization(self, is_task_or_actor): + env1 = ParsedRuntimeEnv( + { + "pip": ["requests"], + "env_vars": { + "hi1": "hi1", + "hi2": "hi2" + } + }, + is_task_or_actor=is_task_or_actor) + + env2 = ParsedRuntimeEnv( + { + "env_vars": { + "hi2": "hi2", + "hi1": "hi1" + }, + "pip": ["requests"] + }, + is_task_or_actor=is_task_or_actor) + + assert env1 == env2 + + serialized_env1 = env1.serialize() + serialized_env2 = env2.serialize() + + # Key ordering shouldn't matter. + assert serialized_env1 == serialized_env2 + + deserialized_env1 = ParsedRuntimeEnv.deserialize(serialized_env1) + deserialized_env2 = ParsedRuntimeEnv.deserialize(serialized_env2) + + assert env1 == deserialized_env1 == env2 == deserialized_env2 + + @pytest.mark.parametrize("is_task_or_actor", [True, False]) + def test_reject_pip_and_conda(self, is_task_or_actor): + with pytest.raises(ValueError): + ParsedRuntimeEnv( + { + "pip": ["requests"], + "conda": "env_name" + }, + is_task_or_actor=is_task_or_actor) + + @pytest.mark.skipif( + sys.platform == "win32", + reason="Conda and pip options not supported on Windows.") + @pytest.mark.parametrize("is_task_or_actor", [True, False]) + def test_ray_commit_injection(self, is_task_or_actor): + # Should not be injected if no pip and conda. + result = ParsedRuntimeEnv( + { + "env_vars": { + "hi": "hi" + } + }, is_task_or_actor=is_task_or_actor) + assert "_ray_commit" not in result + + # Should be injected if pip or conda present. + result = ParsedRuntimeEnv( + { + "pip": ["requests"], + }, is_task_or_actor=is_task_or_actor) + assert "_ray_commit" in result + + result = ParsedRuntimeEnv( + { + "conda": "env_name" + }, is_task_or_actor=is_task_or_actor) + assert "_ray_commit" in result + + # Should not override if passed. + result = ParsedRuntimeEnv( + { + "conda": "env_name", + "_ray_commit": "Blah" + }, + is_task_or_actor=is_task_or_actor) + assert result["_ray_commit"] == "Blah" + + @pytest.mark.parametrize("is_task_or_actor", [True, False]) + def test_inject_current_ray(self, is_task_or_actor): + # Should not be injected if not provided by env var. + result = ParsedRuntimeEnv( + { + "env_vars": { + "hi": "hi" + } + }, is_task_or_actor=is_task_or_actor) + assert "_inject_current_ray" not in result + + os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] = "1" + + # Should be injected if provided by env var. + result = ParsedRuntimeEnv({}, is_task_or_actor=is_task_or_actor) + assert result["_inject_current_ray"] + + # Should be preserved if passed. + result = ParsedRuntimeEnv( + { + "_inject_current_ray": False + }, is_task_or_actor=is_task_or_actor) + assert not result["_inject_current_ray"] + + del os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] + + +class TestOverrideRuntimeEnvs: + def test_override_uris(self): + child = {} + parent = {"uris": ["a", "b"]} + assert override_task_or_actor_runtime_env(child, parent) == parent + + child = {"uris": ["a", "b"]} + parent = {"uris": ["c", "d"]} + assert override_task_or_actor_runtime_env(child, parent) == child + + child = {"uris": ["a", "b"]} + parent = {} + assert override_task_or_actor_runtime_env(child, parent) == child + + def test_override_env_vars(self): + # (child, parent, expected) + TEST_CASES = [ + ({}, {}, {}), + (None, None, None), + ({"a": "b"}, {}, {"a": "b"}), + ({"a": "b"}, None, {"a": "b"}), + ({}, {"a": "b"}, {"a": "b"}), + (None, {"a": "b"}, {"a": "b"}), + ({"a": "b"}, {"a": "d"}, {"a": "b"}), + ({"a": "b"}, {"c": "d"}, {"a": "b", "c": "d"}), + ({"a": "b"}, {"a": "e", "c": "d"}, {"a": "b", "c": "d"}) + ] # yapf: disable + + for idx, (child, parent, expected) in enumerate(TEST_CASES): + child = {"env_vars": child} if child is not None else {} + parent = {"env_vars": parent} if parent is not None else {} + expected = {"env_vars": expected} if expected is not None else {} + assert override_task_or_actor_runtime_env( + child, parent) == expected, f"TEST_INDEX:{idx}" + + def test_uri_inherit(self): + child_env = {} + parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env == {"uris": ["a", "b"]} + + # The dicts passed in should not be mutated. + assert child_env == {} + assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} + + def test_uri_override(self): + child_env = {"uris": ["c", "d"]} + parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env["uris"] == ["c", "d"] + assert result_env.get("working_dir") is None + + # The dicts passed in should not be mutated. + assert child_env == {"uris": ["c", "d"]} + assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} + + def test_no_mutate(self): + child_env = {} + parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env == {"uris": ["a", "b"]} + + # The dicts passed in should not be mutated. + assert child_env == {} + assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} + + def test_inherit_conda(self): + child_env = {"uris": ["a"]} + parent_env = {"conda": "my-env-name", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env == {"uris": ["a"], "conda": "my-env-name"} + + def test_inherit_pip(self): + child_env = {"uris": ["a"]} + parent_env = {"pip": ["pkg-name"], "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env == {"uris": ["a"], "pip": ["pkg-name"]} + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_scheduling.py b/python/ray/tests/test_scheduling.py index 10a4ab846e844..b834d67e0c67c 100644 --- a/python/ray/tests/test_scheduling.py +++ b/python/ray/tests/test_scheduling.py @@ -2,6 +2,7 @@ import collections import logging import platform +import subprocess import sys import time import unittest @@ -549,8 +550,8 @@ def __init__(self): def get_location(self): return ray.worker.global_worker.node.unique_id - @ray.remote - def task_cpu(num_cpus=0.5): + @ray.remote(num_cpus=0.5) + def task_cpu(): time.sleep(10) return ray.worker.global_worker.node.unique_id @@ -578,6 +579,100 @@ def launcher(): cluster.shutdown() +@pytest.mark.parametrize( + "ray_start_cluster", [{ + "num_cpus": 0, + "num_nodes": 1, + }], indirect=True) +def test_head_node_without_cpu(ray_start_cluster): + @ray.remote(num_cpus=1) + def f(): + return 1 + + f.remote() + + check_count = 0 + demand_1cpu = " {'CPU': 1.0}:" + while True: + status = subprocess.check_output(["ray", "status"]).decode() + if demand_1cpu in status: + break + check_count += 1 + assert check_count < 5, f"Incorrect demand. Last status {status}" + time.sleep(1) + + @ray.remote(num_cpus=2) + def g(): + return 2 + + g.remote() + + check_count = 0 + demand_2cpu = " {'CPU': 2.0}:" + while True: + status = subprocess.check_output(["ray", "status"]).decode() + if demand_1cpu in status and demand_2cpu in status: + break + check_count += 1 + assert check_count < 5, f"Incorrect demand. Last status {status}" + time.sleep(1) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Fails on windows") +def test_gpu_scheduling_liveness(ray_start_cluster): + """Check if the GPU scheduling is in progress when + it is used with the placement group + Issue: https://github.com/ray-project/ray/issues/19130 + """ + cluster = ray_start_cluster + # Start a node without a gpu. + cluster.add_node(num_cpus=6) + ray.init(address=cluster.address) + + NUM_CPU_BUNDLES = 10 + + @ray.remote(num_cpus=1) + class Worker(object): + def __init__(self, i): + self.i = i + + def work(self): + time.sleep(0.1) + print("work ", self.i) + + @ray.remote(num_cpus=1, num_gpus=1) + class Trainer(object): + def __init__(self, i): + self.i = i + + def train(self): + time.sleep(0.2) + print("train ", self.i) + + bundles = [{"CPU": 1, "GPU": 1}] + bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)] + + pg = ray.util.placement_group(bundles, strategy="PACK") + o = pg.ready() + # Artificial delay to simulate the real world workload. + time.sleep(3) + print("Scaling up.") + cluster.add_node(num_cpus=6, num_gpus=1) + ray.get(o) + + workers = [ + Worker.options(placement_group=pg).remote(i) + for i in range(NUM_CPU_BUNDLES) + ] + trainer = Trainer.options(placement_group=pg).remote(0) + + # If the gpu scheduling doesn't properly work, the below + # code will hang. + ray.get( + [workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)], timeout=30) + ray.get(trainer.train.remote(), timeout=30) + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py index 057c2e0b2ae32..01b234ceb8315 100644 --- a/python/ray/tests/test_tls_auth.py +++ b/python/ray/tests/test_tls_auth.py @@ -1,23 +1,81 @@ # coding: utf-8 +import logging import os import sys import pytest -import logging +import ray logger = logging.getLogger(__name__) +@pytest.mark.parametrize("use_tls", [True], indirect=True) +def test_put_get_with_tls(shutdown_only, use_tls): + ray.init(num_cpus=0) + + for i in range(100): + value_before = i * 10**6 + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + for i in range(100): + value_before = i * 10**6 * 1.0 + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + for i in range(100): + value_before = "h" * i + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + for i in range(100): + value_before = [1] * i + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + +@pytest.mark.parametrize("use_tls", [True], indirect=True) +def test_submit_with_tls(shutdown_only, use_tls): + ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) + + @ray.remote + def f(n): + return list(range(n)) + + id1, id2, id3 = f._remote(args=[3], num_returns=3) + assert ray.get([id1, id2, id3]) == [0, 1, 2] + + @ray.remote + class Actor: + def __init__(self, x, y=0): + self.x = x + self.y = y + + def method(self, a, b=0): + return self.x, self.y, a, b + + a = Actor._remote( + args=[0], kwargs={"y": 1}, num_gpus=1, resources={"Custom": 1}) + + id1, id2, id3, id4 = a.method._remote( + args=["test"], kwargs={"b": 2}, num_returns=4) + assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] + + @pytest.mark.skipif( sys.platform == "darwin", reason=("Cryptography doesn't install in Mac build pipeline")) @pytest.mark.parametrize("use_tls", [True], indirect=True) def test_client_connect_to_tls_server(use_tls, init_and_serve): - from ray.util.client import ray + from ray.util.client import ray as ray_client os.environ["RAY_USE_TLS"] = "0" with pytest.raises(ConnectionError): - ray.connect("localhost:50051") + ray_client.connect("localhost:50051") os.environ["RAY_USE_TLS"] = "1" - ray.connect("localhost:50051") + ray_client.connect("localhost:50051") diff --git a/python/ray/tests/test_traceback.py b/python/ray/tests/test_traceback.py index 3081bcc6ec3d4..fa48ec62f09cb 100644 --- a/python/ray/tests/test_traceback.py +++ b/python/ray/tests/test_traceback.py @@ -270,6 +270,45 @@ def __repr__(self): assert label_dict["repr"] == actor_repr +def test_unpickleable_stacktrace(): + expected_output = """System error: Failed to unpickle serialized exception +traceback: Traceback (most recent call last): + File "FILE", line ZZ, in from_bytes + return pickle.loads(ray_exception.serialized_exception) +TypeError: __init__() missing 1 required positional argument: 'arg' + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "FILE", line ZZ, in deserialize_objects + obj = self._deserialize_object(data, metadata, object_ref) + File "FILE", line ZZ, in _deserialize_object + return RayError.from_bytes(obj) + File "FILE", line ZZ, in from_bytes + raise RuntimeError(msg) from e +RuntimeError: Failed to unpickle serialized exception""" + + class NoPickleError(OSError): + def __init__(self, arg): + pass + + def g(a): + raise NoPickleError("asdf") + + @ray.remote + def f(): + a = 3 + b = 4 + c = a + b + return g(c) + + try: + ray.get(f.remote()) + except Exception as ex: + print(repr(scrub_traceback(str(ex)))) + assert clean_noqa(expected_output) == scrub_traceback(str(ex)) + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index c8bcfe31a8b92..47330c64c7ec6 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -160,7 +160,7 @@ py_test( size = "small", srcs = ["tests/test_logger.py"], deps = [":tune_lib"], - tags = ["team:ml", "jenkins_only"], + tags = ["team:ml"], ) py_test( diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index e7cfc31810e1d..fbaa7207a04a0 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -1,9 +1,11 @@ import json import logging import os +import warnings from numbers import Number from typing import Any, Dict, List, Optional, Tuple +from ray.util.debug import log_once from ray.tune.utils import flatten_dict from ray.tune.utils.serialization import TuneFunctionDecoder from ray.tune.utils.util import is_nan_or_inf @@ -556,6 +558,17 @@ def best_result(self) -> Dict: "the metric and mode explicitly and fetch the last result.") return self.best_trial.last_result + def _delimiter(self): + # Deprecate: 1.9 (default should become `/`) + delimiter = os.environ.get("TUNE_RESULT_DELIM", ".") + if delimiter == "." and log_once("delimiter_deprecation"): + warnings.warn( + "Dataframes will use '/' instead of '.' to delimit " + "nested result keys in future versions of Ray. For forward " + "compatibility, set the environment variable " + "TUNE_RESULT_DELIM='/'") + return delimiter + @property def best_result_df(self) -> DataFrame: """Get the best result of the experiment as a pandas dataframe. @@ -569,7 +582,9 @@ def best_result_df(self) -> DataFrame: if not pd: raise ValueError("`best_result_df` requires pandas. Install with " "`pip install pandas`.") - best_result = flatten_dict(self.best_result, delimiter=".") + + best_result = flatten_dict( + self.best_result, delimiter=self._delimiter()) return pd.DataFrame.from_records([best_result], index="trial_id") @property @@ -579,12 +594,13 @@ def results(self) -> Dict[str, Dict]: @property def results_df(self) -> DataFrame: + """Get all the last results as a pandas dataframe.""" if not pd: - raise ValueError("`best_result_df` requires pandas. Install with " + raise ValueError("`results_df` requires pandas. Install with " "`pip install pandas`.") return pd.DataFrame.from_records( [ - flatten_dict(trial.last_result, delimiter=".") + flatten_dict(trial.last_result, delimiter=self._delimiter()) for trial in self.trials ], index="trial_id") diff --git a/python/ray/tune/commands.py b/python/ray/tune/commands.py index 7fbbe9776bde2..5d47605c63181 100644 --- a/python/ray/tune/commands.py +++ b/python/ray/tune/commands.py @@ -116,10 +116,9 @@ def list_trials(experiment_path, _check_tabulate() try: - checkpoints_df = Analysis(experiment_path).dataframe( - metric="episode_reward_mean", mode="max") - except TuneError: - raise click.ClickException("No trial data found!") + checkpoints_df = Analysis(experiment_path).dataframe() # last result + except TuneError as e: + raise click.ClickException("No trial data found!") from e def key_filter(k): return k in DEFAULT_CLI_KEYS or k.startswith(CONFIG_PREFIX) diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py index db822434f1223..77b80e510af2b 100644 --- a/python/ray/tune/durable_trainable.py +++ b/python/ray/tune/durable_trainable.py @@ -171,14 +171,16 @@ def durable(trainable: Union[str, Type[Trainable], Callable]): A durable trainable class wrapped around your trainable. """ + overwrite_name = None if isinstance(trainable, str): trainable_cls = get_trainable_cls(trainable) + overwrite_name = f"Durable{trainable}" else: trainable_cls = trainable if not inspect.isclass(trainable_cls): # Function API - return wrap_function(trainable_cls, durable=True) + return wrap_function(trainable_cls, durable=True, name=overwrite_name) if not issubclass(trainable_cls, Trainable): raise ValueError( @@ -187,8 +189,14 @@ def durable(trainable: Union[str, Type[Trainable], Callable]): f"it does. Got: {type(trainable_cls)}") # else: Class API + + # Class is already durable + + if issubclass(trainable_cls, DurableTrainable): + return trainable_cls + class _WrappedDurableTrainable(DurableTrainable, trainable_cls): - _name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \ - else "durable_trainable" + _name = overwrite_name or (trainable_cls.__name__ if hasattr( + trainable_cls, "__name__") else "durable_trainable") return _WrappedDurableTrainable diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index ae4235aa89099..e4c2018068d7a 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -10,6 +10,8 @@ from functools import partial from numbers import Number +from typing import Any, Callable, Optional + from six.moves import queue from ray.util.debug import log_once @@ -530,7 +532,10 @@ def _report_thread_runner_error(self, block=False): pass -def wrap_function(train_func, durable=False, warn=True): +def wrap_function(train_func: Callable[[Any], Any], + durable: bool = False, + warn: bool = True, + name: Optional[str] = None): inherit_from = (FunctionRunner, ) if hasattr(train_func, "__mixins__"): @@ -562,8 +567,8 @@ def wrap_function(train_func, durable=False, warn=True): "arguments to be `func(config, checkpoint_dir=None)`.") class ImplicitFunc(*inherit_from): - _name = train_func.__name__ if hasattr(train_func, "__name__") \ - else "func" + _name = name or (train_func.__name__ + if hasattr(train_func, "__name__") else "func") def _trainable_func(self, config, reporter, checkpoint_dir): if not use_checkpoint and not use_reporter: diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index edc8dcb5482d1..f555e684b4466 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -169,8 +169,8 @@ class TBXLogger(Logger): {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} """ - VALID_HPARAMS = (str, bool, np.bool8, int, np.integer, float, list, - type(None)) + VALID_HPARAMS = (str, bool, int, float, list, type(None)) + VALID_NP_HPARAMS = (np.bool8, np.float32, np.float64, np.int32, np.int64) def _init(self): try: @@ -254,10 +254,18 @@ def _try_log_hparams(self, result): if isinstance(v, self.VALID_HPARAMS) } + np_params = { + k: v.tolist() + for k, v in flat_params.items() + if isinstance(v, self.VALID_NP_HPARAMS) + } + + scrubbed_params.update(np_params) + removed = { k: v for k, v in flat_params.items() - if not isinstance(v, self.VALID_HPARAMS) + if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS) } if removed: logger.info( @@ -585,8 +593,7 @@ class TBXLoggerCallback(LoggerCallback): {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} """ - # NoneType is not supported on the last TBX release yet. - VALID_HPARAMS = (str, bool, int, float, list) + VALID_HPARAMS = (str, bool, int, float, list, type(None)) VALID_NP_HPARAMS = (np.bool8, np.float32, np.float64, np.int32, np.int64) def __init__(self): diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 0b69faa51550d..52f19f8029da2 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -1,5 +1,6 @@ from __future__ import print_function +import datetime from typing import Dict, List, Optional, Union import collections @@ -8,15 +9,17 @@ import numpy as np import time +from ray.util.annotations import PublicAPI, DeveloperAPI +from ray.util.queue import Queue + from ray.tune.callback import Callback from ray.tune.logger import pretty_print, logger -from ray.tune.result import (DEFAULT_METRIC, EPISODE_REWARD_MEAN, - MEAN_ACCURACY, MEAN_LOSS, TRAINING_ITERATION, - TIME_TOTAL_S, TIMESTEPS_TOTAL, AUTO_RESULT_KEYS) -from ray.tune.trial import DEBUG_PRINT_INTERVAL, Trial +from ray.tune.result import ( + DEFAULT_METRIC, EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS, NODE_IP, + PID, TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL, AUTO_RESULT_KEYS) +from ray.tune.trial import DEBUG_PRINT_INTERVAL, Trial, Location from ray.tune.utils import unflattened_lookup from ray.tune.utils.log import Verbosity, has_verbosity -from ray.util.annotations import PublicAPI, DeveloperAPI try: from collections.abc import Mapping, MutableMapping @@ -159,6 +162,8 @@ def __init__( self._max_report_freqency = max_report_frequency self._last_report_time = 0 + self._start_time = time.time() + self._metric = metric self._mode = mode @@ -188,6 +193,12 @@ def set_search_properties(self, metric: Optional[str], def set_total_samples(self, total_samples: int): self._total_samples = total_samples + def set_start_time(self, timestamp: Optional[float] = None): + if timestamp is not None: + self._start_time = time.time() + else: + self._start_time = timestamp + def should_report(self, trials: List[Trial], done: bool = False): if time.time() - self._last_report_time > self._max_report_freqency: self._last_report_time = time.time() @@ -267,7 +278,11 @@ def _progress_str(self, if not self._metrics_override: user_metrics = self._infer_user_metrics(trials, self._infer_limit) self._metric_columns.update(user_metrics) - messages = ["== Status ==", memory_debug_str(), *sys_info] + messages = [ + "== Status ==", + time_passed_str(self._start_time, time.time()), + memory_debug_str(), *sys_info + ] if done: max_progress = None max_error = None @@ -416,15 +431,32 @@ def __init__( "to `tune.run()` instead.") self._overwrite = overwrite + self._output_queue = None + + def set_output_queue(self, queue: Queue): + self._output_queue = queue def report(self, trials: List[Trial], done: bool, *sys_info: Dict): - from IPython.display import clear_output - from IPython.core.display import display, HTML - if self._overwrite: - clear_output(wait=True) + overwrite = self._overwrite progress_str = self._progress_str( trials, done, *sys_info, fmt="html", delim="
") - display(HTML(progress_str)) + + def update_output(): + from IPython.display import clear_output + from IPython.core.display import display, HTML + + if overwrite: + clear_output(wait=True) + + display(HTML(progress_str)) + + if self._output_queue is not None: + # If an output queue is set, send callable (e.g. when using + # Ray client) + self._output_queue.put(update_output) + else: + # Else, output directly + update_output() @PublicAPI @@ -510,6 +542,33 @@ def memory_debug_str(): "to resolve)") +def time_passed_str(start_time: float, current_time: float): + current_time_dt = datetime.datetime.fromtimestamp(current_time) + start_time_dt = datetime.datetime.fromtimestamp(start_time) + delta: datetime.timedelta = current_time_dt - start_time_dt + + rest = delta.total_seconds() + days = rest // (60 * 60 * 24) + + rest -= days * (60 * 60 * 24) + hours = rest // (60 * 60) + + rest -= hours * (60 * 60) + minutes = rest // 60 + + seconds = rest - minutes * 60 + + if days > 0: + running_for_str = f"{days:.0f} days, " + else: + running_for_str = "" + + running_for_str += f"{hours:02.0f}:{minutes:02.0f}:{seconds:05.2f}" + + return (f"Current time: {current_time_dt:%Y-%m-%d %H:%M:%S} " + f"(running for {running_for_str})") + + def _get_trials_by_state(trials: List[Trial]): trials_by_state = collections.defaultdict(list) for t in trials: @@ -774,6 +833,18 @@ def _fair_filter_trials(trials_by_state: Dict[str, List[Trial]], return filtered_trials +def _get_trial_location(trial: Trial, result: dict) -> Location: + # we get the location from the result, as the one in trial will be + # reset when trial terminates + node_ip, pid = result.get(NODE_IP, None), result.get(PID, None) + if node_ip and pid: + location = Location(node_ip, pid) + else: + # fallback to trial location if there hasn't been a report yet + location = trial.location + return location + + def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]): """Returns the following information about a trial: @@ -786,7 +857,8 @@ def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]): """ result = trial.last_result config = trial.config - trial_info = [str(trial), trial.status, str(trial.location)] + location = _get_trial_location(trial, result) + trial_info = [str(trial), trial.status, str(location)] trial_info += [ unflattened_lookup(param, config, default=None) for param in parameters ] diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 52ec1102a5f78..959fba6c0dcff 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -93,11 +93,18 @@ class _TrialCleanup: Args: threshold (int): Number of futures to hold at once. If the threshold is passed, cleanup will kick in and remove futures. + force_cleanup (int): Grace periods for forceful actor termination. + If 0, actors will not be forcefully terminated. """ - def __init__(self, threshold: int = TRIAL_CLEANUP_THRESHOLD): + def __init__(self, + threshold: int = TRIAL_CLEANUP_THRESHOLD, + force_cleanup: int = 0): self.threshold = threshold self._cleanup_map = {} + if force_cleanup < 0: + force_cleanup = 0 + self._force_cleanup = force_cleanup def add(self, trial: Trial, actor: ActorHandle): """Adds a trial actor to be stopped. @@ -123,15 +130,27 @@ def cleanup(self, partial: bool = True): If partial=False, all futures are expected to return. If a future does not return within the timeout period, the cleanup terminates. """ + # At this point, self._cleanup_map holds the last references + # to actors. Removing those references either one-by-one + # (graceful termination case) or all at once, by reinstantiating + # self._cleanup_map (forceful termination case) will cause Ray + # to kill the actors during garbage collection. logger.debug("Cleaning up futures") num_to_keep = int(self.threshold) / 2 if partial else 0 while len(self._cleanup_map) > num_to_keep: dones, _ = ray.wait( - list(self._cleanup_map), timeout=DEFAULT_GET_TIMEOUT) + list(self._cleanup_map), + timeout=DEFAULT_GET_TIMEOUT + if not self._force_cleanup else self._force_cleanup) if not dones: logger.warning( "Skipping cleanup - trainable.stop did not return in " "time. Consider making `stop` a faster operation.") + if not partial and self._force_cleanup: + logger.warning( + "Forcing trainable cleanup by terminating actors.") + self._cleanup_map = {} + return else: done = dones[0] del self._cleanup_map[done] @@ -165,7 +184,9 @@ def __init__(self, # We use self._paused to store paused trials here. self._paused = {} - self._trial_cleanup = _TrialCleanup() + force_trial_cleanup = int( + os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "0")) + self._trial_cleanup = _TrialCleanup(force_cleanup=force_trial_cleanup) self._has_cleaned_up_pgs = False self._reuse_actors = reuse_actors # The maxlen will be updated when `set_max_pending_trials()` is called diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index d83c727179387..9f143db42d37d 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -1,5 +1,8 @@ import logging +import uuid + from types import FunctionType +from typing import Optional import ray import ray.cloudpickle as pickle @@ -114,23 +117,25 @@ def check_serializability(key, value): _global_registry.register(TEST, key, value) -def _make_key(category, key): +def _make_key(prefix, category, key): """Generate a binary key for the given category and key. Args: + prefix (str): Prefix category (str): The category of the item key (str): The unique identifier for the item Returns: The key to use for storing a the value. """ - return (b"TuneRegistry:" + category.encode("ascii") + b"/" + - key.encode("ascii")) + return (b"TuneRegistry:" + prefix.encode("ascii") + b":" + + category.encode("ascii") + b"/" + key.encode("ascii")) class _Registry: - def __init__(self): + def __init__(self, prefix: Optional[str] = None): self._to_flush = {} + self._prefix = prefix or uuid.uuid4().hex[:8] def register(self, category, key, value): """Registers the value with the global registry. @@ -148,14 +153,14 @@ def register(self, category, key, value): def contains(self, category, key): if _internal_kv_initialized(): - value = _internal_kv_get(_make_key(category, key)) + value = _internal_kv_get(_make_key(self._prefix, category, key)) return value is not None else: return (category, key) in self._to_flush def get(self, category, key): if _internal_kv_initialized(): - value = _internal_kv_get(_make_key(category, key)) + value = _internal_kv_get(_make_key(self._prefix, category, key)) if value is None: raise ValueError( "Registry value for {}/{} doesn't exist.".format( @@ -166,11 +171,12 @@ def get(self, category, key): def flush_values(self): for (category, key), value in self._to_flush.items(): - _internal_kv_put(_make_key(category, key), value, overwrite=True) + _internal_kv_put( + _make_key(self._prefix, category, key), value, overwrite=True) self._to_flush.clear() -_global_registry = _Registry() +_global_registry = _Registry(prefix="global") ray.worker._post_init_hooks.append(_global_registry.flush_values) diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index e9eb7f40212dc..c166da2c0ce8e 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -67,6 +67,10 @@ DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL, MEAN_ACCURACY, MEAN_LOSS) +# Metrics that don't require at least one iteration to complete +DEBUG_METRICS = (TRIAL_ID, "experiment_id", "date", "timestamp", PID, HOSTNAME, + NODE_IP, "config") + # Make sure this doesn't regress AUTO_RESULT_KEYS = ( TRAINING_ITERATION, diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 8b7830b79d150..1a671f7b24996 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -80,6 +80,8 @@ class HyperBandScheduler(FIFOScheduler): reaching max_t. Defaults to True. """ + _supports_buffered_results = False + def __init__(self, time_attr: str = "training_iteration", metric: Optional[str] = None, diff --git a/python/ray/tune/schedulers/trial_scheduler.py b/python/ray/tune/schedulers/trial_scheduler.py index a0626416afe00..64e9d99613ae8 100644 --- a/python/ray/tune/schedulers/trial_scheduler.py +++ b/python/ray/tune/schedulers/trial_scheduler.py @@ -14,10 +14,16 @@ class TrialScheduler: _metric = None + _supports_buffered_results = True + @property def metric(self): return self._metric + @property + def supports_buffered_results(self): + return self._supports_buffered_results + def set_search_properties(self, metric: Optional[str], mode: Optional[str]) -> bool: """Pass search properties to scheduler. diff --git a/python/ray/tune/suggest/bohb.py b/python/ray/tune/suggest/bohb.py index 52ebf84e9acc2..e8f15c5082866 100644 --- a/python/ray/tune/suggest/bohb.py +++ b/python/ray/tune/suggest/bohb.py @@ -235,10 +235,10 @@ def to_wrapper(self, trial_id: str, result: Dict) -> _BOHBJobWrapper: def on_pause(self, trial_id: str): self.paused.add(trial_id) - self.running.remove(trial_id) + self.running.discard(trial_id) def on_unpause(self, trial_id: str): - self.paused.remove(trial_id) + self.paused.discard(trial_id) self.running.add(trial_id) @staticmethod diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index bb49de900fb1d..598b2a2dccf59 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -226,6 +226,14 @@ class B(Trainable): self.assertRaises(TypeError, lambda: register_trainable("foo", A)) self.assertRaises(TypeError, lambda: Experiment("foo", A)) + def testRegisterDurableTrainableTwice(self): + def train(config, reporter): + pass + + register_trainable("foo", train) + register_trainable("foo", tune.durable("foo")) + register_trainable("foo", tune.durable("foo")) + def testTrainableCallable(self): def dummy_fn(config, reporter, steps): reporter(timesteps_total=steps, done=True) diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 98dd4e4b2da58..87a6f42f7af25 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -190,7 +190,7 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster): running_trials = _get_running_trials(runner) assert len(running_trials) == 1 assert _check_trial_running(running_trials[0]) - assert not trial.last_result + assert not trial.has_reported_at_least_once assert trial.status == Trial.RUNNING cluster.remove_node(node) cluster.add_node(num_cpus=1) diff --git a/python/ray/tune/tests/test_logger.py b/python/ray/tune/tests/test_logger.py index ef75bfcfb49d5..84c633a9b0884 100644 --- a/python/ray/tune/tests/test_logger.py +++ b/python/ray/tune/tests/test_logger.py @@ -230,16 +230,6 @@ def testLegacyBadTBX(self): logger.close() assert "INFO" in cm.output[0] - config = {"None": None} - t = Trial( - evaluated_params=config, trial_id="tbx", logdir=self.test_dir) - logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) - logger.on_result(result(0, 4)) - logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1})) - with self.assertLogs("ray.tune.logger", level="INFO") as cm: - logger.close() - assert "INFO" in cm.output[0] - def testBadTBX(self): config = {"b": (1, 2, 3)} t = Trial( @@ -253,18 +243,6 @@ def testBadTBX(self): logger.on_trial_complete(3, [], t) assert "INFO" in cm.output[0] - config = {"None": None} - t = Trial( - evaluated_params=config, trial_id="tbx", logdir=self.test_dir) - logger = TBXLoggerCallback() - logger.on_trial_result(0, [], t, result(0, 4)) - logger.on_trial_result(1, [], t, result(1, 5)) - logger.on_trial_result( - 2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1})) - with self.assertLogs("ray.tune.logger", level="INFO") as cm: - logger.on_trial_complete(3, [], t) - assert "INFO" in cm.output[0] - if __name__ == "__main__": import pytest diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index 2e85fe0a6b368..6978d2c128c6f 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -3,13 +3,14 @@ import os import unittest from unittest.mock import MagicMock, Mock, patch + from ray import tune from ray._private.test_utils import run_string_as_driver from ray.tune.trial import Trial from ray.tune.result import AUTO_RESULT_KEYS -from ray.tune.progress_reporter import (CLIReporter, JupyterNotebookReporter, - _fair_filter_trials, best_trial_str, - detect_reporter, trial_progress_str) +from ray.tune.progress_reporter import ( + CLIReporter, JupyterNotebookReporter, _fair_filter_trials, best_trial_str, + detect_reporter, trial_progress_str, time_passed_str) EXPECTED_RESULT_1 = """Result logdir: /foo Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED) @@ -60,76 +61,92 @@ END_TO_END_COMMAND = """ import ray from ray import tune +from ray.tune.trial import Location +from ray.tune.progress_reporter import _get_trial_location +from unittest.mock import patch -reporter = tune.progress_reporter.CLIReporter(metric_columns=["done"]) -def f(config): - return {"done": True} +def mock_get_trial_location(trial, result): + location = _get_trial_location(trial, result) + if location.pid: + return Location("123.123.123.123", "1") + return location -ray.init(num_cpus=1) -tune.run_experiments({ - "one": { - "run": f, - "config": { - "a": tune.grid_search(list(range(10))), - }, - }, - "two": { - "run": f, - "config": { - "b": tune.grid_search(list(range(10))), - }, - }, - "three": { - "run": f, - "config": { - "c": tune.grid_search(list(range(10))), + +with patch("ray.tune.progress_reporter._get_trial_location", + mock_get_trial_location): + reporter = tune.progress_reporter.CLIReporter(metric_columns=["done"]) + + def f(config): + return {"done": True} + + ray.init(num_cpus=1) + tune.run_experiments( + { + "one": { + "run": f, + "config": { + "a": tune.grid_search(list(range(10))), + }, + }, + "two": { + "run": f, + "config": { + "b": tune.grid_search(list(range(10))), + }, + }, + "three": { + "run": f, + "config": { + "c": tune.grid_search(list(range(10))), + }, + }, }, - }, -}, verbose=3, progress_reporter=reporter)""" + verbose=3, + progress_reporter=reporter)""" EXPECTED_END_TO_END_START = """Number of trials: 30/30 (29 PENDING, 1 RUNNING) -+---------------+----------+-------+-----+-----+ -| Trial name | status | loc | a | b | -|---------------+----------+-------+-----+-----| -| f_xxxxx_00000 | RUNNING | | 0 | | -| f_xxxxx_00001 | PENDING | | 1 | |""" ++---------------+----------+-------------------+-----+-----+ +| Trial name | status | loc | a | b | +|---------------+----------+-------------------+-----+-----| +| f_xxxxx_00000 | RUNNING | 123.123.123.123:1 | 0 | | +| f_xxxxx_00001 | PENDING | | 1 | |""" EXPECTED_END_TO_END_END = """Number of trials: 30/30 (30 TERMINATED) -+---------------+------------+-------+-----+-----+-----+--------+ -| Trial name | status | loc | a | b | c | done | -|---------------+------------+-------+-----+-----+-----+--------| -| f_xxxxx_00000 | TERMINATED | | 0 | | | True | -| f_xxxxx_00001 | TERMINATED | | 1 | | | True | -| f_xxxxx_00002 | TERMINATED | | 2 | | | True | -| f_xxxxx_00003 | TERMINATED | | 3 | | | True | -| f_xxxxx_00004 | TERMINATED | | 4 | | | True | -| f_xxxxx_00005 | TERMINATED | | 5 | | | True | -| f_xxxxx_00006 | TERMINATED | | 6 | | | True | -| f_xxxxx_00007 | TERMINATED | | 7 | | | True | -| f_xxxxx_00008 | TERMINATED | | 8 | | | True | -| f_xxxxx_00009 | TERMINATED | | 9 | | | True | -| f_xxxxx_00010 | TERMINATED | | | 0 | | True | -| f_xxxxx_00011 | TERMINATED | | | 1 | | True | -| f_xxxxx_00012 | TERMINATED | | | 2 | | True | -| f_xxxxx_00013 | TERMINATED | | | 3 | | True | -| f_xxxxx_00014 | TERMINATED | | | 4 | | True | -| f_xxxxx_00015 | TERMINATED | | | 5 | | True | -| f_xxxxx_00016 | TERMINATED | | | 6 | | True | -| f_xxxxx_00017 | TERMINATED | | | 7 | | True | -| f_xxxxx_00018 | TERMINATED | | | 8 | | True | -| f_xxxxx_00019 | TERMINATED | | | 9 | | True | -| f_xxxxx_00020 | TERMINATED | | | | 0 | True | -| f_xxxxx_00021 | TERMINATED | | | | 1 | True | -| f_xxxxx_00022 | TERMINATED | | | | 2 | True | -| f_xxxxx_00023 | TERMINATED | | | | 3 | True | -| f_xxxxx_00024 | TERMINATED | | | | 4 | True | -| f_xxxxx_00025 | TERMINATED | | | | 5 | True | -| f_xxxxx_00026 | TERMINATED | | | | 6 | True | -| f_xxxxx_00027 | TERMINATED | | | | 7 | True | -| f_xxxxx_00028 | TERMINATED | | | | 8 | True | -| f_xxxxx_00029 | TERMINATED | | | | 9 | True | -+---------------+------------+-------+-----+-----+-----+--------+""" ++---------------+------------+-------------------+-----+-----+-----+--------+ +| Trial name | status | loc | a | b | c | done | +|---------------+------------+-------------------+-----+-----+-----+--------| +| f_xxxxx_00000 | TERMINATED | 123.123.123.123:1 | 0 | | | True | +| f_xxxxx_00001 | TERMINATED | 123.123.123.123:1 | 1 | | | True | +| f_xxxxx_00002 | TERMINATED | 123.123.123.123:1 | 2 | | | True | +| f_xxxxx_00003 | TERMINATED | 123.123.123.123:1 | 3 | | | True | +| f_xxxxx_00004 | TERMINATED | 123.123.123.123:1 | 4 | | | True | +| f_xxxxx_00005 | TERMINATED | 123.123.123.123:1 | 5 | | | True | +| f_xxxxx_00006 | TERMINATED | 123.123.123.123:1 | 6 | | | True | +| f_xxxxx_00007 | TERMINATED | 123.123.123.123:1 | 7 | | | True | +| f_xxxxx_00008 | TERMINATED | 123.123.123.123:1 | 8 | | | True | +| f_xxxxx_00009 | TERMINATED | 123.123.123.123:1 | 9 | | | True | +| f_xxxxx_00010 | TERMINATED | 123.123.123.123:1 | | 0 | | True | +| f_xxxxx_00011 | TERMINATED | 123.123.123.123:1 | | 1 | | True | +| f_xxxxx_00012 | TERMINATED | 123.123.123.123:1 | | 2 | | True | +| f_xxxxx_00013 | TERMINATED | 123.123.123.123:1 | | 3 | | True | +| f_xxxxx_00014 | TERMINATED | 123.123.123.123:1 | | 4 | | True | +| f_xxxxx_00015 | TERMINATED | 123.123.123.123:1 | | 5 | | True | +| f_xxxxx_00016 | TERMINATED | 123.123.123.123:1 | | 6 | | True | +| f_xxxxx_00017 | TERMINATED | 123.123.123.123:1 | | 7 | | True | +| f_xxxxx_00018 | TERMINATED | 123.123.123.123:1 | | 8 | | True | +| f_xxxxx_00019 | TERMINATED | 123.123.123.123:1 | | 9 | | True | +| f_xxxxx_00020 | TERMINATED | 123.123.123.123:1 | | | 0 | True | +| f_xxxxx_00021 | TERMINATED | 123.123.123.123:1 | | | 1 | True | +| f_xxxxx_00022 | TERMINATED | 123.123.123.123:1 | | | 2 | True | +| f_xxxxx_00023 | TERMINATED | 123.123.123.123:1 | | | 3 | True | +| f_xxxxx_00024 | TERMINATED | 123.123.123.123:1 | | | 4 | True | +| f_xxxxx_00025 | TERMINATED | 123.123.123.123:1 | | | 5 | True | +| f_xxxxx_00026 | TERMINATED | 123.123.123.123:1 | | | 6 | True | +| f_xxxxx_00027 | TERMINATED | 123.123.123.123:1 | | | 7 | True | +| f_xxxxx_00028 | TERMINATED | 123.123.123.123:1 | | | 8 | True | +| f_xxxxx_00029 | TERMINATED | 123.123.123.123:1 | | | 9 | True | ++---------------+------------+-------------------+-----+-----+-----+--------+""" # noqa EXPECTED_END_TO_END_AC = """Number of trials: 30/30 (30 TERMINATED) +---------------+------------+-------+-----+-----+-----+ @@ -217,15 +234,26 @@ def f(config): Trial train_xxxxx_00002 reported acc=8 with parameters={'do': 'twice'}. """ + \ "This trial completed." -VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------+----------+ -| Trial name | status | loc | do | -|-------------------+----------+-------+----------| -| train_xxxxx_00000 | RUNNING | | complete |""" +VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------------------+----------+ +| Trial name | status | loc | do | +|-------------------+----------+-------------------+----------| +| train_xxxxx_00000 | RUNNING | 123.123.123.123:1 | complete |""" VERBOSE_CMD = """from ray import tune import random import numpy as np import time +from ray.tune.trial import Location +from ray.tune.progress_reporter import _get_trial_location +from unittest.mock import patch + + +def mock_get_trial_location(trial, result): + location = _get_trial_location(trial, result) + if location.pid: + return Location("123.123.123.123", "1") + return location + def train(config): if config["do"] == "complete": @@ -242,11 +270,14 @@ def train(config): random.seed(1234) np.random.seed(1234) -tune.run( - train, - config={ - "do": tune.grid_search(["complete", "once", "twice"]) - },""" + +with patch("ray.tune.progress_reporter._get_trial_location", + mock_get_trial_location): + tune.run( + train, + config={ + "do": tune.grid_search(["complete", "once", "twice"]) + },""" # Add "verbose=3)" etc @@ -424,6 +455,27 @@ def testProgressStr(self): best1 = best_trial_str(trials[1], "metric_1") assert best1 == EXPECTED_BEST_1 + def testTimeElapsed(self): + # Sun Feb 7 14:18:40 2016 -0800 + # (time of the first Ray commit) + time_start = 1454825920 + time_now = ( + time_start + 1 * 60 * 60 # 1 hour + + 31 * 60 # 31 minutes + + 22 # 22 seconds + ) # time to second commit + + # Local timezone output can be tricky, so we don't check the + # day and the hour in this test. + output = time_passed_str(time_start, time_now) + self.assertIn("Current time: 2016-02-", output) + self.assertIn(":50:02 (running for 01:31:22.00)", output) + + time_now += 2 * 60 * 60 * 24 # plus two days + output = time_passed_str(time_start, time_now) + self.assertIn("Current time: 2016-02-", output) + self.assertIn(":50:02 (running for 2 days, 01:31:22.00)", output) + def testCurrentBestTrial(self): trials = [] for i in range(5): diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index a21664a2c11ee..f5d87e7dd1926 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -2,6 +2,7 @@ import os import pytest +import time import unittest import ray @@ -11,7 +12,7 @@ from ray.tune.callback import Callback from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.registry import _global_registry, TRAINABLE_CLASS -from ray.tune.result import TRAINING_ITERATION +from ray.tune.result import PID, TRAINING_ITERATION, TRIAL_ID from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, Checkpoint from ray.tune.resources import Resources @@ -252,6 +253,68 @@ def reset_config(self, config): self.assertEqual(trial.experiment_tag, "modified_mock") self.assertEqual(Trial.RUNNING, trial.status) + def testForceTrialCleanup(self): + class B(Trainable): + def step(self): + print("Step start") + time.sleep(10) + print("Step done") + return dict(my_metric=1, timesteps_this_iter=1, done=True) + + def reset_config(self, config): + self.config = config + return True + + def cleanup(self): + print("Cleanup start") + time.sleep(10) + print("Cleanup done") + + # First check if the trials terminate gracefully by default + trials = self.generate_trials({ + "run": B, + "config": { + "foo": 0 + }, + }, "grid_search") + trial = trials[0] + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + time.sleep(5) + print("Stop trial") + self.trial_executor.stop_trial(trial) + print("Start trial cleanup") + start = time.time() + self.trial_executor.cleanup([trial]) + self.assertGreaterEqual(time.time() - start, 12.0) + + # Check forceful termination. It should run for much less than the + # sleep periods in the Trainable + trials = self.generate_trials({ + "run": B, + "config": { + "foo": 0 + }, + }, "grid_search") + trial = trials[0] + os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1" + self.trial_executor = RayTrialExecutor(queue_trials=False) + os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0" + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + time.sleep(5) + print("Stop trial") + self.trial_executor.stop_trial(trial) + print("Start trial cleanup") + start = time.time() + self.trial_executor.cleanup([trial]) + self.assertLess(time.time() - start, 5.0) + + # also check if auto-filled metrics were returned + self.assertIn(PID, trial.last_result) + self.assertIn(TRIAL_ID, trial.last_result) + self.assertNotIn("my_metric", trial.last_result) + @staticmethod def generate_trials(spec, name): suggester = BasicVariantGenerator() @@ -480,6 +543,10 @@ def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects + def testForceTrialCleanup(self): + self.skipTest("Skipping as force trial cleanup is not applicable" + " for local mode.") + if __name__ == "__main__": import sys diff --git a/python/ray/tune/tests/test_trial_runner_3.py b/python/ray/tune/tests/test_trial_runner_3.py index e467eafa5e51e..44341ebf99cf6 100644 --- a/python/ray/tune/tests/test_trial_runner_3.py +++ b/python/ray/tune/tests/test_trial_runner_3.py @@ -555,7 +555,8 @@ def testTrialNoSave(self): self.assertTrue( runner2.get_trial("checkpoint").status == Trial.TERMINATED) self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING) - self.assertTrue(not runner2.get_trial("pending").last_result) + self.assertTrue( + not runner2.get_trial("pending").has_reported_at_least_once) runner2.step() def testCheckpointWithFunction(self): diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index 16f40b7602712..f9cf300948ea6 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -154,7 +154,7 @@ def testCallbackSteps(self): result = {TRAINING_ITERATION: 1, "metric": 800, "done": False} self.executor.results[trials[1]] = result self.executor.next_trial = trials[1] - self.assertEqual(trials[1].last_result, {}) + self.assertTrue(not trials[1].has_reported_at_least_once) self.trial_runner.step() self.assertEqual(self.callback.state["trial_result"]["iteration"], 3) self.assertEqual(self.callback.state["trial_result"]["trial"].trial_id, diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 0e0a2dd65c701..798c08192ab21 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -847,6 +847,7 @@ def __init__(self, i, config): self.resources = Resources(1, 0) self.custom_trial_name = None self.custom_dirname = None + self._default_result_or_future = None def on_checkpoint(self, checkpoint): self.restored_checkpoint = checkpoint.value diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index 81b90dcfeebf2..31a8f02132101 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -192,8 +192,13 @@ def MockTrainingFuncSync(config, checkpoint_dir=None): "checkpoint") with open(checkpoint_path, "wb") as fp: pickle.dump((a, iter), fp) + # Different sleep times so that asynch test runs do not + # randomly succeed. If well performing trials finish later, + # then bad performing trials will already have continued + # to train, which is exactly what we want to test when + # comparing sync vs. async. + time.sleep(a / 20) # Score gets better every iteration. - time.sleep(1) tune.report(mean_accuracy=iter + a, a=a) self.MockTrainingFuncSync = MockTrainingFuncSync @@ -201,7 +206,10 @@ def MockTrainingFuncSync(config, checkpoint_dir=None): def tearDown(self): ray.shutdown() - def synchSetup(self, synch, param=[10, 20, 30]): + def synchSetup(self, synch, param=None): + if param is None: + param = [10, 20, 30] + scheduler = PopulationBasedTraining( time_attr="training_iteration", metric="mean_accuracy", diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 7e63147ca4e00..3299d7aa4e861 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -8,17 +8,18 @@ import sys import tempfile import time -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import uuid import ray import ray.cloudpickle as pickle from ray.tune.resources import Resources from ray.tune.result import ( - DEFAULT_RESULTS_DIR, SHOULD_CHECKPOINT, TIME_THIS_ITER_S, - TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL, EPISODES_THIS_ITER, - EPISODES_TOTAL, TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_INFO, - STDOUT_FILE, STDERR_FILE) + DEBUG_METRICS, DEFAULT_RESULTS_DIR, HOSTNAME, NODE_IP, PID, + SHOULD_CHECKPOINT, TIME_THIS_ITER_S, TIME_TOTAL_S, TIMESTEPS_THIS_ITER, + DONE, TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, + TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_ID, TRIAL_INFO, STDOUT_FILE, + STDERR_FILE) from ray.tune.utils import UtilMonitor from ray.tune.utils.placement_groups import PlacementGroupFactory from ray.tune.utils.trainable import TrainableUtil @@ -154,6 +155,40 @@ def get_current_ip(self): self._local_ip = ray.util.get_node_ip_address() return self._local_ip + def get_auto_filled_metrics(self, + now: Optional[datetime] = None, + time_this_iter: Optional[float] = None, + debug_metrics_only: bool = False) -> dict: + """Return a dict with metrics auto-filled by the trainable. + + If ``debug_metrics_only`` is True, only metrics that don't + require at least one iteration will be returned + (``ray.tune.result.DEBUG_METRICS``). + """ + if now is None: + now = datetime.today() + autofilled = { + TRIAL_ID: self.trial_id, + "experiment_id": self._experiment_id, + "date": now.strftime("%Y-%m-%d_%H-%M-%S"), + "timestamp": int(time.mktime(now.timetuple())), + TIME_THIS_ITER_S: time_this_iter, + TIME_TOTAL_S: self._time_total, + PID: os.getpid(), + HOSTNAME: platform.node(), + NODE_IP: self._local_ip, + "config": self.config, + "time_since_restore": self._time_since_restore, + "timesteps_since_restore": self._timesteps_since_restore, + "iterations_since_restore": self._iterations_since_restore + } + if debug_metrics_only: + autofilled = { + k: v + for k, v in autofilled.items() if k in DEBUG_METRICS + } + return autofilled + def is_actor(self): try: actor_id = ray.worker.global_worker.actor_id @@ -289,19 +324,7 @@ def train(self): result.setdefault("neg_mean_loss", -result["mean_loss"]) now = datetime.today() - result.update( - experiment_id=self._experiment_id, - date=now.strftime("%Y-%m-%d_%H-%M-%S"), - timestamp=int(time.mktime(now.timetuple())), - time_this_iter_s=time_this_iter, - time_total_s=self._time_total, - pid=os.getpid(), - hostname=platform.node(), - node_ip=self._local_ip, - config=self.config, - time_since_restore=self._time_since_restore, - timesteps_since_restore=self._timesteps_since_restore, - iterations_since_restore=self._iterations_since_restore) + result.update(self.get_auto_filled_metrics(now, time_this_iter)) monitor_data = self._monitor.get_data() if monitor_data: diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 6398b53f2292f..ede51f26ba5b1 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -8,19 +8,20 @@ import re import shutil import time -from typing import Callable, Dict, Sequence, Union +from typing import Callable, Dict, Optional, Sequence, Union import uuid import ray import ray.cloudpickle as cloudpickle -from ray.exceptions import GetTimeoutError +from ray.exceptions import GetTimeoutError, RayActorError from ray.tune import TuneError from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not # have been defined yet. See https://github.com/ray-project/ray/issues/1716. from ray.tune.registry import get_trainable_cls, validate_trainable -from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION +from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, NODE_IP, PID, + TRAINING_ITERATION, TRIAL_ID) from ray.tune.resources import Resources, \ json_to_resources, resources_to_json from ray.tune.utils.placement_groups import PlacementGroupFactory, \ @@ -299,7 +300,9 @@ def __init__(self, self.max_failures = max_failures # Local trial state that is updated during the run - self.last_result = {} + self._last_result = {} + self._default_result_or_future: Union[ray.ObjectRef, dict, None] = ( + None) self.last_update_time = -float("inf") # stores in memory max/min/avg/last-n-avg/last result for each @@ -394,6 +397,52 @@ def _setup_resources(self, log_always: bool = False): resource_kwargs["has_placement_group"] = True self.resources = Resources(**resource_kwargs) + def _get_default_result_or_future(self) -> Optional[dict]: + """Calls ray.get on self._default_result_or_future and assigns back. + + Returns None in case of exceptions. + Will also set the trial location if runner is set. + """ + if self._default_result_or_future and isinstance( + self._default_result_or_future, ray.ObjectRef): + try: + self._default_result_or_future = ray.get( + self._default_result_or_future) + except RayActorError: # error during initialization + self._default_result_or_future = None + if self._default_result_or_future and self.runner: + self.set_location( + Location( + self._default_result_or_future.get(NODE_IP), + self._default_result_or_future.get(PID))) + return self._default_result_or_future + + @property + def last_result(self) -> dict: + # The logic in here is as follows: + # 1. If the trial has reported at least once, last_result would have + # been set and therefore would not be empty. We can just return it. + # 2. If the trial has not reported at least once but we have the + # future for the default results dict, (obtained through + # Trainable.get_auto_filled_metrics), we get that future + # and return it. + # 3. In the worst case where we have nothing, we just set the + # trial_id and return that. + result = self._last_result + if not {k for k in result if k != TRIAL_ID}: + self._get_default_result_or_future() + result = self._default_result_or_future or result + result.setdefault(TRIAL_ID, self.trial_id) + return result + + @last_result.setter + def last_result(self, val: dict): + self._last_result = val + + @property + def has_reported_at_least_once(self) -> bool: + return bool(self._last_result) + @property def node_ip(self): return self.location.hostname @@ -499,6 +548,11 @@ def update_resources( def set_runner(self, runner): self.runner = runner + if runner: + # Do not block here, the result will be gotten when last_result + # property is accessed + self._default_result_or_future = ( + runner.get_auto_filled_metrics.remote(debug_metrics_only=True)) self.checkpoint_manager.delete = CheckpointDeleter( self._trainable_name(), runner, self.node_ip) # No need to invalidate state cache: runner is not stored in json @@ -603,7 +657,7 @@ def update_last_result(self, result, terminate=False): if self.experiment_tag: result.update(experiment_tag=self.experiment_tag) - self.set_location(Location(result.get("node_ip"), result.get("pid"))) + self.set_location(Location(result.get(NODE_IP), result.get(PID))) self.last_result = result self.last_update_time = time.time() @@ -729,6 +783,7 @@ def __getstate__(self): state["_state_json"] = None state["_state_valid"] = False + state["_default_result_or_future"] = None return copy.deepcopy(state) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 0d91ee3b8bc65..f4ce4ea70d001 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -15,8 +15,9 @@ from ray.tune.callback import CallbackList from ray.tune.stopper import NoopStopper from ray.tune.ray_trial_executor import RayTrialExecutor -from ray.tune.result import (DEFAULT_METRIC, TIME_THIS_ITER_S, - RESULT_DUPLICATE, SHOULD_CHECKPOINT) +from ray.tune.result import (DEBUG_METRICS, DEFAULT_METRIC, DONE, + TIME_THIS_ITER_S, RESULT_DUPLICATE, + SHOULD_CHECKPOINT) from ray.tune.syncer import CloudSyncer, get_cloud_syncer from ray.tune.trial import Checkpoint, Trial from ray.tune.schedulers import FIFOScheduler, TrialScheduler @@ -195,7 +196,9 @@ class TrialRunner: """ CKPT_FILE_TMPL = "experiment_state-{}.json" - VALID_RESUME_TYPES = [True, "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY"] + VALID_RESUME_TYPES = [ + True, "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY", "AUTO" + ] RAISE = "RAISE" def __init__(self, @@ -415,7 +418,7 @@ def _validate_resume(self, resume_type): Args: resume_type: One of True, "REMOTE", "LOCAL", - "PROMPT", "ERRORED_ONLY". + "PROMPT", "ERRORED_ONLY", "AUTO". """ # TODO: Consider supporting ERRORED_ONLY+REMOTE? if not resume_type: @@ -426,11 +429,54 @@ def _validate_resume(self, resume_type): # Not clear if we need this assertion, since we should always have a # local checkpoint dir. assert self._local_checkpoint_dir or self._remote_checkpoint_dir + + if resume_type == "AUTO": + if self._remote_checkpoint_dir: + logger.info( + f"Trying to find and download experiment checkpoint at " + f"{self._remote_checkpoint_dir}") + # Todo: This syncs the entire experiment including trial + # checkpoints. We should exclude these in the future. + try: + self._syncer.sync_down_if_needed() + self._syncer.wait() + except TuneError as e: + logger.warning( + f"Got error when trying to sync down: {e} " + f"\nPlease check this error message for potential " + f"access problems - if a directory was not found, " + f"that is expected at this stage when you're starting " + f"a new experiment.") + logger.info( + "No remote checkpoint was found or an error occurred " + "when trying to download the experiment checkpoint. " + "Please check the previous warning message for more " + "details. " + "Ray Tune will now start a new experiment.") + return False + logger.info( + "A remote experiment checkpoint was found and will be " + "used to restore the previous experiment state.") + return True + elif not self.checkpoint_exists(self._local_checkpoint_dir): + logger.info("No local checkpoint was found. " + "Ray Tune will now start a new experiment.") + return False + logger.info( + "A local experiment checkpoint was found and will be used " + "to restore the previous experiment state.") + return True + if resume_type in [True, "LOCAL", "PROMPT", "ERRORED_ONLY"]: if not self.checkpoint_exists(self._local_checkpoint_dir): raise ValueError( - f"Called resume ({resume_type}) when no checkpoint exists " - f"in local directory ({self._local_checkpoint_dir}).") + f"You called resume ({resume_type}) when no checkpoint " + f"exists in local directory " + f"({self._local_checkpoint_dir}). If you want to start " + f"a new experiment, use `resume=\"AUTO\"` or " + f"`resume=None`. If you expected an experiment to " + f"already exist, check if you supplied the correct " + f"`local_dir` to `tune.run()`.") elif resume_type == "PROMPT": if click.confirm(f"Resume from local directory? " f"({self._local_checkpoint_dir})"): @@ -448,12 +494,22 @@ def _validate_resume(self, resume_type): "`upload_dir` set to `tune.run(sync_config=...)`.") # Try syncing down the upload directory. - logger.info("Downloading from %s", self._remote_checkpoint_dir) - # TODO(ujvl): Note that this syncs down the entire directory, - # which may also contain trial checkpoints. We should selectively - # sync the necessary files instead. - self._syncer.sync_down_if_needed() - self._syncer.wait() + logger.info(f"Downloading experiment checkpoint from " + f"{self._remote_checkpoint_dir}") + # Todo: This syncs the entire experiment including trial + # checkpoints. We should exclude these in the future. + try: + self._syncer.sync_down_if_needed() + self._syncer.wait() + except TuneError as e: + raise RuntimeError( + "Syncing the remote experiment checkpoint to the driver " + "failed. Please check the error message. If you want to " + "start a new experiment, use `resume=\"AUTO\"` or " + "`resume=None`. If you expected an experiment to " + "already exist, check if you supplied the correct " + "`upload_dir` to the `tune.SyncConfig` passed to " + "`tune.run()`.") from e if not self.checkpoint_exists(self._local_checkpoint_dir): raise ValueError("Called resume when no checkpoint exists " @@ -863,6 +919,8 @@ def _process_trial_result(self, trial, result): flat_result = flatten_dict(result) self._validate_result_metrics(flat_result) + _trigger_callback_complete = False + if self._stopper(trial.trial_id, result) or trial.should_stop(flat_result): result.update(done=True) @@ -882,8 +940,7 @@ def _process_trial_result(self, trial, result): trial=trial, result=result.copy()) - self._callbacks.on_trial_complete( - iteration=self._iteration, trials=self._trials, trial=trial) + _trigger_callback_complete = True decision = TrialScheduler.STOP else: with warn_if_slow("scheduler.on_trial_result"): @@ -919,6 +976,10 @@ def _process_trial_result(self, trial, result): # the global checkpoint state. self._checkpoint_trial_if_needed(trial, force=force_checkpoint) + if _trigger_callback_complete: + self._callbacks.on_trial_complete( + iteration=self._iteration, trials=self._trials, trial=trial) + if trial.is_saving: # Cache decision to execute on after the save is processed. # This prevents changing the trial's state or kicking off @@ -932,15 +993,18 @@ def _process_trial_result(self, trial, result): def _validate_result_metrics(self, result): """ Check if any of the required metrics was not reported - in the last result. If the only item is `done=True`, this - means that no result was ever received and the trial just - returned. This is also okay and will not raise an error. + in the last result. If the only items are ``done`` or any of + DEBUG_METRICS, this means that no result was ever received and + the trial just returned. This is also okay and will not raise + an error. This will ignore checking for the DEFAULT_METRIC. """ - if int(os.environ.get("TUNE_DISABLE_STRICT_METRIC_CHECKING", - 0)) != 1 and (len(result) > 1 - or "done" not in result): + if int(os.environ.get( + "TUNE_DISABLE_STRICT_METRIC_CHECKING", 0)) != 1 and (len({ + k + for k in result if k not in list(DEBUG_METRICS) + [DONE] + }) > 1): base_metric = self._metric \ if self._metric != DEFAULT_METRIC else None scheduler_metric = self._scheduler_alg.metric \ diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 8077f7c6e6cd2..a3553879633a9 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -11,13 +11,15 @@ import ray from ray.util.annotations import PublicAPI +from ray.util.queue import Queue, Empty from ray.tune.analysis import ExperimentAnalysis from ray.tune.callback import Callback from ray.tune.error import TuneError from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.logger import Logger -from ray.tune.progress_reporter import detect_reporter, ProgressReporter +from ray.tune.progress_reporter import (detect_reporter, ProgressReporter, + JupyterNotebookReporter) from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.registry import get_trainable_cls from ray.tune.stopper import Stopper @@ -314,7 +316,48 @@ def run( # Make sure tune.run is called on the sever node. remote_run = force_on_current_node(remote_run) - return ray.get(remote_run.remote(_remote=False, **remote_run_kwargs)) + # JupyterNotebooks don't work with remote tune runs out of the box + # (e.g. via Ray client) as they don't have access to the main + # process stdout. So we introduce a queue here that accepts + # callables, which will then be executed on the driver side. + if isinstance(progress_reporter, JupyterNotebookReporter): + execute_queue = Queue(actor_options={ + "num_cpus": 0, + **force_on_current_node(None) + }) + progress_reporter.set_output_queue(execute_queue) + + def get_next_queue_item(): + try: + return execute_queue.get(block=False) + except Empty: + return None + + else: + # If we don't need a queue, use this dummy get fn instead of + # scheduling an unneeded actor + def get_next_queue_item(): + return None + + def _handle_execute_queue(): + execute_item = get_next_queue_item() + while execute_item: + if isinstance(execute_item, Callable): + execute_item() + + execute_item = get_next_queue_item() + + remote_future = remote_run.remote(_remote=False, **remote_run_kwargs) + + # ray.wait(...)[1] returns futures that are not ready, yet + while ray.wait([remote_future], timeout=0.2)[1]: + # Check if we have items to execute + _handle_execute_queue() + + # Handle queue one last time + _handle_execute_queue() + + return ray.get(remote_future) del remote_run_kwargs @@ -341,8 +384,34 @@ def run( if num_samples == -1: num_samples = sys.maxsize + result_buffer_length = None + + # Create scheduler here as we need access to some of its properties + if isinstance(scheduler, str): + # importing at top level causes a recursive dependency + from ray.tune.schedulers import create_scheduler + scheduler = create_scheduler(scheduler) + scheduler = scheduler or FIFOScheduler() + + if scheduler.supports_buffered_results: + # Result buffering with a Hyperband scheduler is a bad idea, as + # hyperband tries to stop trials when processing brackets. With result + # buffering, we might trigger this multiple times when evaluating + # a single trial, which leads to unexpected behavior. + env_result_buffer_length = os.getenv("TUNE_RESULT_BUFFER_LENGTH", "") + if env_result_buffer_length: + warnings.warn( + f"You are using a {type(scheduler)} scheduler, but " + f"TUNE_RESULT_BUFFER_LENGTH is set " + f"({env_result_buffer_length}). This can lead to undesired " + f"and faulty behavior, so the buffer length was forcibly set " + f"to 1 instead.") + result_buffer_length = 1 + trial_executor = trial_executor or RayTrialExecutor( - reuse_actors=reuse_actors, queue_trials=queue_trials) + reuse_actors=reuse_actors, + queue_trials=queue_trials, + result_buffer_length=result_buffer_length) if isinstance(run_or_experiment, list): experiments = run_or_experiment else: @@ -395,11 +464,6 @@ def run( if is_local_mode: max_concurrent_trials = 1 - if isinstance(scheduler, str): - # importing at top level causes a recursive dependency - from ray.tune.schedulers import create_scheduler - scheduler = create_scheduler(scheduler) - if not search_alg: search_alg = BasicVariantGenerator( max_concurrent=max_concurrent_trials or 0) @@ -447,7 +511,6 @@ def run( "does not contain any more parameter definitions - include " "them in the search algorithm's search space if necessary.") - scheduler = scheduler or FIFOScheduler() if not scheduler.set_search_properties(metric, mode): raise ValueError( "You passed a `metric` or `mode` argument to `tune.run()`, but " @@ -531,6 +594,7 @@ def sigint_handler(sig, frame): signal.signal(signal.SIGINT, sigint_handler) tune_start = time.time() + progress_reporter.set_start_time(tune_start) while not runner.is_finished() and not state[signal.SIGINT]: runner.step() if has_verbosity(Verbosity.V1_EXPERIMENT): diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index 0f4612c66c047..c41179c43b845 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -645,14 +645,15 @@ def get_current_node_resource_key() -> str: raise ValueError("Cannot found the node dictionary for current node.") -def force_on_current_node(task_or_actor): +def force_on_current_node(task_or_actor=None): """Given a task or actor, place it on the current node. If using Ray Client, the current node is the client server node. Args: task_or_actor: A Ray remote function or class to place on the - current node. + current node. If None, returns the options dict to pass to + another actor. Returns: The provided task or actor, but with options modified to force @@ -660,6 +661,10 @@ def force_on_current_node(task_or_actor): """ node_resource_key = get_current_node_resource_key() options = {"resources": {node_resource_key: 0.01}} + + if task_or_actor is None: + return options + return task_or_actor.options(**options) diff --git a/python/ray/util/__init__.py b/python/ray/util/__init__.py index 3177925e68fc0..7a326154bf1e3 100644 --- a/python/ray/util/__init__.py +++ b/python/ray/util/__init__.py @@ -19,7 +19,7 @@ @PublicAPI(stability="beta") -@client_mode_hook +@client_mode_hook(auto_init=True) def list_named_actors(all_namespaces: bool = False) -> List[str]: """List all named actors in the system. diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index f11b692d56f42..29b95c850c2a4 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -5,7 +5,6 @@ import os import sys import logging -import json import threading import grpc @@ -66,7 +65,7 @@ def connect(self, job_config = job_config or JobConfig() job_config.set_ray_namespace(namespace) if job_config is not None: - runtime_env = json.loads(job_config.get_serialized_runtime_env()) + runtime_env = job_config.runtime_env if runtime_env.get("pip") or runtime_env.get("conda"): logger.warning("The 'pip' or 'conda' field was specified in " "the runtime env, so it may take some time to " diff --git a/python/ray/util/client/client_pickler.py b/python/ray/util/client/client_pickler.py index 9c1ebef68d565..0faf3c99c68cd 100644 --- a/python/ray/util/client/client_pickler.py +++ b/python/ray/util/client/client_pickler.py @@ -49,12 +49,17 @@ else: import pickle # noqa: F401 + # NOTE(barakmich): These PickleStubs are really close to -# the data for an exectuion, with no arguments. Combine the two? -PickleStub = NamedTuple("PickleStub", - [("type", str), ("client_id", str), ("ref_id", bytes), - ("name", Optional[str]), - ("baseline_options", Optional[Dict])]) +# the data for an execution, with no arguments. Combine the two? +class PickleStub( + NamedTuple("PickleStub", [("type", str), ("client_id", str), + ("ref_id", bytes), ("name", Optional[str]), + ("baseline_options", Optional[Dict])])): + def __reduce__(self): + # PySpark's namedtuple monkey patch breaks compatibility with + # cloudpickle. Thus we revert this patch here if it exists. + return object.__reduce__(self) class ClientPickler(cloudpickle.CloudPickler): diff --git a/python/ray/util/client/options.py b/python/ray/util/client/options.py index 9c9df946d0cf5..ec6c568d5b347 100644 --- a/python/ray/util/client/options.py +++ b/python/ray/util/client/options.py @@ -36,7 +36,6 @@ "placement_group_bundle_index": (), "placement_group_capture_child_tasks": (), "runtime_env": (), - "override_environment_variables": (), } diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 0fb2f07429b1d..98ad26c93d8b4 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -27,10 +27,10 @@ from ray.util.client.server.dataservicer import _get_reconnecting_from_context from ray._private.client_mode_hook import disable_client_hook from ray._private.parameter import RayParams -from ray._private.runtime_env import RuntimeEnvContext +from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.services import ProcessInfo, start_ray_client_server -from ray._private.utils import (detect_fate_sharing_support, - add_port_to_grpc_server) +from ray._private.utils import detect_fate_sharing_support +from ray._private.tls_utils import add_port_to_grpc_server # Import psutil after ray so the packaged version is used. import psutil @@ -264,7 +264,9 @@ def start_specific_server(self, client_id: str, f"ray_client_server_{specific_server.port}", unique=True) serialized_runtime_env = job_config.get_serialized_runtime_env() - if serialized_runtime_env == "{}": + if not serialized_runtime_env or serialized_runtime_env == "{}": + # TODO(edoakes): can we just remove this case and always send it + # to the agent? serialized_runtime_env_context = RuntimeEnvContext().serialize() else: serialized_runtime_env_context = self._create_runtime_env( diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 351b981d0a17c..27a10d18e3b11 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -35,7 +35,7 @@ from ray.ray_constants import env_integer from ray.util.placement_group import PlacementGroup from ray._private.client_mode_hook import disable_client_hook -from ray._private.utils import add_port_to_grpc_server +from ray._private.tls_utils import add_port_to_grpc_server logger = logging.getLogger(__name__) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index b5c50215c5488..4b45ac0c761ee 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -128,6 +128,9 @@ def __init__( self._connect_channel() self._has_connected = True + # Has Ray been initialized on the server? + self._serverside_ray_initialized = False + # Initialize the streams to finish protocol negotiation. self.data_client = DataClient(self, self._client_id, self.metadata) self.reference_count: Dict[bytes, int] = defaultdict(int) @@ -359,8 +362,8 @@ def get(self, vals, *, timeout: Optional[float] = None) -> Any: logger.debug("Internal retry for get {}".format(to_get)) if len(to_get) != len(res): raise Exception( - "Mismatched number of items in request ({}) and response ({})" - .format(len(to_get), len(res))) + "Mismatched number of items in request ({}) and response ({})". + format(len(to_get), len(res))) if isinstance(vals, ClientObjectRef): res = res[0] return res @@ -647,10 +650,17 @@ def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]: return json.loads(self.data_client.ListNamedActors(req).actors_json) def is_initialized(self) -> bool: - if self.server is not None: - return self.get_cluster_info( + if not self.is_connected() or self.server is None: + return False + if not self._serverside_ray_initialized: + # We only check that Ray is initialized on the server once to + # avoid making an RPC every time this function is called. This is + # safe to do because Ray only 'un-initializes' on the server when + # the Client connection is torn down. + self._serverside_ray_initialized = self.get_cluster_info( ray_client_pb2.ClusterInfoType.IS_INITIALIZED) - return False + + return self._serverside_ray_initialized def ping_server(self, timeout=None) -> bool: """Simple health check. diff --git a/python/ray/util/dask/scheduler_utils.py b/python/ray/util/dask/scheduler_utils.py index a1805048c989b..dba0c660b0c5b 100644 --- a/python/ray/util/dask/scheduler_utils.py +++ b/python/ray/util/dask/scheduler_utils.py @@ -371,8 +371,11 @@ def fire_task(): return nested_get(result, state["cache"]) -def apply_sync(func, args=(), kwds={}, callback=None): +def apply_sync(func, args=(), kwds=None, callback=None): """ A naive synchronous version of apply_async """ + if kwds is None: + kwds = {} + res = func(*args, **kwds) if callback is not None: callback(res) diff --git a/python/ray/util/placement_group.py b/python/ray/util/placement_group.py index 43741556f54e1..933695ea0fbe1 100644 --- a/python/ray/util/placement_group.py +++ b/python/ray/util/placement_group.py @@ -25,7 +25,7 @@ def _export_bundle_reservation_check_method_if_needed(): if bundle_reservation_check: return - @ray.remote(num_cpus=0, max_calls=0) + @ray.remote(num_cpus=0) def bundle_reservation_check_func(placement_group): return placement_group @@ -307,7 +307,7 @@ def get_current_placement_group() -> Optional[PlacementGroup]: None if the current task or actor wasn't created with any placement group. """ - if client_mode_should_convert(): + if client_mode_should_convert(auto_init=True): # Client mode is only a driver. return None worker = ray.worker.global_worker diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 77bf9e1454ea9..90b8c0adb44cc 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -5,7 +5,6 @@ import ray import torch -from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS from ray.util.sgd import utils from ray.util.sgd.torch.utils import choose_amp_backend @@ -63,6 +62,7 @@ def setup_operator(self): world_rank=0, local_rank=0, is_distributed=False, + device=None, use_gpu=self.use_gpu, use_fp16=self.use_fp16, use_tqdm=self.use_tqdm, @@ -121,11 +121,6 @@ def train_epoch(self, info = info or {} self._toggle_profiling(profile=profile) - info.update({ - NUM_STEPS: num_steps, - USE_FP16: self.use_fp16, - "epoch_idx": self.epochs, - }) with self.timers.record("train_epoch"): if iterator is not None: # Dataset will provide us with a list of tuples but we @@ -141,7 +136,11 @@ def format_batch(batch): else: iterator = self.make_iterator( training=True, num_steps=num_steps) - train_stats = self.training_operator.train_epoch(iterator, info) + train_stats = self.training_operator.train_epoch( + iterator, + info=info, + num_steps=num_steps, + epoch_idx=self.epochs) # This is so that `epochs` is first in ordering. stats = dict(epoch=self.epochs, **train_stats) @@ -151,7 +150,6 @@ def format_batch(batch): def validate(self, num_steps=None, profile=False, info=None): """Evaluates the model on the validation data set.""" - info = info or {} self._toggle_profiling(profile=profile) with self.timers.record("validation"): diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index 7143d5c558fd0..3a37436c43e92 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -11,11 +11,8 @@ from ray.util.annotations import PublicAPI from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection, NUM_SAMPLES) -from ray.util.sgd.torch.constants import ( - SCHEDULER_STEP_EPOCH, - NUM_STEPS, - SCHEDULER_STEP_BATCH, -) +from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS, + SCHEDULER_STEP_BATCH, USE_FP16) from ray.util.sgd.torch.utils import choose_amp_backend from torch.nn.parallel import DistributedDataParallel @@ -131,14 +128,15 @@ def __init__(self, config, world_rank, local_rank, - is_distributed=False, - device=None, - use_gpu=False, + is_distributed, + use_gpu, + device, use_fp16=False, use_tqdm=False, wrap_ddp=False, add_dist_sampler=False, scheduler_step_freq=None): + # You are not expected to override this method. self._world_rank = world_rank self._local_rank = local_rank @@ -456,7 +454,7 @@ def should_wrap_dataloader(loader): self._validation_loader = with_sampler( self._validation_loader) - def train_epoch(self, iterator, info): + def train_epoch(self, iterator, info=None, num_steps=None, epoch_idx=0): """Runs one standard training pass over the training dataloader. By default, this method will iterate over the given iterator and @@ -489,8 +487,10 @@ def train_epoch(self, ...): Args: iterator (iter): Iterator over the training data for the entire epoch. This iterator is expected to be entirely consumed. - info (dict): Dictionary for information to be used for custom - training operations. + info (Optional[dict]): Dictionary for information to be used for + custom training operations. + num_steps (Optional[int]): Number of steps in the iterator. + epoch_idx (int): Index of current epoch. Returns: A dict of metrics from training. @@ -499,6 +499,14 @@ def train_epoch(self, ...): raise RuntimeError("Either set self.model in setup function or " "override this method to implement a custom " "training loop.") + + info = info or {} + + info.update({ + NUM_STEPS: num_steps, + USE_FP16: self.use_fp16, + "epoch_idx": epoch_idx + }) model = self.model scheduler = None if hasattr(self, "scheduler"): @@ -636,7 +644,7 @@ def train_batch(self, batch, batch_info): return {"train_loss": loss.item(), NUM_SAMPLES: target.size(0)} - def validate(self, val_iterator, info): + def validate(self, val_iterator, info=None): """Runs one standard validation pass over the val_iterator. This will call ``model.eval()`` and ``torch.no_grad`` when iterating @@ -648,8 +656,8 @@ def validate(self, val_iterator, info): Args: val_iterator (iter): Iterable constructed from the validation dataloader. - info: (dict): Dictionary for information to be used for custom - validation operations. + info: (Optional[dict]): Dictionary for information to be used for + custom validation operations. Returns: A dict of metrics from the evaluation. @@ -662,6 +670,8 @@ def validate(self, val_iterator, info): raise RuntimeError("Either set self.model in setup function or " "override this method to implement a custom " "validation loop.") + + info = info or {} model = self.model metric_meters = AverageMeterCollection() @@ -1151,13 +1161,13 @@ def schedulers(self): def get_test_operator(operator_cls): class _TestingOperator(operator_cls): - def train_epoch(self, iterator, info): + def train_epoch(self, iterator, info, **kwargs): func = self.config.get("custom_func") if callable(func): return func(self, iterator, info) return {"done": 1} - def validate(self, iterator, info): + def validate(self, iterator, info, **kwargs): return self.train_epoch(iterator, info) return _TestingOperator diff --git a/python/ray/util/sgd/v2/BUILD b/python/ray/util/sgd/v2/BUILD index 1f3bb55976689..7081a53b75591 100644 --- a/python/ray/util/sgd/v2/BUILD +++ b/python/ray/util/sgd/v2/BUILD @@ -24,6 +24,16 @@ py_test( "--max_train_steps=2", "--start_local", "--num_workers=2"] ) +py_test( + name = "tune_cifar_pytorch_pbt_example", + size = "medium", + main = "examples/tune_cifar_pytorch_pbt_example.py", + srcs = ["examples/tune_cifar_pytorch_pbt_example.py"], + tags = ["team:ml", "exclusive", "pytorch"], + deps = [":sgd_v2_lib"], + args = ["--smoke-test"] +) + py_test( name = "tune_linear_example", size = "medium", @@ -47,6 +57,14 @@ py_test( deps = [":sgd_v2_lib"] ) +py_test( + name = "test_gpu", + size = "medium", + srcs = ["tests/test_gpu.py"], + tags = ["team:ml", "exclusive", "gpu_only"], + deps = [":sgd_v2_lib"] +) + py_test( name = "test_session", size = "small", @@ -71,6 +89,15 @@ py_test( deps = [":sgd_v2_lib"] ) +py_test( + name = "test_utils", + size = "small", + srcs = ["tests/test_utils.py"], + tags = ["team:ml", "exclusive"], + deps = [":sgd_v2_lib"] +) + + py_test( name = "test_worker_group", size = "medium", diff --git a/python/ray/util/sgd/v2/__init__.py b/python/ray/util/sgd/v2/__init__.py index 49d68ce97309d..8fb122c160345 100644 --- a/python/ray/util/sgd/v2/__init__.py +++ b/python/ray/util/sgd/v2/__init__.py @@ -8,6 +8,6 @@ __all__ = [ "BackendConfig", "CheckpointStrategy", "HorovodConfig", "load_checkpoint", - "local_rank", "report", "save_checkpoint", "SGDCallback", "SGDIterator", - "TensorflowConfig", "TorchConfig", "Trainer", "world_rank" + "local_rank", "report", "save_checkpoint", "SGDIterator", + "TensorflowConfig", "SGDCallback", "TorchConfig", "Trainer", "world_rank" ] diff --git a/python/ray/util/sgd/v2/backends/backend.py b/python/ray/util/sgd/v2/backends/backend.py index 24d8b59f1e413..4feec51a5eb08 100644 --- a/python/ray/util/sgd/v2/backends/backend.py +++ b/python/ray/util/sgd/v2/backends/backend.py @@ -12,7 +12,7 @@ from ray.util.sgd.v2.checkpoint import CheckpointStrategy from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \ TUNE_INSTALLED, TUNE_CHECKPOINT_FILE_NAME, \ - TUNE_CHECKPOINT_ID + TUNE_CHECKPOINT_ID, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV from ray.util.sgd.v2.session import TrainingResultType, TrainingResult from ray.util.sgd.v2.session import init_session, get_session, shutdown_session from ray.util.sgd.v2.utils import construct_path, check_for_failure @@ -275,15 +275,21 @@ def start(self, if initialization_hook: self._initialization_hook = initialization_hook self.worker_group.execute(initialization_hook) - if self._num_gpus_per_worker > 0: - self._setup_gpus() + + share_cuda_visible_devices_enabled = bool( + env_integer(ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, + self._backend.share_cuda_visible_devices)) + + if (self._num_gpus_per_worker > 0 + and share_cuda_visible_devices_enabled): + self._share_cuda_visible_devices() self._backend.on_start(self.worker_group, self._backend_config) except RayActorError as exc: logger.exception(str(exc)) self._increment_failures() self._restart() - def _setup_gpus(self): + def _share_cuda_visible_devices(self): """Sets CUDA_VISIBLE_DEVICES on all workers. For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs @@ -685,6 +691,18 @@ def _increment_failures(self): class Backend(metaclass=abc.ABCMeta): + """Metaclass for distributed communication backend. + + Attributes: + share_cuda_visible_devices (bool): If True, each worker + process will have CUDA_VISIBLE_DEVICES set as the visible device + IDs of all workers on the same node for this training instance. + If False, each worker will have CUDA_VISIBLE_DEVICES set to the + device IDs allocated by Ray for that worker. + """ + + share_cuda_visible_devices: bool = False + def on_start(self, worker_group: WorkerGroup, backend_config: BackendConfig): """Logic for starting this backend.""" diff --git a/python/ray/util/sgd/v2/backends/horovod.py b/python/ray/util/sgd/v2/backends/horovod.py index 4f424d5212dec..4382130ae5749 100644 --- a/python/ray/util/sgd/v2/backends/horovod.py +++ b/python/ray/util/sgd/v2/backends/horovod.py @@ -52,6 +52,8 @@ def init_env_vars(world_rank: int, world_size: int, node_id: str): class HorovodBackend(Backend): + share_cuda_visible_devices: bool = True + def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig): diff --git a/python/ray/util/sgd/v2/backends/torch.py b/python/ray/util/sgd/v2/backends/torch.py index 7d76b179c8d2d..1d1f0d39f366f 100644 --- a/python/ray/util/sgd/v2/backends/torch.py +++ b/python/ray/util/sgd/v2/backends/torch.py @@ -92,6 +92,8 @@ def shutdown_torch(destroy_process_group=False): class TorchBackend(Backend): + share_cuda_visible_devices: bool = True + def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig): if len(worker_group) > 1 and dist.is_available(): # Set the appropriate training backend. diff --git a/python/ray/util/sgd/v2/constants.py b/python/ray/util/sgd/v2/constants.py index 6ebd428f7b1cb..b0dc39e9cbfbc 100644 --- a/python/ray/util/sgd/v2/constants.py +++ b/python/ray/util/sgd/v2/constants.py @@ -44,3 +44,7 @@ # This needs to be added to the checkpoint dictionary so if the Tune trial # is restarted, the checkpoint_id can continue to increment. TUNE_CHECKPOINT_ID = "_current_checkpoint_id" + +# Integer value which if set will override the value of +# Backend.share_cuda_visible_devices. 1 for True, 0 for False. +ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV = "SGD_ENABLE_SHARE_CUDA_VISIBLE_DEVICES" diff --git a/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py b/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py index f87380cf9ce16..c299808c916aa 100644 --- a/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py +++ b/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py @@ -72,7 +72,7 @@ def train_func(config): return results -def train_tensorflow_mnist(num_workers=1, use_gpu=False): +def train_tensorflow_mnist(num_workers=2, use_gpu=False): trainer = Trainer( backend="tensorflow", num_workers=num_workers, use_gpu=use_gpu) trainer.start() @@ -98,7 +98,7 @@ def train_tensorflow_mnist(num_workers=1, use_gpu=False): "--num-workers", "-n", type=int, - default=1, + default=2, help="Sets number of workers for training.") parser.add_argument( "--use-gpu", diff --git a/python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py b/python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py new file mode 100644 index 0000000000000..1ff8054be367c --- /dev/null +++ b/python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py @@ -0,0 +1,200 @@ +import numpy as np +import argparse +from filelock import FileLock + +import ray +from ray import tune +from ray.tune import CLIReporter +from ray.tune.schedulers import PopulationBasedTraining + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, DistributedSampler, Subset +from torchvision.datasets import CIFAR10 +import torchvision.transforms as transforms +from torch.nn.parallel import DistributedDataParallel + +from ray.util.sgd.torch.resnet import ResNet18 + +import ray.util.sgd.v2 as sgd +from ray.util.sgd.v2 import Trainer + + +def train(dataloader, model, loss_fn, optimizer, device): + size = len(dataloader.dataset) + for batch, (X, y) in enumerate(dataloader): + X, y = X.to(device), y.to(device) + + # Compute prediction error + pred = model(X) + loss = loss_fn(pred, y) + + # Backpropagation + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if batch % 100 == 0: + loss, current = loss.item(), batch * len(X) + print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") + + +def validate(dataloader, model, loss_fn, device): + size = len(dataloader.dataset) + num_batches = len(dataloader) + model.eval() + test_loss, correct = 0, 0 + with torch.no_grad(): + for X, y in dataloader: + X, y = X.to(device), y.to(device) + pred = model(X) + test_loss += loss_fn(pred, y).item() + correct += (pred.argmax(1) == y).type(torch.float).sum().item() + test_loss /= num_batches + correct /= size + print(f"Test Error: \n " + f"Accuracy: {(100 * correct):>0.1f}%, " + f"Avg loss: {test_loss:>8f} \n") + return {"loss": test_loss} + + +def train_func(config): + device = torch.device(f"cuda:{sgd.local_rank()}" + if torch.cuda.is_available() else "cpu") + + epochs = config.pop("epochs", 3) + model = ResNet18(config) + model = model.to(device) + model = DistributedDataParallel( + model, + device_ids=[device.index] if torch.cuda.is_available() else None) + + # Create optimizer. + optimizer = torch.optim.SGD( + model.parameters(), + lr=config.get("lr", 0.1), + momentum=config.get("momentum", 0.9)) + + # Load in training and validation data. + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) # meanstd transformation + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) + + with FileLock(".ray.lock"): + train_dataset = CIFAR10( + root="~/data", + train=True, + download=True, + transform=transform_train) + validation_dataset = CIFAR10( + root="~/data", + train=False, + download=False, + transform=transform_test) + + if config.get("test_mode"): + train_dataset = Subset(train_dataset, list(range(64))) + validation_dataset = Subset(validation_dataset, list(range(64))) + + train_loader = DataLoader( + train_dataset, + batch_size=config["batch_size"], + sampler=DistributedSampler(train_dataset)) + validation_loader = DataLoader( + validation_dataset, + batch_size=config["batch_size"], + sampler=DistributedSampler(validation_dataset)) + + # Create loss. + criterion = nn.CrossEntropyLoss() + + results = [] + + for _ in range(epochs): + train(train_loader, model, criterion, optimizer, device) + result = validate(validation_loader, model, criterion, device) + sgd.report(**result) + results.append(result) + + return results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--address", + required=False, + type=str, + help="the address to use for Redis") + parser.add_argument( + "--num-workers", + "-n", + type=int, + default=2, + help="Sets number of workers for training.") + parser.add_argument( + "--num-epochs", type=int, default=5, help="Number of epochs to train.") + parser.add_argument( + "--smoke-test", + action="store_true", + default=False, + help="Finish quickly for testing.") + parser.add_argument( + "--use-gpu", + action="store_true", + default=False, + help="Enables GPU training") + + args, _ = parser.parse_known_args() + if args.smoke_test: + ray.init(num_cpus=4) + else: + ray.init(address=args.address) + + trainer = Trainer( + "torch", num_workers=args.num_workers, use_gpu=args.use_gpu) + Trainable = trainer.to_tune_trainable(train_func) + pbt_scheduler = PopulationBasedTraining( + time_attr="training_iteration", + metric="loss", + mode="min", + perturbation_interval=1, + hyperparam_mutations={ + # distribution for resampling + "lr": lambda: np.random.uniform(0.001, 1), + # allow perturbations within this set of categorical values + "momentum": [0.8, 0.9, 0.99], + }) + + reporter = CLIReporter() + reporter.add_metric_column("loss", "loss") + + analysis = tune.run( + Trainable, + num_samples=4, + config={ + "lr": tune.choice([0.001, 0.01, 0.1]), + "momentum": 0.8, + "batch_size": 128 * args.num_workers, + "epochs": args.num_epochs, + "test_mode": args.smoke_test # whether to to subset the data + }, + stop={"training_iteration": 2 if args.smoke_test else 100}, + max_failures=3, # used for fault tolerance + checkpoint_freq=3, # used for fault tolerance + keep_checkpoints_num=1, # used for fault tolerance + verbose=2, + progress_reporter=reporter, + scheduler=pbt_scheduler) + + print(analysis.get_best_config(metric="loss", mode="min")) diff --git a/python/ray/util/sgd/v2/tests/test_backend.py b/python/ray/util/sgd/v2/tests/test_backend.py index 985029b808118..65ac486dd9df1 100644 --- a/python/ray/util/sgd/v2/tests/test_backend.py +++ b/python/ray/util/sgd/v2/tests/test_backend.py @@ -8,6 +8,7 @@ from ray.util.sgd import v2 as sgd from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor from ray.util.sgd.v2.backends.tensorflow import TensorflowConfig +from ray.util.sgd.v2.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV from ray.util.sgd.v2.worker_group import WorkerGroup from ray.util.sgd.v2.backends.torch import TorchConfig @@ -321,6 +322,7 @@ def get_resources(): num_workers, expected_results = worker_results + os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" e = BackendExecutor( config, num_workers=num_workers, @@ -349,6 +351,7 @@ def get_resources(): num_workers, expected_results = worker_results + os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" e = BackendExecutor( config, num_workers=num_workers, @@ -374,6 +377,7 @@ def get_resources(): num_workers, expected_results = worker_results + os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" e = BackendExecutor( config, num_workers=num_workers, diff --git a/python/ray/util/sgd/v2/tests/test_gpu.py b/python/ray/util/sgd/v2/tests/test_gpu.py new file mode 100644 index 0000000000000..845e768cd6d47 --- /dev/null +++ b/python/ray/util/sgd/v2/tests/test_gpu.py @@ -0,0 +1,92 @@ +import pytest + +import ray +from ray.util.sgd.v2 import Trainer +from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \ + horovod_torch_train_func +from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ + tensorflow_mnist_train_func +from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \ + as fashion_mnist_train_func +from test_tune import torch_fashion_mnist, tune_tensorflow_mnist + + +@pytest.fixture +def ray_start_4_cpus_2_gpus(): + address_info = ray.init(num_cpus=4, num_gpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +def test_tensorflow_mnist_gpu(ray_start_4_cpus_2_gpus): + num_workers = 2 + epochs = 3 + + trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=True) + config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} + trainer.start() + results = trainer.run(tensorflow_mnist_train_func, config) + trainer.shutdown() + + assert len(results) == num_workers + result = results[0] + + loss = result["loss"] + assert len(loss) == epochs + assert loss[-1] < loss[0] + + accuracy = result["accuracy"] + assert len(accuracy) == epochs + assert accuracy[-1] > accuracy[0] + + +def test_torch_fashion_mnist_gpu(ray_start_4_cpus_2_gpus): + num_workers = 2 + epochs = 3 + + trainer = Trainer("torch", num_workers=num_workers, use_gpu=True) + config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} + trainer.start() + results = trainer.run(fashion_mnist_train_func, config) + trainer.shutdown() + + assert len(results) == num_workers + + for result in results: + assert len(result) == epochs + assert result[-1] < result[0] + + +def test_horovod_torch_mnist_gpu(ray_start_4_cpus_2_gpus): + num_workers = 2 + num_epochs = 2 + trainer = Trainer("horovod", num_workers, use_gpu=True) + trainer.start() + results = trainer.run( + horovod_torch_train_func, + config={ + "num_epochs": num_epochs, + "lr": 1e-3 + }) + trainer.shutdown() + + assert len(results) == num_workers + for worker_result in results: + assert len(worker_result) == num_epochs + assert worker_result[num_epochs - 1] < worker_result[0] + + +def test_tune_fashion_mnist_gpu(ray_start_4_cpus_2_gpus): + torch_fashion_mnist(num_workers=2, use_gpu=True, num_samples=1) + + +def test_tune_tensorflow_mnist_gpu(ray_start_4_cpus_2_gpus): + tune_tensorflow_mnist(num_workers=2, use_gpu=True, num_samples=1) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", "-x", "-s", __file__])) diff --git a/python/ray/util/sgd/v2/tests/test_trainer.py b/python/ray/util/sgd/v2/tests/test_trainer.py index f7da6310a2a96..9795017283a0b 100644 --- a/python/ray/util/sgd/v2/tests/test_trainer.py +++ b/python/ray/util/sgd/v2/tests/test_trainer.py @@ -5,26 +5,24 @@ import horovod.torch as hvd_torch import pytest + import ray import ray.util.sgd.v2 as sgd -import tensorflow as tf -import torch from ray._private.test_utils import wait_for_condition from ray.util.sgd.v2 import Trainer, TorchConfig, TensorflowConfig, \ HorovodConfig from ray.util.sgd.v2.backends.backend import BackendConfig, Backend, \ BackendExecutor from ray.util.sgd.v2.callbacks.callback import SGDCallback +from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \ + horovod_torch_train_func, HorovodTrainClass +from ray.util.sgd.v2.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ tensorflow_mnist_train_func from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \ - as \ - fashion_mnist_train_func + as fashion_mnist_train_func from ray.util.sgd.v2.examples.train_linear_example import train_func as \ linear_train_func - -from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \ - horovod_torch_train_func, HorovodTrainClass from ray.util.sgd.v2.worker_group import WorkerGroup @@ -498,31 +496,6 @@ def test_tensorflow_mnist(ray_start_2_cpus): assert accuracy[-1] > accuracy[0] -@pytest.mark.skipif( - len(tf.config.list_physical_devices("GPU")) < 2, - reason="Only run if multiple GPUs are available.") -def test_tensorflow_mnist_gpu(ray_start_2_cpus_2_gpus): - num_workers = 2 - epochs = 3 - - trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=True) - config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} - trainer.start() - results = trainer.run(tensorflow_mnist_train_func, config) - trainer.shutdown() - - assert len(results) == num_workers - result = results[0] - - loss = result["loss"] - assert len(loss) == epochs - assert loss[-1] < loss[0] - - accuracy = result["accuracy"] - assert len(accuracy) == epochs - assert accuracy[-1] > accuracy[0] - - def test_torch_linear(ray_start_2_cpus): num_workers = 2 epochs = 3 @@ -557,26 +530,6 @@ def test_torch_fashion_mnist(ray_start_2_cpus): assert result[-1] < result[0] -@pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="Only run if multiple GPUs are available.") -def test_torch_fashion_mnist_gpu(ray_start_2_cpus_2_gpus): - num_workers = 2 - epochs = 3 - - trainer = Trainer("torch", num_workers=num_workers, use_gpu=True) - config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} - trainer.start() - results = trainer.run(fashion_mnist_train_func, config) - trainer.shutdown() - - assert len(results) == num_workers - - for result in results: - assert len(result) == epochs - assert result[-1] < result[0] - - def test_horovod_simple(ray_start_2_cpus): def simple_fn(): hvd_torch.init() @@ -610,28 +563,6 @@ def test_horovod_torch_mnist(ray_start_2_cpus): assert worker_result[num_epochs - 1] < worker_result[0] -@pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="Only run if multiple GPUs are available.") -def test_horovod_torch_mnist_gpu(ray_start_2_cpus_2_gpus): - num_workers = 2 - num_epochs = 2 - trainer = Trainer("horovod", num_workers, use_gpu=True) - trainer.start() - results = trainer.run( - horovod_torch_train_func, - config={ - "num_epochs": num_epochs, - "lr": 1e-3 - }) - trainer.shutdown() - - assert len(results) == num_workers - for worker_result in results: - assert len(worker_result) == num_epochs - assert worker_result[num_epochs - 1] < worker_result[0] - - def test_horovod_torch_mnist_stateful(ray_start_2_cpus): num_workers = 2 num_epochs = 2 @@ -986,7 +917,6 @@ def test_resources(ray_start_4_cpus_4_gpus_4_extra, resource, num_requested): def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra): - # GPUs should not be requested if `use_gpu` is False. with pytest.raises(ValueError): Trainer( @@ -1006,6 +936,8 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra): def get_resources(): return os.environ["CUDA_VISIBLE_DEVICES"] + os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" + # 0 GPUs will be requested and should not raise an error. trainer = Trainer(TestConfig(), num_workers=2, use_gpu=False) trainer.start() diff --git a/python/ray/util/sgd/v2/tests/test_tune.py b/python/ray/util/sgd/v2/tests/test_tune.py index fb9d39b6df8b0..0ec1db59542f8 100644 --- a/python/ray/util/sgd/v2/tests/test_tune.py +++ b/python/ray/util/sgd/v2/tests/test_tune.py @@ -1,18 +1,13 @@ import os import pytest - -import torch -import tensorflow as tf - import ray +import ray.util.sgd.v2 as sgd from ray import tune, cloudpickle from ray.tune import TuneError - -import ray.util.sgd.v2 as sgd from ray.util.sgd.v2 import Trainer -from ray.util.sgd.v2.constants import TUNE_CHECKPOINT_FILE_NAME from ray.util.sgd.v2.backends.backend import Backend, BackendConfig +from ray.util.sgd.v2.constants import TUNE_CHECKPOINT_FILE_NAME from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ tensorflow_mnist_train_func from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \ @@ -28,14 +23,6 @@ def ray_start_2_cpus(): ray.shutdown() -@pytest.fixture -def ray_start_4_cpus_4_gpus(): - address_info = ray.init(num_cpus=2, num_gpus=2) - yield address_info - # The code after the yield will run as teardown code. - ray.shutdown() - - @pytest.fixture def ray_start_8_cpus(): address_info = ray.init(num_cpus=8) @@ -83,13 +70,6 @@ def test_tune_torch_fashion_mnist(ray_start_8_cpus): torch_fashion_mnist(num_workers=2, use_gpu=False, num_samples=2) -@pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="Only run if multiple GPUs are available.") -def test_tune_fashion_mnist_gpu(ray_start_4_cpus_4_gpus): - torch_fashion_mnist(num_workers=2, use_gpu=True, num_samples=1) - - def tune_tensorflow_mnist(num_workers, use_gpu, num_samples): epochs = 2 trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=use_gpu) @@ -113,13 +93,6 @@ def test_tune_tensorflow_mnist(ray_start_8_cpus): tune_tensorflow_mnist(num_workers=2, use_gpu=False, num_samples=2) -@pytest.mark.skipif( - len(tf.config.list_physical_devices("GPU")) < 2, - reason="Only run if multiple GPUs are available.") -def test_tune_tensorflow_mnist_gpu(ray_start_4_cpus_4_gpus): - tune_tensorflow_mnist(num_workers=2, use_gpu=True, num_samples=1) - - def test_tune_error(ray_start_2_cpus): def train_func(config): raise RuntimeError("Error in training function!") diff --git a/python/ray/util/tracing/tracing_helper.py b/python/ray/util/tracing/tracing_helper.py index 73fb61c00767c..68696fe29c46d 100644 --- a/python/ray/util/tracing/tracing_helper.py +++ b/python/ray/util/tracing/tracing_helper.py @@ -290,6 +290,8 @@ def _invocation_remote_span( # If tracing feature flag is not on, perform a no-op. # Tracing doesn't work for cross lang yet. if not is_tracing_enabled() or self._is_cross_language: + if kwargs is not None: + assert "_ray_trace_ctx" not in kwargs return method(self, args, kwargs, *_args, **_kwargs) assert "_ray_trace_ctx" not in kwargs @@ -365,8 +367,7 @@ def _invocation_actor_class_remote_span( # If tracing feature flag is not on, perform a no-op if not is_tracing_enabled(): - if not self.__ray_metadata__.is_cross_language: - kwargs["_ray_trace_ctx"] = None + assert "_ray_trace_ctx" not in kwargs return method(self, args, kwargs, *_args, **_kwargs) class_name = self.__ray_metadata__.class_name @@ -404,6 +405,8 @@ def _start_span( # If tracing feature flag is not on, perform a no-op if (not is_tracing_enabled() or self._actor_ref()._ray_is_cross_language): + if kwargs is not None: + assert "_ray_trace_ctx" not in kwargs return method(self, args, kwargs, *_args, **_kwargs) class_name = (self._actor_ref() diff --git a/python/ray/worker.py b/python/ray/worker.py index 9f5dd31ca6da3..97849d8d2750f 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -191,8 +191,8 @@ def current_session_and_job(self): @property def runtime_env(self): """Get the runtime env in json format""" - return json.loads( - self.core_worker.get_job_config().runtime_env.raw_json) + return json.loads(self.core_worker.get_job_config() + .runtime_env.serialized_runtime_env) def get_serialization_context(self, job_id=None): """Get the SerializationContext of the job that this worker is processing. @@ -223,9 +223,6 @@ def check_connected(self): Exception: An exception is raised if the worker is not connected. """ if not self.connected: - if os.environ.get("RAY_ENABLE_AUTO_CONNECT", "") != "0": - ray.client().connect() - return raise RaySystemError("Ray has not been started yet. You can " "start Ray with 'ray.init()'.") @@ -479,7 +476,7 @@ def print_logs(self): @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def get_gpu_ids(): """Get the IDs of the GPUs that are available to the worker. @@ -576,7 +573,7 @@ def get_dashboard_url(): @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def init( address: Optional[str] = None, *, @@ -605,7 +602,6 @@ def init( _memory: Optional[int] = None, _redis_password: str = ray_constants.REDIS_DEFAULT_PASSWORD, _temp_dir: Optional[str] = None, - _lru_evict: bool = False, _metrics_export_port: Optional[int] = None, _system_config: Optional[Dict[str, str]] = None, _tracing_startup_hook: Optional[Callable] = None, @@ -882,7 +878,6 @@ def init( start_initial_python_workers_for_first_job=( job_config is None or job_config.runtime_env is None), _system_config=_system_config, - lru_evict=_lru_evict, enable_object_reconstruction=_enable_object_reconstruction, metrics_export_port=_metrics_export_port, tracing_startup_hook=_tracing_startup_hook) @@ -924,7 +919,6 @@ def init( object_ref_seed=None, temp_dir=_temp_dir, _system_config=_system_config, - lru_evict=_lru_evict, enable_object_reconstruction=_enable_object_reconstruction, metrics_export_port=_metrics_export_port) _global_node = ray.node.Node( @@ -974,7 +968,7 @@ def init( @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def shutdown(_exiting_interpreter: bool = False): """Disconnect the worker, and terminate processes started by ray.init(). @@ -1240,7 +1234,7 @@ def listen_error_messages_raylet(worker, threads_stopped): @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def is_initialized() -> bool: """Check if ray.init has been called yet. @@ -1559,7 +1553,7 @@ def show_in_dashboard(message: str, key: str = "", dtype: str = "text"): @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def get(object_refs: Union[ray.ObjectRef, List[ray.ObjectRef]], *, timeout: Optional[float] = None) -> Union[Any, List[Any]]: @@ -1648,7 +1642,7 @@ def get(object_refs: Union[ray.ObjectRef, List[ray.ObjectRef]], @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def put(value: Any, *, _owner: Optional["ray.actor.ActorHandle"] = None) -> ray.ObjectRef: """Store an object in the object store. @@ -1702,7 +1696,7 @@ def put(value: Any, *, @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def wait(object_refs: List[ray.ObjectRef], *, num_returns: int = 1, @@ -1809,7 +1803,7 @@ def wait(object_refs: List[ray.ObjectRef], @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def get_actor(name: str, namespace: Optional[str] = None) -> "ray.actor.ActorHandle": """Get a handle to a named actor. @@ -1841,7 +1835,7 @@ def get_actor(name: str, @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True): """Kill an actor forcefully. @@ -1870,7 +1864,7 @@ def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True): @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def cancel(object_ref: ray.ObjectRef, *, force: bool = False, @@ -1932,6 +1926,7 @@ def make_decorator(num_returns=None, max_restarts=None, max_task_retries=None, runtime_env=None, + placement_group="default", worker=None, retry_exceptions=None): def decorator(function_or_class): @@ -1963,7 +1958,7 @@ def decorator(function_or_class): Language.PYTHON, function_or_class, None, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type, num_returns, max_calls, max_retries, retry_exceptions, - runtime_env) + runtime_env, placement_group) if inspect.isclass(function_or_class): if num_returns is not None: @@ -2101,15 +2096,6 @@ def method(self): retry_exceptions (bool): Only for *remote functions*. This specifies whether application-level errors should be retried up to max_retries times. - override_environment_variables (Dict[str, str]): (Deprecated in Ray - 1.4.0, will be removed in Ray 1.6--please use the ``env_vars`` - field of :ref:`runtime-environments` instead.) This specifies - environment variables to override for the actor or task. The - overrides are propagated to all child actors and tasks. This - is a dictionary mapping variable names to their values. Existing - variables can be overridden, new ones can be created, and an - existing variable can be unset by setting it to an empty string. - Note: can only be set via `.options()`. """ worker = global_worker @@ -2121,7 +2107,8 @@ def method(self): valid_kwargs = [ "num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory", "resources", "accelerator_type", "max_calls", "max_restarts", - "max_task_retries", "max_retries", "runtime_env", "retry_exceptions" + "max_task_retries", "max_retries", "runtime_env", "retry_exceptions", + "placement_group" ] error_string = ("The @ray.remote decorator must be applied either " "with no arguments and no parentheses, for example " @@ -2154,6 +2141,7 @@ def method(self): object_store_memory = kwargs.get("object_store_memory") max_retries = kwargs.get("max_retries") runtime_env = kwargs.get("runtime_env") + placement_group = kwargs.get("placement_group", "default") retry_exceptions = kwargs.get("retry_exceptions") return make_decorator( @@ -2169,5 +2157,6 @@ def method(self): max_task_retries=max_task_retries, max_retries=max_retries, runtime_env=runtime_env, + placement_group=placement_group, worker=worker, retry_exceptions=retry_exceptions) diff --git a/python/ray/workers/setup_worker.py b/python/ray/workers/setup_worker.py index 23fbc6e8e150d..b40737c1a8ad0 100644 --- a/python/ray/workers/setup_worker.py +++ b/python/ray/workers/setup_worker.py @@ -3,7 +3,8 @@ import logging import os -from ray._private.runtime_env import RuntimeEnvContext +from ray._private.runtime_env.context import RuntimeEnvContext +from ray.core.generated.common_pb2 import Language logger = logging.getLogger(__name__) @@ -26,6 +27,9 @@ type=str, help="the worker allocated resource") +parser.add_argument( + "--language", type=str, help="the language type of the worker") + def get_tmp_dir(remaining_args): for arg in remaining_args: @@ -117,5 +121,5 @@ def start_worker_in_container(container_option, args, remaining_args): # probably not even go through this codepath. runtime_env_context = RuntimeEnvContext.deserialize( args.serialized_runtime_env_context or "{}") - - runtime_env_context.exec_worker(remaining_args) + runtime_env_context.exec_worker(remaining_args, + Language.Value(args.language)) diff --git a/python/ray/workflow/common.py b/python/ray/workflow/common.py index e883cdfacd0b4..169b318ed5d76 100644 --- a/python/ray/workflow/common.py +++ b/python/ray/workflow/common.py @@ -32,7 +32,8 @@ def get_qualname(f): def ensure_ray_initialized(): - ray.worker.global_worker.check_connected() + if not ray.is_initialized(): + ray.init() @dataclass diff --git a/python/ray/workflow/execution.py b/python/ray/workflow/execution.py index 6de22cef05943..b660c65fa8a9e 100644 --- a/python/ray/workflow/execution.py +++ b/python/ray/workflow/execution.py @@ -32,8 +32,9 @@ def run(entry_workflow: Workflow, # Workflow ID format: {Entry workflow UUID}.{Unix time to nanoseconds} workflow_id = f"{str(uuid.uuid4())}.{time.time():.9f}" - logger.info(f"Workflow job created. [id=\"{workflow_id}\", storage_url=" - f"\"{store.storage_url}\"].") + logger.info( + f"Workflow job created. [id=\"{workflow_id}\", storage_url=" + f"\"{store.storage_url}\"]. Type: {entry_workflow.data.step_type} ") with workflow_context.workflow_step_context(workflow_id, store.storage_url): @@ -51,7 +52,7 @@ def run(entry_workflow: Workflow, # - it's a new workflow # TODO (yic): follow up with force rerun if entry_workflow.data.step_type != StepType.FUNCTION or not wf_exists: - commit_step(ws, "", entry_workflow, None) + commit_step(ws, "", entry_workflow, exception=None) workflow_manager = get_or_create_management_actor() ignore_existing = (entry_workflow.data.step_type != StepType.FUNCTION) # NOTE: It is important to 'ray.get' the returned output. This diff --git a/python/ray/workflow/recovery.py b/python/ray/workflow/recovery.py index 58902b4419681..8c64c2cba4100 100644 --- a/python/ray/workflow/recovery.py +++ b/python/ray/workflow/recovery.py @@ -51,8 +51,8 @@ def _recover_workflow_step(args: List[Any], kwargs: Dict[str, Any], def _construct_resume_workflow_from_step( - reader: workflow_storage.WorkflowStorage, - step_id: StepID) -> Union[Workflow, StepID]: + reader: workflow_storage.WorkflowStorage, step_id: StepID, + input_map: Dict[StepID, Any]) -> Union[Workflow, StepID]: """Try to construct a workflow (step) that recovers the workflow step. If the workflow step already has an output checkpointing file, we return the workflow step id instead. @@ -60,6 +60,8 @@ def _construct_resume_workflow_from_step( Args: reader: The storage reader for inspecting the step. step_id: The ID of the step we want to recover. + input_map: This is a context storing the input which has been loaded. + This context is important for dedupe Returns: A workflow that recovers the step, or a ID of a step @@ -70,8 +72,8 @@ def _construct_resume_workflow_from_step( # we already have the output return step_id if isinstance(result.output_step_id, str): - return _construct_resume_workflow_from_step(reader, - result.output_step_id) + return _construct_resume_workflow_from_step( + reader, result.output_step_id, input_map) # output does not exists or not valid. try to reconstruct it. if not result.is_recoverable(): raise WorkflowStepNotRecoverableError(step_id) @@ -79,7 +81,14 @@ def _construct_resume_workflow_from_step( with serialization.objectref_cache(): input_workflows = [] for i, _step_id in enumerate(result.workflows): - r = _construct_resume_workflow_from_step(reader, _step_id) + # Check whether the step has been loaded or not to avoid + # duplication + if _step_id in input_map: + r = input_map[_step_id] + else: + r = _construct_resume_workflow_from_step( + reader, _step_id, input_map) + input_map[_step_id] = r if isinstance(r, Workflow): input_workflows.append(r) else: @@ -119,15 +128,15 @@ def _resume_workflow_step_executor(workflow_id: str, step_id: "StepID", try: store = storage.create_storage(store_url) wf_store = workflow_storage.WorkflowStorage(workflow_id, store) - r = _construct_resume_workflow_from_step(wf_store, step_id) + r = _construct_resume_workflow_from_step(wf_store, step_id, {}) except Exception as e: raise WorkflowNotResumableError(workflow_id) from e if isinstance(r, Workflow): - with workflow_context.workflow_step_context(workflow_id, - store.storage_url): - from ray.workflow.step_executor import (execute_workflow) - result = execute_workflow(r, last_step_of_workflow=True) + with workflow_context.workflow_step_context( + workflow_id, store.storage_url, last_step_of_workflow=True): + from ray.workflow.step_executor import execute_workflow + result = execute_workflow(r) return result.persisted_output, result.volatile_output assert isinstance(r, StepID) return wf_store.load_step_output(r), None diff --git a/python/ray/workflow/step_executor.py b/python/ray/workflow/step_executor.py index 878c7b40bf451..b5416b5a40218 100644 --- a/python/ray/workflow/step_executor.py +++ b/python/ray/workflow/step_executor.py @@ -134,33 +134,9 @@ def _resolve_step_inputs( return signature.recover_args(flattened_args) -def execute_workflow( - workflow: "Workflow", - outer_most_step_id: Optional[str] = None, - last_step_of_workflow: bool = False) -> "WorkflowExecutionResult": +def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult": """Execute workflow. - To fully explain what we are doing, we need to introduce some syntax first. - The syntax for dependencies between workflow steps - "B.step(A.step())" is "A - B"; the syntax for nested workflow steps - "def A(): return B.step()" is "A / B". - - In a chain/DAG of step dependencies, the "output step" is the step of last - (topological) order. For example, in "A - B - C", C is the output step. - - In a chain of nested workflow steps, the initial "output step" is - called the "outer most step" for other "output steps". For example, in - "A / B / C / D", "A" is the outer most step for "B", "C", "D"; - in the hybrid workflow "((A - B) / C / D) - (E / (F - G) / H)", - "B" is the outer most step for "C", "D"; "E" is the outer most step - for "G", "H". - - Args: - workflow: The workflow to be executed. - outer_most_step_id: The ID of the outer most workflow. None if it - does not exists. - last_step_of_workflow: The step that generates the output of the - workflow (including nested steps). Returns: An object ref that represent the result. """ @@ -173,8 +149,8 @@ def execute_workflow( **workflow_data.ray_options).remote( workflow_data.step_type, workflow_data.func_body, workflow_context.get_workflow_step_context(), workflow.step_id, - baked_inputs, outer_most_step_id, workflow_data.catch_exceptions, - workflow_data.max_retries, last_step_of_workflow) + baked_inputs, workflow_data.catch_exceptions, + workflow_data.max_retries) if not isinstance(persisted_output, WorkflowOutputType): raise TypeError("Unexpected return type of the workflow.") @@ -197,7 +173,6 @@ async def _write_step_inputs(wf_storage: workflow_storage.WorkflowStorage, # TODO(suquark): in the future we should write to storage directly # with plasma store object in memory. args_obj = ray.get(inputs.inputs.args) - workflow_id = wf_storage._workflow_id storage = wf_storage._storage save_tasks = [ @@ -213,19 +188,13 @@ async def _write_step_inputs(wf_storage: workflow_storage.WorkflowStorage, await asyncio.gather(*save_tasks) -def commit_step(store: workflow_storage.WorkflowStorage, - step_id: "StepID", - ret: Union["Workflow", Any], - exception: Optional[Exception], - outer_most_step_id: Optional[str] = None): +def commit_step(store: workflow_storage.WorkflowStorage, step_id: "StepID", + ret: Union["Workflow", Any], exception: Optional[Exception]): """Checkpoint the step output. Args: store: The storage the current workflow is using. step_id: The ID of the step. ret: The returned object of the workflow step. - outer_most_step_id: The ID of the outer most workflow. None if it - does not exists. See "step_executor.execute_workflow" for detailed - explanation. """ from ray.workflow.common import Workflow if isinstance(ret, Workflow): @@ -236,7 +205,12 @@ def commit_step(store: workflow_storage.WorkflowStorage, ] asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) - store.save_step_output(step_id, ret, exception, outer_most_step_id) + context = workflow_context.get_workflow_step_context() + store.save_step_output( + step_id, + ret, + exception=exception, + outer_most_step_id=context.outer_most_step_id) def _wrap_run(func: Callable, step_type: StepType, step_id: "StepID", @@ -328,12 +302,11 @@ def _wrap_run(func: Callable, step_type: StepType, step_id: "StepID", @ray.remote(num_returns=2) -def _workflow_step_executor( - step_type: StepType, func: Callable, - context: workflow_context.WorkflowStepContext, step_id: "StepID", - baked_inputs: "_BakedWorkflowInputs", outer_most_step_id: "StepID", - catch_exceptions: bool, max_retries: int, - last_step_of_workflow: bool) -> Any: +def _workflow_step_executor(step_type: StepType, func: Callable, + context: workflow_context.WorkflowStepContext, + step_id: "StepID", + baked_inputs: "_BakedWorkflowInputs", + catch_exceptions: bool, max_retries: int) -> Any: """Executor function for workflow step. Args: @@ -342,13 +315,9 @@ def _workflow_step_executor( context: Workflow step context. Used to access correct storage etc. step_id: The ID of the step. baked_inputs: The processed inputs for the step. - outer_most_step_id: See "step_executor.execute_workflow" for - explanation. catch_exceptions: If set to be true, return (Optional[Result], Optional[Error]) instead of Result. max_retries: Max number of retries encounter of a failure. - last_step_of_workflow: The step that generates the output of the - workflow (including nested steps). Returns: Workflow step output. @@ -361,7 +330,7 @@ def _workflow_step_executor( func, step_type, step_id, catch_exceptions, max_retries, *args, **kwargs) except Exception as e: - commit_step(store, step_id, None, e, outer_most_step_id) + commit_step(store, step_id, None, e) raise e if step_type == StepType.READONLY_ACTOR_METHOD: if isinstance(volatile_output, Workflow): @@ -371,26 +340,28 @@ def _workflow_step_executor( assert not isinstance(persisted_output, Workflow) else: store = workflow_storage.get_workflow_storage() - commit_step(store, step_id, persisted_output, None, outer_most_step_id) + commit_step(store, step_id, persisted_output, None) + outer_most_step_id = context.outer_most_step_id if isinstance(persisted_output, Workflow): if step_type == StepType.FUNCTION: # Passing down outer most step so inner nested steps would # access the same outer most step. - if not outer_most_step_id: + if not context.outer_most_step_id: # The current workflow step returns a nested workflow, and # there is no outer step for the current step. So the # current step is the outer most step for the inner nested # workflow steps. outer_most_step_id = workflow_context.get_current_step_id() assert volatile_output is None - # execute sub-workflow - result = execute_workflow(persisted_output, outer_most_step_id, - last_step_of_workflow) + # Execute sub-workflow. Pass down "outer_most_step_id". + with workflow_context.fork_workflow_step_context( + outer_most_step_id=outer_most_step_id): + result = execute_workflow(persisted_output) # When virtual actor returns a workflow in the method, # the volatile_output and persisted_output will be put together persisted_output = result.persisted_output volatile_output = result.volatile_output - elif last_step_of_workflow: + elif context.last_step_of_workflow: # advance the progress of the workflow store.advance_progress(step_id) _record_step_status(step_id, WorkflowStatus.SUCCESSFUL) @@ -415,9 +386,11 @@ class _BakedWorkflowInputs: @classmethod def from_workflow_inputs(cls, inputs: "WorkflowInputs"): - workflow_outputs = [ - execute_workflow(w).persisted_output for w in inputs.workflows - ] + with workflow_context.fork_workflow_step_context( + outer_most_step_id=None, last_step_of_workflow=False): + workflow_outputs = [ + execute_workflow(w).persisted_output for w in inputs.workflows + ] return cls(inputs.args, workflow_outputs, inputs.workflow_refs) def __reduce__(self): @@ -427,7 +400,10 @@ def __reduce__(self): def _record_step_status(step_id: "StepID", status: "WorkflowStatus", - outputs: List["ObjectRef"] = []) -> None: + outputs: Optional[List["ObjectRef"]] = None) -> None: + if outputs is None: + outputs = [] + workflow_id = workflow_context.get_current_workflow_id() workflow_manager = get_management_actor() ray.get( diff --git a/python/ray/workflow/tests/test_basic_workflows_2.py b/python/ray/workflow/tests/test_basic_workflows_2.py index dad390635cab7..acecfb14dc014 100644 --- a/python/ray/workflow/tests/test_basic_workflows_2.py +++ b/python/ray/workflow/tests/test_basic_workflows_2.py @@ -1,10 +1,13 @@ +import os import pytest import ray import re from filelock import FileLock +from pathlib import Path from ray._private.test_utils import run_string_as_driver, SignalActor from ray import workflow from ray.tests.conftest import * # noqa +from unittest.mock import patch def test_init_twice(call_ray_start, reset_workflow, tmp_path): @@ -22,9 +25,11 @@ def test_init_twice(call_ray_start, reset_workflow, tmp_path): def test_init_twice_2(call_ray_start, reset_workflow, tmp_path): - run_string_as_driver(driver_script) - with pytest.raises(RuntimeError): - workflow.init(str(tmp_path)) + with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}): + run_string_as_driver(driver_script) + with pytest.raises( + RuntimeError, match=".*different from the workflow manager.*"): + workflow.init(str(tmp_path)) @pytest.mark.parametrize( @@ -285,6 +290,38 @@ def f2(*w): f.run() +def test_dedupe_indirect(workflow_start_regular, tmp_path): + counter = Path(tmp_path) / "counter.txt" + lock = Path(tmp_path) / "lock.txt" + counter.write_text("0") + + @workflow.step + def incr(): + with FileLock(str(lock)): + c = int(counter.read_text()) + c += 1 + counter.write_text(f"{c}") + + @workflow.step + def identity(a): + return a + + @workflow.step + def join(*a): + return counter.read_text() + + # Here a is passed to two steps and we need to ensure + # it's only executed once + a = incr.step() + i1 = identity.step(a) + i2 = identity.step(a) + assert "1" == join.step(i1, i2).run() + assert "2" == join.step(i1, i2).run() + # pass a multiple times + assert "3" == join.step(a, a, a, a).run() + assert "4" == join.step(a, a, a, a).run() + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/workflow/tests/test_lifetime.py b/python/ray/workflow/tests/test_lifetime.py index 8d12399369ac9..64a519fa19a48 100644 --- a/python/ray/workflow/tests/test_lifetime.py +++ b/python/ray/workflow/tests/test_lifetime.py @@ -1,3 +1,4 @@ +import os import ray import time import pytest @@ -5,6 +6,7 @@ run_string_as_driver) from ray.tests.conftest import * # noqa from ray import workflow +from unittest.mock import patch driver_script = """ import time @@ -29,21 +31,23 @@ def foo(x): def test_workflow_lifetime_1(call_ray_start, reset_workflow): # Case 1: driver exits normally - run_string_as_driver(driver_script.format(5)) - workflow.init() - output = workflow.get_output("driver_terminated") - assert ray.get(output) == 20 + with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}): + run_string_as_driver(driver_script.format(5)) + workflow.init() + output = workflow.get_output("driver_terminated") + assert ray.get(output) == 20 def test_workflow_lifetime_2(call_ray_start, reset_workflow): # Case 2: driver terminated - proc = run_string_as_driver_nonblocking(driver_script.format(100)) - time.sleep(10) - proc.kill() - time.sleep(1) - workflow.init() - output = workflow.get_output("driver_terminated") - assert ray.get(output) == 20 + with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}): + proc = run_string_as_driver_nonblocking(driver_script.format(100)) + time.sleep(10) + proc.kill() + time.sleep(1) + workflow.init() + output = workflow.get_output("driver_terminated") + assert ray.get(output) == 20 if __name__ == "__main__": diff --git a/python/ray/workflow/workflow_access.py b/python/ray/workflow/workflow_access.py index 0524637cf08da..c1b5d78d253a0 100644 --- a/python/ray/workflow/workflow_access.py +++ b/python/ray/workflow/workflow_access.py @@ -327,8 +327,8 @@ def load(wf_store, workflow_id, step_id): actor = get_management_actor() return actor.get_output.remote(workflow_id, result.output_step_id) - raise ValueError( - f"No such step id {step_id} in workflow {workflow_id}") + raise ValueError(f"Cannot load output from step id {step_id} " + f"in workflow {workflow_id}") return ray.put( _SelfDereferenceObject(None, diff --git a/python/ray/workflow/workflow_context.py b/python/ray/workflow/workflow_context.py index ffbeaafb6ce7f..7dec0937695f5 100644 --- a/python/ray/workflow/workflow_context.py +++ b/python/ray/workflow/workflow_context.py @@ -1,40 +1,58 @@ +from dataclasses import dataclass, field import logging -from typing import Optional, List +from typing import Optional, List, TYPE_CHECKING from contextlib import contextmanager from ray.workflow.common import WorkflowStatus logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from python.ray.workflow.common import StepID + +@dataclass class WorkflowStepContext: - def __init__(self, - workflow_id: str = None, - storage_url: str = None, - workflow_scope: List[str] = None): - """ - The structure for saving workflow step context. The context provides - critical info (e.g. where to checkpoint, which is its parent step) - for the step to execute correctly. - - Args: - workflow_id: The workflow job ID. - storage_url: The storage of the workflow, used for checkpointing. - workflow_scope: The "calling stack" of the current workflow step. - It describe the parent workflow steps. - """ - self.workflow_id = workflow_id - self.storage_url = storage_url - self.workflow_scope = workflow_scope or [] - - def __reduce__(self): - return WorkflowStepContext, (self.workflow_id, self.storage_url, - self.workflow_scope) + """ + The structure for saving workflow step context. The context provides + critical info (e.g. where to checkpoint, which is its parent step) + for the step to execute correctly. + + To fully explain what we are doing, we need to introduce some syntax + first. The syntax for dependencies between workflow steps + "B.step(A.step())" is "A - B"; the syntax for nested workflow steps + "def A(): return B.step()" is "A / B". + + In a chain/DAG of step dependencies, the "output step" is the step of + last (topological) order. For example, in "A - B - C", C is the + output step. + + In a chain of nested workflow steps, the initial "output step" is + called the "outer most step" for other "output steps". For example, in + "A / B / C / D", "A" is the outer most step for "B", "C", "D"; + in the hybrid workflow "((A - B) / C / D) - (E / (F - G) / H)", + "B" is the outer most step for "C", "D"; "E" is the outer most step + for "G", "H". + """ + # ID of the workflow. + workflow_id: Optional[str] = None + # The storage of the workflow, used for checkpointing. + storage_url: Optional[str] = None + # The "calling stack" of the current workflow step. It describe + # the parent workflow steps. + workflow_scope: List[str] = field(default_factory=list) + # The ID of the outer most workflow. "None" if it does not exists. + outer_most_step_id: "Optional[StepID]" = None + # The step that generates the output of the workflow (including all + # nested steps). + last_step_of_workflow: bool = False _context: Optional[WorkflowStepContext] = None @contextmanager -def workflow_step_context(workflow_id, storage_url) -> None: +def workflow_step_context(workflow_id, + storage_url, + last_step_of_workflow=False) -> None: """Initialize the workflow step context. Args: @@ -45,7 +63,48 @@ def workflow_step_context(workflow_id, storage_url) -> None: original_context = _context assert workflow_id is not None try: - _context = WorkflowStepContext(workflow_id, storage_url) + _context = WorkflowStepContext( + workflow_id, + storage_url, + last_step_of_workflow=last_step_of_workflow) + yield + finally: + _context = original_context + + +_sentinel = object() + + +@contextmanager +def fork_workflow_step_context( + workflow_id: Optional[str] = _sentinel, + storage_url: Optional[str] = _sentinel, + workflow_scope: Optional[List[str]] = _sentinel, + outer_most_step_id: Optional[str] = _sentinel, + last_step_of_workflow: Optional[bool] = _sentinel): + """Fork the workflow step context. + Inherits the original value if no value is provided. + + Args: + workflow_id: The ID of the workflow. + storage_url: The storage the workflow is using. + """ + global _context + original_context = _context + assert workflow_id is not None + try: + _context = WorkflowStepContext( + workflow_id=original_context.workflow_id + if workflow_id is _sentinel else workflow_id, + storage_url=original_context.storage_url + if storage_url is _sentinel else storage_url, + workflow_scope=original_context.workflow_scope + if workflow_scope is _sentinel else workflow_scope, + outer_most_step_id=original_context.outer_most_step_id + if outer_most_step_id is _sentinel else outer_most_step_id, + last_step_of_workflow=original_context.last_step_of_workflow + if last_step_of_workflow is _sentinel else last_step_of_workflow, + ) yield finally: _context = original_context diff --git a/python/ray/workflow/workflow_storage.py b/python/ray/workflow/workflow_storage.py index bf18f471483de..5a188cca1a3f2 100644 --- a/python/ray/workflow/workflow_storage.py +++ b/python/ray/workflow/workflow_storage.py @@ -117,9 +117,9 @@ def load_step_output(self, step_id: StepID) -> Any: # In this case, there is no such step raise output_err - def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], + def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], *, exception: Optional[Exception], - outer_most_step_id: Optional[StepID]) -> None: + outer_most_step_id: StepID) -> None: """When a workflow step returns, 1. If the returned object is a workflow, this means we are a nested workflow. We save the output metadata that points to the workflow. @@ -130,8 +130,7 @@ def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], it means we are in the workflow job driver process. ret: The returned object from a workflow step. exception: This step should throw exception. - outer_most_step_id: See - "step_executor.execute_workflow" for explanation. + outer_most_step_id: See WorkflowStepContext. """ tasks = [] if isinstance(ret, Workflow): @@ -154,14 +153,9 @@ def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], # tasks.append(self._put(self._key_step_output(step_id), ret)) dynamic_output_id = step_id # TODO (yic): Delete exception file - - # outer_most_step_id == "" indicates the root step of a - # workflow. This would directly update "outputs.json" in - # the workflow dir, and we want to avoid it. - if outer_most_step_id is not None and outer_most_step_id != "": - tasks.append( - self._update_dynamic_output(outer_most_step_id, - dynamic_output_id)) + tasks.append( + self._update_dynamic_output(outer_most_step_id, + dynamic_output_id)) else: assert ret is None promise = serialization.dump_to_storage( @@ -271,10 +265,15 @@ async def _update_dynamic_output(self, outer_most_step_id: StepID, critical for scalability of virtual actors. Args: - outer_most_step_id: ID of outer_most_step. See - "step_executor.execute_workflow" for explanation. + outer_most_step_id: See WorkflowStepContext for explanation. dynamic_output_step_id: ID of dynamic_step. """ + # outer_most_step_id == "" indicates the root step of a + # workflow. This would directly update "outputs.json" in + # the workflow dir, and we want to avoid it. + if outer_most_step_id is None or outer_most_step_id == "": + return + metadata = await self._get( self._key_step_output_metadata(outer_most_step_id), True) if (dynamic_output_step_id != metadata["output_step_id"] diff --git a/python/requirements.txt b/python/requirements.txt index 4d0baeaf9ef80..2f683373fbc05 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -5,7 +5,7 @@ # In short, if you change it here, PLEASE also change it in setup.py. # # setup.py install_requires -aiohttp==3.7 +aiohttp>=3.7 aioredis < 2 click >= 7.0 cloudpickle @@ -27,7 +27,7 @@ requests ## setup.py extras dm_tree flask -gym +gym==0.19 lz4 scikit-image opencv-python-headless==4.3.0.36 @@ -68,6 +68,7 @@ opentelemetry-exporter-otlp==1.1.0 pexpect Pillow; platform_system != "Windows" pygments +pyspark pytest==5.4.3 pytest-asyncio pytest-rerunfailures diff --git a/python/requirements/ml/requirements_rllib.txt b/python/requirements/ml/requirements_rllib.txt index a81e52c9c1f08..6bba94e49fc99 100644 --- a/python/requirements/ml/requirements_rllib.txt +++ b/python/requirements/ml/requirements_rllib.txt @@ -10,9 +10,9 @@ kaggle_environments==1.7.11 # Unity3D testing mlagents_envs==0.27.0 # For tests on PettingZoo's multi-agent envs. -pettingzoo==1.11.0 +pettingzoo==1.11.1 pymunk==6.0.0 -supersuit +supersuit==2.6.6 # For testing in MuJoCo-like envs (in PyBullet). pybullet==3.1.7 # For tests on RecSim and Kaggle envs. diff --git a/python/requirements/requirements_default.txt b/python/requirements/requirements_default.txt index 4537b9f9ea2a6..2df14c6e7588d 100644 --- a/python/requirements/requirements_default.txt +++ b/python/requirements/requirements_default.txt @@ -1,4 +1,4 @@ -aiohttp +aiohttp>=3.7 aiohttp_cors aioredis<2 colorful diff --git a/python/requirements_linters.txt b/python/requirements_linters.txt index 6f5661b1f2b2f..69f457fea1688 100644 --- a/python/requirements_linters.txt +++ b/python/requirements_linters.txt @@ -1,5 +1,6 @@ flake8==3.9.1 flake8-comprehensions flake8-quotes==2.0.0 +flake8-bugbear==21.9.2 mypy==0.782 yapf==0.23.0 diff --git a/python/setup.py b/python/setup.py index 62d1e4e36fa46..3fb8ff43ab262 100644 --- a/python/setup.py +++ b/python/setup.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) SUPPORTED_PYTHONS = [(3, 6), (3, 7), (3, 8), (3, 9)] -SUPPORTED_BAZEL = (3, 4, 1) +SUPPORTED_BAZEL = (4, 2, 1) ROOT_DIR = os.path.dirname(__file__) BUILD_JAVA = os.getenv("RAY_INSTALL_JAVA") == "1" @@ -184,8 +184,13 @@ def get_packages(self): # in this directory if setup_spec.type == SetupType.RAY: setup_spec.extras = { + "data": [ + "pandas", + "pyarrow>=4.0.1", + "fsspec", + ], "default": [ - "aiohttp", + "aiohttp >= 3.7", "aiohttp_cors", "aioredis < 2", "colorful", @@ -534,6 +539,19 @@ def copy_file(target_dir, filename, rootdir): return 0 +def add_system_dlls(dlls, target_dir): + """ + Copy any required dlls required by the c-extension module and not already + provided by python. They will end up in the wheel next to the c-extension + module which will guarentee they are available at runtime. + """ + for dll in dlls: + # Installing Visual Studio will copy the runtime dlls to system32 + src = os.path.join(r"c:\Windows\system32", dll) + assert os.path.exists(src) + shutil.copy(src, target_dir) + + def pip_run(build_ext): build(True, BUILD_JAVA, True) @@ -558,6 +576,13 @@ def pip_run(build_ext): copied_files = 0 for filename in setup_spec.files_to_include: copied_files += copy_file(build_ext.build_lib, filename, ROOT_DIR) + if sys.platform == "win32": + # _raylet.pyd links to some MSVC runtime DLLS, this one may not be + # present on a user's machine. While vcruntime140.dll and + # vcruntime140_1.dll are also required, they are provided by CPython. + runtime_dlls = ["msvcp140.dll"] + add_system_dlls(runtime_dlls, os.path.join(build_ext.build_lib, "ray")) + copied_files += len(runtime_dlls) print("# of files copied to {}: {}".format(build_ext.build_lib, copied_files)) diff --git a/release/.buildkite/build_pipeline.py b/release/.buildkite/build_pipeline.py index 96d58a2b54f2a..9cbfdbdc08a0c 100644 --- a/release/.buildkite/build_pipeline.py +++ b/release/.buildkite/build_pipeline.py @@ -99,6 +99,7 @@ def __init__(self, name: str, retry: int = 0): "~/ray/release/nightly_tests/nightly_tests.yaml": [ "dask_on_ray_large_scale_test_no_spilling", "dask_on_ray_large_scale_test_spilling", + "pg_autoscaling_regression_test", ], "~/ray/release/long_running_tests/long_running_tests.yaml": [ SmokeTest("actor_deaths"), diff --git a/release/RELEASE_CHECKLIST.md b/release/RELEASE_CHECKLIST.md index f8a55bfff9aec..e4770a1cb6fdd 100644 --- a/release/RELEASE_CHECKLIST.md +++ b/release/RELEASE_CHECKLIST.md @@ -31,6 +31,7 @@ This checklist is meant to be used in conjunction with the RELEASE_PROCESS.rst d - [ ] Test passing - [ ] Results added to `release/release_logs` - [ ] microbenchmark +- [ ] `kubernetes` manual release tests pass - [ ] ``weekly`` release test suite - [ ] Test passing diff --git a/release/RELEASE_PROCESS.rst b/release/RELEASE_PROCESS.rst index 59da95846cf3c..b2f7d05db5492 100644 --- a/release/RELEASE_PROCESS.rst +++ b/release/RELEASE_PROCESS.rst @@ -172,6 +172,9 @@ Release tests are added and maintained by the respective teams. As another example, if you just want to kick off all nightly RLLib tests, select the respective test suite and specify ``rllib`` in the test file filter. +6. **Kubernetes tests must be run manually.** Refer to ``kubernetes_manual_tests/README.md``. + Feel free to ping code owner(s) of OSS Kubernetes support to run these. + Identify and Resolve Release Blockers ------------------------------------- If a release blocking issue arises in the course of testing, you should diff --git a/release/alerts/xgboost_tests.py b/release/alerts/xgboost_tests.py index 59ab2880adf76..8b77cc17f49c7 100644 --- a/release/alerts/xgboost_tests.py +++ b/release/alerts/xgboost_tests.py @@ -43,7 +43,9 @@ def handle_result(created_on: datetime.datetime, category: str, else: # train scripts if test_name == "train_small": - target_time = 30 + # Leave a couple of seconds for ray connect setup + # (without connect it should finish in < 30) + target_time = 45 elif test_name == "train_moderate": target_time = 60 elif test_name == "train_gpu": diff --git a/release/e2e.py b/release/e2e.py index 1b5fe71d15923..f47a0bbeecf08 100644 --- a/release/e2e.py +++ b/release/e2e.py @@ -264,11 +264,30 @@ def getenv_default(key: str, default: Optional[str] = None): } REPORT_S = 30 +RETRY_MULTIPLIER = 2 + + +def exponential_backoff_retry(f, retry_exceptions, initial_retry_delay_s, + max_retries): + retry_cnt = 0 + retry_delay_s = initial_retry_delay_s + while True: + try: + return f() + except retry_exceptions as e: + retry_cnt += 1 + if retry_cnt > max_retries: + raise + logger.info(f"Retry function call failed due to {e} " + f"in {retry_delay_s} seconds...") + time.sleep(retry_delay_s) + retry_delay_s *= RETRY_MULTIPLIER def maybe_fetch_api_token(): if GLOBAL_CONFIG["ANYSCALE_CLI_TOKEN"] is None: - print("Missing ANYSCALE_CLI_TOKEN, retrieving from AWS secrets store") + logger.info( + "Missing ANYSCALE_CLI_TOKEN, retrieving from AWS secrets store") # NOTE(simon) This should automatically retrieve # release-automation@anyscale.com's anyscale token GLOBAL_CONFIG["ANYSCALE_CLI_TOKEN"] = boto3.client( @@ -405,7 +424,8 @@ def populate_wheels_sanity_check(commit: Optional[str] = None): raise RuntimeError(f"Could not populate wheels sanity check command: " f"Commit hash missing. Got: {commit}") - cmd = f"python -c 'import ray; assert ray.__commit__ == \"{commit}\"'" + cmd = (f"python -c 'import ray; " + f"assert ray.__commit__ == \"{commit}\", ray.__commit__'") os.environ["RAY_WHEELS_SANITY_CHECK"] = cmd @@ -463,7 +483,7 @@ def has_errored(result: Dict[Any, Any]) -> bool: return result.get("status", "invalid") != "finished" -def report_result(test_suite: str, test_name: str, status: str, logs: str, +def report_result(test_suite: str, test_name: str, status: str, last_logs: str, results: Dict[Any, Any], artifacts: Dict[Any, Any], category: str): now = datetime.datetime.utcnow() @@ -477,67 +497,66 @@ def report_result(test_suite: str, test_name: str, status: str, logs: str, f"results, artifacts, category) " f"VALUES (:created_on, :test_suite, :test_name, :status, :last_logs, " f":results, :artifacts, :category)") - - rds_data_client.execute_statement( - database=GLOBAL_CONFIG["RELEASE_AWS_DB_NAME"], - parameters=[ - { - "name": "created_on", - "typeHint": "TIMESTAMP", - "value": { - "stringValue": now.strftime("%Y-%m-%d %H:%M:%S") - }, - }, - { - "name": "test_suite", - "value": { - "stringValue": test_suite - } - }, - { - "name": "test_name", - "value": { - "stringValue": test_name - } - }, - { - "name": "status", - "value": { - "stringValue": status - } - }, - { - "name": "last_logs", - "value": { - "stringValue": logs - } - }, - { - "name": "results", - "typeHint": "JSON", - "value": { - "stringValue": json.dumps(results) - }, - }, - { - "name": "artifacts", - "typeHint": "JSON", - "value": { - "stringValue": json.dumps(artifacts) - }, - }, - { - "name": "category", - "value": { - "stringValue": category - } - }, - ], - secretArn=GLOBAL_CONFIG["RELEASE_AWS_DB_SECRET_ARN"], - resourceArn=GLOBAL_CONFIG["RELEASE_AWS_DB_RESOURCE_ARN"], - schema=schema, - sql=sql, - ) + parameters = [{ + "name": "created_on", + "typeHint": "TIMESTAMP", + "value": { + "stringValue": now.strftime("%Y-%m-%d %H:%M:%S") + }, + }, { + "name": "test_suite", + "value": { + "stringValue": test_suite + } + }, { + "name": "test_name", + "value": { + "stringValue": test_name + } + }, { + "name": "status", + "value": { + "stringValue": status + } + }, { + "name": "last_logs", + "value": { + "stringValue": last_logs + } + }, { + "name": "results", + "typeHint": "JSON", + "value": { + "stringValue": json.dumps(results) + }, + }, { + "name": "artifacts", + "typeHint": "JSON", + "value": { + "stringValue": json.dumps(artifacts) + }, + }, { + "name": "category", + "value": { + "stringValue": category + } + }] + + # Default boto3 call timeout is 45 seconds. + retry_delay_s = 64 + MAX_RDS_RETRY = 3 + exponential_backoff_retry( + lambda: rds_data_client.execute_statement( + database=GLOBAL_CONFIG["RELEASE_AWS_DB_NAME"], + parameters=parameters, + secretArn=GLOBAL_CONFIG["RELEASE_AWS_DB_SECRET_ARN"], + resourceArn=GLOBAL_CONFIG["RELEASE_AWS_DB_RESOURCE_ARN"], + schema=schema, + sql=sql), + retry_exceptions=rds_data_client.exceptions.StatementTimeoutException, + initial_retry_delay_s=retry_delay_s, + max_retries=MAX_RDS_RETRY) + logger.info("Result has been persisted to the databse") def log_results_and_artifacts(result: Dict): @@ -903,7 +922,11 @@ def wait_for_session_command_to_complete(create_session_command_result, # Sleep 1 sec before next check. time.sleep(1) - result = sdk.get_session_command(session_command_id=scd_id) + result = exponential_backoff_retry( + lambda: sdk.get_session_command(session_command_id=scd_id), + retry_exceptions=Exception, + initial_retry_delay_s=10, + max_retries=3) completed = result.result.finished_at if state_str == "CMD_RUN": @@ -934,10 +957,14 @@ def wait_for_session_command_to_complete(create_session_command_result, def get_command_logs(session_controller: SessionController, scd_id: str, lines: int = 50): - result = session_controller.api_client.get_execution_logs_api_v2_session_commands_session_command_id_execution_logs_get( # noqa: E501 - session_command_id=scd_id, - start_line=-1 * lines, - end_line=0) + result = exponential_backoff_retry( + lambda: session_controller.api_client.get_execution_logs_api_v2_session_commands_session_command_id_execution_logs_get( # noqa: E501 + session_command_id=scd_id, + start_line=-1 * lines, + end_line=0), + retry_exceptions=Exception, + initial_retry_delay_s=10, + max_retries=3) return result.result.lines @@ -1777,7 +1804,7 @@ def run_test(test_config_file: str, report: bool = True, keep_results_dir: bool = False, session_name: Optional[str] = None, - app_config_id_override=None): + app_config_id_override=None) -> Dict[str, Any]: with open(test_config_file, "rt") as f: test_configs = yaml.load(f, Loader=yaml.FullLoader) @@ -1836,18 +1863,18 @@ def run_test(test_config_file: str, logger.info("Kicked off test. It's now up to the `--check` " "part of the script to track its process.") - return + return {} else: # `--check` or no kick off only if status == "nosession": logger.info(f"No running session found for test {test_name}, so " f"assuming everything is fine.") - return + return {} if status == "kickoff": logger.info(f"Test {test_name} is still running.") - return + return {} last_logs = result.get("last_logs", "No logs.") @@ -1857,7 +1884,7 @@ def run_test(test_config_file: str, test_suite=test_suite, test_name=test_name, status=status, - logs=last_logs, + last_logs=last_logs, results=result.get("results", {}), artifacts=result.get("artifacts", {}), category=category, @@ -1872,7 +1899,7 @@ def run_test(test_config_file: str, if has_errored(result): raise RuntimeError(last_logs) - return + return report_kwargs if __name__ == "__main__": @@ -1935,7 +1962,6 @@ def run_test(test_config_file: str, "You have to set the ANYSCALE_PROJECT environment variable!") maybe_fetch_api_token() - if args.ray_wheels: os.environ["RAY_WHEELS"] = str(args.ray_wheels) url = str(args.ray_wheels) @@ -1955,7 +1981,7 @@ def run_test(test_config_file: str, test_config_file = os.path.abspath(os.path.expanduser(args.test_config)) - run_test( + result_dict = run_test( test_config_file=test_config_file, test_name=args.test_name, project_id=GLOBAL_CONFIG["ANYSCALE_PROJECT"], @@ -1970,3 +1996,30 @@ def run_test(test_config_file: str, keep_results_dir=args.keep_results_dir, app_config_id_override=args.app_config_id_override, ) + + if result_dict: + # If we get a result dict, check if any alerts should be raised + from alert import SUITE_TO_FN, default_handle_result + + logger.info("Checking if results are valid...") + + handle_result_kwargs = result_dict.copy() + handle_result_kwargs["created_on"] = None + + test_suite = handle_result_kwargs.get("test_suite", None) + test_name = handle_result_kwargs.get("test_name", None) + category = handle_result_kwargs.get("category", None) + + handle_fn = SUITE_TO_FN.get(test_suite, None) + if not handle_fn: + logger.warning(f"No handle for suite {test_suite}") + alert = default_handle_result(**handle_result_kwargs) + else: + alert = handle_fn(**handle_result_kwargs) + + if alert: + # If we get an alert, the test failed. + raise RuntimeError(alert) + else: + logger.info(f"No alert raised for test {test_suite}/{test_name} " + f"({category}) - the test successfully passed!") diff --git a/release/golden_notebook_tests/dask_xgboost_app_config.yaml b/release/golden_notebook_tests/dask_xgboost_app_config.yaml index 072b183099476..a05da857edef8 100755 --- a/release/golden_notebook_tests/dask_xgboost_app_config.yaml +++ b/release/golden_notebook_tests/dask_xgboost_app_config.yaml @@ -5,9 +5,8 @@ debian_packages: python: pip_packages: - - pytest - pandas>=1.3.0 # otherwise, a version mismatch between local and remote will cause an exception - - xgboost_ray[default] + - git+https://github.com/ray-project/xgboost_ray.git#egg=xgboost_ray[default] - dask - fastapi - uvicorn @@ -16,5 +15,5 @@ python: post_build_cmds: - pip uninstall -y ray || true - - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - pip install -U {{ env["RAY_WHEELS"] | default("ray") }} - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/golden_notebook_tests/golden_notebook_tests.yaml b/release/golden_notebook_tests/golden_notebook_tests.yaml index 1fae1e1d65824..e6d5838d10333 100644 --- a/release/golden_notebook_tests/golden_notebook_tests.yaml +++ b/release/golden_notebook_tests/golden_notebook_tests.yaml @@ -1,4 +1,7 @@ - name: dask_xgboost_test + owner: + mail: "antoni@anyscale.com" + slack: "@team_ml" cluster: app_config: dask_xgboost_app_config.yaml compute_template: compute_tpl.yaml @@ -8,8 +11,18 @@ autosuspend_mins: 10 timeout: 1200 script: python workloads/dask_xgboost_test.py + args: + [ + "--num-actors 4", + "--cpus-per-actor 4", + "--num-actors-inference 16", + "--cpus-per-actor-inference 1", + ] - name: modin_xgboost_test + owner: + mail: "antoni@anyscale.com" + slack: "@team_ml" cluster: app_config: modin_xgboost_app_config.yaml compute_template: compute_tpl.yaml @@ -19,6 +32,13 @@ autosuspend_mins: 10 timeout: 1200 script: python workloads/modin_xgboost_test.py + args: + [ + "--num-actors 4", + "--cpus-per-actor 4", + "--num-actors-inference 16", + "--cpus-per-actor-inference 1", + ] - name: torch_tune_serve_test owner: @@ -34,4 +54,3 @@ autosuspend_mins: 10 timeout: 1200 script: python workloads/torch_tune_serve_test.py - diff --git a/release/golden_notebook_tests/modin_xgboost_app_config.yaml b/release/golden_notebook_tests/modin_xgboost_app_config.yaml index c17fa85ca0144..5fb35e7b03fdd 100755 --- a/release/golden_notebook_tests/modin_xgboost_app_config.yaml +++ b/release/golden_notebook_tests/modin_xgboost_app_config.yaml @@ -5,7 +5,8 @@ debian_packages: python: pip_packages: - - pytest + - pandas>=1.3.0 # otherwise, a version mismatch between local and remote will cause an exception + - git+https://github.com/ray-project/xgboost_ray.git#egg=xgboost_ray[default] - modin - s3fs - fastapi @@ -16,4 +17,4 @@ python: post_build_cmds: - pip uninstall -y ray || true - pip install -U {{ env["RAY_WHEELS"] | default("ray") }} - - pip install git+https://github.com/ray-project/xgboost_ray.git#egg=xgboost_ray[default] + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/golden_notebook_tests/workloads/dask_xgboost_test.py b/release/golden_notebook_tests/workloads/dask_xgboost_test.py index 99755eb4399bb..c10bf91d96754 100644 --- a/release/golden_notebook_tests/workloads/dask_xgboost_test.py +++ b/release/golden_notebook_tests/workloads/dask_xgboost_test.py @@ -1,135 +1,28 @@ -import argparse -import json +import ray import os import time +import json +from util import import_and_execute_test_script, wait_for_cluster_client -import dask -import dask.dataframe as dd -import ray -from ray import tune - -from ray.util.dask import ray_dask_get - -from xgboost_ray import RayDMatrix, RayParams, train, predict - -from utils.utils import is_anyscale_connect - -FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/" \ - "simpleHIGGS.csv" - - -def train_xgboost(config, train_df, test_df, target_column, ray_params): - # distributed loading of a parquet dataset - train_set = RayDMatrix(train_df, target_column) - test_set = RayDMatrix(test_df, target_column) - - evals_result = {} - - start_time = time.time() - # Train the classifier - bst = train( - params=config, - dtrain=train_set, - evals=[(test_set, "eval")], - evals_result=evals_result, - verbose_eval=False, - num_boost_round=100, - ray_params=ray_params) - print(f"Total time taken: {time.time()-start_time}") - - model_path = "model.xgb" - bst.save_model(model_path) - print("Final validation error: {:.4f}".format( - evals_result["eval"]["error"][-1])) - - return bst - - -def tune_xgboost(train_df, test_df, target_column): - # Set XGBoost config. - config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], - "eta": tune.loguniform(1e-4, 1e-1), - "subsample": tune.uniform(0.5, 1.0), - "max_depth": tune.randint(1, 9) - } - - ray_params = RayParams( - max_actor_restarts=1, gpus_per_actor=0, cpus_per_actor=4, num_actors=4) - - analysis = tune.run( - tune.with_parameters( - train_xgboost, - train_df=train_df, - test_df=test_df, - target_column=target_column, - ray_params=ray_params), - # Use the `get_tune_resources` helper function to set the resources. - resources_per_trial=ray_params.get_tune_resources(), - config=config, - num_samples=1, - metric="eval-error", - mode="min", - verbose=1) - - accuracy = 1. - analysis.best_result["eval-error"] - print(f"Best model parameters: {analysis.best_config}") - print(f"Best model total accuracy: {accuracy:.4f}") - - return analysis.best_config +NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO = ( + "doc/examples/dask_xgboost/dask_xgboost.py") def main(): - print("Loading HIGGS data.") - - dask.config.set(scheduler=ray_dask_get) - colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)] - data = dd.read_csv(FILE_URL, names=colnames) - - print("Loaded HIGGS data.") - - # partition on a column - df_train = data[(data["feature-01"] < 0.4)] - df_validation = data[(data["feature-01"] >= 0.4) - & (data["feature-01"] < 0.8)] - - config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], - } - - bst = train_xgboost( - config, df_train, df_validation, "label", - RayParams(max_actor_restarts=1, cpus_per_actor=4, num_actors=4)) - tune_xgboost(df_train, df_validation, "label") - inference_df = RayDMatrix( - df_train[sorted(df_train.columns)], ignore=["label", "partition"]) - predict( - bst, - inference_df, - ray_params=RayParams(cpus_per_actor=2, num_actors=16)) + import_and_execute_test_script(NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--smoke-test", - action="store_true", - help="Finish quickly for testing.") - args = parser.parse_args() - start = time.time() addr = os.environ.get("RAY_ADDRESS") job_name = os.environ.get("RAY_JOB_NAME", "dask_xgboost_test") - if is_anyscale_connect(addr): + if addr is not None and addr.startswith("anyscale://"): ray.init(address=addr, job_name=job_name) else: ray.init(address="auto") + wait_for_cluster_client(4, 600) main() taken = time.time() - start diff --git a/release/golden_notebook_tests/workloads/modin_xgboost_test.py b/release/golden_notebook_tests/workloads/modin_xgboost_test.py index 4180351e7cb40..d5fb36f07b23e 100644 --- a/release/golden_notebook_tests/workloads/modin_xgboost_test.py +++ b/release/golden_notebook_tests/workloads/modin_xgboost_test.py @@ -1,131 +1,28 @@ -import argparse -import json +import ray import os import time +import json +from util import import_and_execute_test_script, wait_for_cluster_client -import modin.pandas as pd -import ray -from ray import tune -from xgboost_ray import RayDMatrix, RayParams, train, predict - -from utils.utils import is_anyscale_connect - -FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/" \ - "simpleHIGGS.csv" - - -def train_xgboost(config, train_df, test_df, target_column, ray_params): - # distributed loading of a parquet dataset - train_set = RayDMatrix(train_df, target_column) - test_set = RayDMatrix(test_df, target_column) - - evals_result = {} - - start_time = time.time() - # Train the classifier - bst = train( - params=config, - dtrain=train_set, - evals=[(test_set, "eval")], - evals_result=evals_result, - verbose_eval=False, - num_boost_round=100, - ray_params=ray_params) - print(f"Total time taken: {time.time()-start_time}") - - model_path = "model.xgb" - bst.save_model(model_path) - print("Final validation error: {:.4f}".format( - evals_result["eval"]["error"][-1])) - - return bst - - -def tune_xgboost(train_df, test_df, target_column): - # Set XGBoost config. - config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], - "eta": tune.loguniform(1e-4, 1e-1), - "subsample": tune.uniform(0.5, 1.0), - "max_depth": tune.randint(1, 9) - } - - ray_params = RayParams( - max_actor_restarts=1, gpus_per_actor=0, cpus_per_actor=1, num_actors=2) - - analysis = tune.run( - tune.with_parameters( - train_xgboost, - train_df=train_df, - test_df=test_df, - target_column=target_column, - ray_params=ray_params), - # Use the `get_tune_resources` helper function to set the resources. - resources_per_trial=ray_params.get_tune_resources(), - config=config, - num_samples=1, - metric="eval-error", - mode="min", - verbose=1) - - accuracy = 1. - analysis.best_result["eval-error"] - print(f"Best model parameters: {analysis.best_config}") - print(f"Best model total accuracy: {accuracy:.4f}") - - return analysis.best_config +NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO = ( + "doc/examples/modin_xgboost/modin_xgboost.py") def main(): - print("Loading HIGGS data.") - - colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)] - - data = pd.read_csv(FILE_URL, names=colnames) - - print("Loaded HIGGS data.") - - # partition on a column - df_train = data[(data["feature-01"] < 0.4)] - df_validation = data[(data["feature-01"] >= 0.4) - & (data["feature-01"] < 0.8)] - - config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], - } - - bst = train_xgboost( - config, df_train, df_validation, "label", - RayParams(max_actor_restarts=1, cpus_per_actor=4, num_actors=4)) - # tune_xgboost(df_train, df_validation, "label") # broken atm - inference_df = RayDMatrix( - df_train[sorted(df_train.columns)], ignore=["label", "partition"]) - predict( - bst, - inference_df, - ray_params=RayParams(cpus_per_actor=1, num_actors=16)) + import_and_execute_test_script(NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--smoke-test", - action="store_true", - help="Finish quickly for testing.") - args = parser.parse_args() - start = time.time() addr = os.environ.get("RAY_ADDRESS") job_name = os.environ.get("RAY_JOB_NAME", "modin_xgboost_test") - if is_anyscale_connect(addr): + if addr is not None and addr.startswith("anyscale://"): ray.init(address=addr, job_name=job_name) else: ray.init(address="auto") + wait_for_cluster_client(4, 600) main() taken = time.time() - start diff --git a/release/golden_notebook_tests/workloads/torch_tune_serve_test.py b/release/golden_notebook_tests/workloads/torch_tune_serve_test.py index 15bd43a575a7a..9b511d5765ae6 100644 --- a/release/golden_notebook_tests/workloads/torch_tune_serve_test.py +++ b/release/golden_notebook_tests/workloads/torch_tune_serve_test.py @@ -17,8 +17,6 @@ from torch.utils.data import DataLoader, Subset from torchvision.datasets import MNIST -from utils.utils import is_anyscale_connect - def load_mnist_data(train: bool, download: bool): transform = transforms.Compose( @@ -200,7 +198,7 @@ def test_predictions(test_mode=False): addr = os.environ.get("RAY_ADDRESS") job_name = os.environ.get("RAY_JOB_NAME", "torch_tune_serve_test") - if is_anyscale_connect(addr): + if addr is not None and addr.startswith("anyscale://"): client = ray.init(address=addr, job_name=job_name) else: client = ray.init(address="auto") diff --git a/release/golden_notebook_tests/workloads/util.py b/release/golden_notebook_tests/workloads/util.py new file mode 100644 index 0000000000000..a0efc28b0e73a --- /dev/null +++ b/release/golden_notebook_tests/workloads/util.py @@ -0,0 +1,49 @@ +from pathlib import Path +import importlib.util +import ray +import time + + +def import_and_execute_test_script(relative_path_to_test_script: str): + """Imports and executes a module from a path relative to Ray repo root.""" + # get the ray folder + ray_path = next( + x for x in Path(__file__).resolve().parents if str(x).endswith("/ray")) + notebook_path = ray_path.joinpath(relative_path_to_test_script) + assert notebook_path.exists() + + spec = importlib.util.spec_from_file_location("notebook_test", + notebook_path) + notebook_test_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(notebook_test_module) + + +def wait_for_cluster_client(num_nodes: int, + max_time_s: int, + feedback_interval_s: int = 10): + assert ray.is_initialized() + curr_nodes = 0 + start = time.time() + next_feedback = start + max_time = start + max_time_s + while not curr_nodes >= num_nodes: + now = time.time() + + if now >= max_time: + raise RuntimeError( + f"Maximum wait time reached, but only " + f"{curr_nodes}/{num_nodes} nodes came up. Aborting.") + + if now >= next_feedback: + passed = now - start + print(f"Waiting for more nodes to come up: " + f"{curr_nodes}/{num_nodes} " + f"({passed:.0f} seconds passed)") + next_feedback = now + feedback_interval_s + + time.sleep(5) + curr_nodes = len(ray.nodes()) + + passed = time.time() - start + print(f"Cluster is up: {curr_nodes}/{num_nodes} nodes online after " + f"{passed:.0f} seconds") diff --git a/release/golden_notebook_tests/workloads/utils/utils.py b/release/golden_notebook_tests/workloads/utils/utils.py deleted file mode 100644 index 071f076c72aee..0000000000000 --- a/release/golden_notebook_tests/workloads/utils/utils.py +++ /dev/null @@ -1,5 +0,0 @@ -def is_anyscale_connect(address: str) -> bool: - """Returns whether or not the Ray Address points to an Anyscale cluster.""" - is_anyscale_connect = address is not None and address.startswith( - "anyscale://") - return is_anyscale_connect diff --git a/release/kubernetes_manual_tests/README.md b/release/kubernetes_manual_tests/README.md new file mode 100644 index 0000000000000..12b61f272b079 --- /dev/null +++ b/release/kubernetes_manual_tests/README.md @@ -0,0 +1,25 @@ +# ray-k8s-tests + +These tests are not automated and thus **must be run manually** for each release. +If you have issues running them, bug the code owner(s) for OSS Kubernetes support. + +## How to run +1. Configure kubectl and Helm 3 to access a K8s cluster. +2. `git checkout releases/` +3. You might have to locally pip install the Ray wheel for the relevant commit (or pip install -e) in a conda env, see Ray client note below. +4. cd to this directory +5. `IMAGE=rayproject/ray: bash k8s_release_tests.sh` +6. Test outcomes will be reported at the end of the output. + +This runs three tests and does the necessary resource creation/teardown. The tests typically take about 15 minutes to finish. + +## Notes +0. Anyscale employees: You should have access to create a K8s cluster using either GKE or EKS, ask OSS Kubernetes code owner if in doubt. +1. Your Ray cluster should be able to accomodate 30 1-CPU pods to run all of the tests. +2. These tests use basic Ray client functionality -- your locally installed Ray version may need to be updated to match the one in the release image. +3. The tests do a poor job of Ray client port-forwarding process clean-up -- if a test fails, it's possible there might be a port-forwarding process stuck running in the background. To identify the rogue process run `ps aux | grep "port-forward"`. Then `kill` it. +4. There are some errors that will appear on the screen during the run -- that's normal, error recovery is being tested. + +## Running individual tests +To run any of the three individual tests, substitute in step 5 of **How to Run** `k8s-test.sh` or `helm-test.sh` or `k8s-test-scale.sh`. +It's the last of these that needs 30 1-cpu pods. 10 is enough for either of the other two. The scale test is currently somewhat flaky. Rerun it if it fails. diff --git a/release/kubernetes_manual_tests/helm-test.sh b/release/kubernetes_manual_tests/helm-test.sh new file mode 100755 index 0000000000000..273ddb5c1cc11 --- /dev/null +++ b/release/kubernetes_manual_tests/helm-test.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -x +kubectl create namespace helm-test +kubectl create namespace helm-test2 +KUBERNETES_OPERATOR_TEST_NAMESPACE=helm-test KUBERNETES_OPERATOR_TEST_IMAGE="$IMAGE" python ../../python/ray/tests/kubernetes_e2e/test_helm.py +kubectl delete namespace helm-test +kubectl delete namespace helm-test2 +kubectl delete -f ../../deploy/charts/ray/crds/cluster_crd.yaml diff --git a/release/kubernetes_manual_tests/k8s-test-scale.sh b/release/kubernetes_manual_tests/k8s-test-scale.sh new file mode 100755 index 0000000000000..59ea06c80f5f1 --- /dev/null +++ b/release/kubernetes_manual_tests/k8s-test-scale.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -x +kubectl create namespace scale-test +kubectl create namespace scale-test2 +KUBERNETES_OPERATOR_TEST_NAMESPACE=scale-test KUBERNETES_OPERATOR_TEST_IMAGE="$IMAGE" python ../../python/ray/tests/kubernetes_e2e/test_k8s_operator_scaling.py +kubectl -n scale-test delete --all rayclusters +kubectl -n scale-test2 delete --all rayclusters +kubectl delete -f ../../deploy/components/operator_cluster_scoped.yaml +kubectl delete namespace scale-test +kubectl delete namespace scale-test2 +kubectl delete -f ../../deploy/charts/ray/crds/cluster_crd.yaml diff --git a/release/kubernetes_manual_tests/k8s-test.sh b/release/kubernetes_manual_tests/k8s-test.sh new file mode 100755 index 0000000000000..aa0ec6325d880 --- /dev/null +++ b/release/kubernetes_manual_tests/k8s-test.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -x +kubectl create namespace basic-test +kubectl apply -f ../../deploy/charts/ray/crds/cluster_crd.yaml +KUBERNETES_OPERATOR_TEST_NAMESPACE=basic-test KUBERNETES_OPERATOR_TEST_IMAGE="$IMAGE" python ../../python/ray/tests/kubernetes_e2e/test_k8s_operator_basic.py +kubectl -n basic-test delete --all rayclusters +kubectl -n basic-test delete deployment ray-operator +kubectl delete namespace basic-test +kubectl delete -f ../../deploy/charts/ray/crds/cluster_crd.yaml diff --git a/release/kubernetes_manual_tests/k8s_release_tests.sh b/release/kubernetes_manual_tests/k8s_release_tests.sh new file mode 100644 index 0000000000000..6576dcdabfa39 --- /dev/null +++ b/release/kubernetes_manual_tests/k8s_release_tests.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -x +IMAGE="$IMAGE" bash k8s-test.sh +BASIC_SUCCEEDED=$? +IMAGE="$IMAGE" bash helm-test.sh +HELM_SUCCEEDED=$? +IMAGE="$IMAGE" bash k8s-test-scale.sh +SCALE_SUCCEEDED=$? + +if (( BASIC_SUCCEEDED == 0 )) +then + echo "k8s-test.sh succeeded" +else + echo "k8s-test.sh test failed" +fi + +if (( HELM_SUCCEEDED == 0 )) +then + echo "helm-test.sh test succeeded"; +else + echo "helm-test.sh test failed" +fi + +if (( SCALE_SUCCEEDED == 0)) +then + echo "k8s-test-scale.sh test succeeded"; +else + echo "k8s-test-scale.sh failed. Try re-running just the k8s-test-scale.sh. It's expected to be flaky." +fi + diff --git a/release/long_running_tests/tpl_cpu_1.yaml b/release/long_running_tests/tpl_cpu_1.yaml index 1045aa8948456..a22bc5dfc95a7 100644 --- a/release/long_running_tests/tpl_cpu_1.yaml +++ b/release/long_running_tests/tpl_cpu_1.yaml @@ -22,3 +22,8 @@ aws: Value: '{{env["ANYSCALE_USER"]}}' - Key: anyscale-expiration Value: '{{env["EXPIRATION_2D"]}}' + + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 202 \ No newline at end of file diff --git a/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml b/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml index 9b7a0a9a11d3f..1aa0b86782476 100644 --- a/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml +++ b/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} debian_packages: [] python: diff --git a/release/nightly_tests/dataset/app_config.yaml b/release/nightly_tests/dataset/app_config.yaml index c0cc753990de9..5f311fbabfe87 100644 --- a/release/nightly_tests/dataset/app_config.yaml +++ b/release/nightly_tests/dataset/app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" -env_vars: {} python: pip_packages: [] diff --git a/release/nightly_tests/dataset/dataset_shuffle_data_loader.py b/release/nightly_tests/dataset/dataset_shuffle_data_loader.py index da3a7d74649f0..e917624a4712b 100644 --- a/release/nightly_tests/dataset/dataset_shuffle_data_loader.py +++ b/release/nightly_tests/dataset/dataset_shuffle_data_loader.py @@ -85,7 +85,7 @@ def create_torch_iterator(split, batch_size, rank=None): def create_dataset(filenames, repeat_times): pipeline = ray.data.read_parquet(list(filenames))\ - .repeat(times=repeat_times).random_shuffle() + .repeat(times=repeat_times).random_shuffle_each_window() return pipeline diff --git a/release/nightly_tests/dataset/pipelined_ingestion_app.yaml b/release/nightly_tests/dataset/pipelined_ingestion_app.yaml index 2fbda804b9b50..23ee18a1008b7 100644 --- a/release/nightly_tests/dataset/pipelined_ingestion_app.yaml +++ b/release/nightly_tests/dataset/pipelined_ingestion_app.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" -env_vars: {} python: pip_packages: [] diff --git a/release/nightly_tests/dataset/pipelined_training.py b/release/nightly_tests/dataset/pipelined_training.py index d9a4b9245bee1..c8c7486724755 100644 --- a/release/nightly_tests/dataset/pipelined_training.py +++ b/release/nightly_tests/dataset/pipelined_training.py @@ -244,12 +244,12 @@ def __next__(self): i * num_rows // num_windows // num_workers for i in range(1, num_workers) ] - pipe = pipe.random_shuffle(_spread_resource_prefix="node:") + pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:") pipe_shards = pipe.split_at_indices(split_indices) else: ds = ray.data.read_parquet(files, _spread_resource_prefix="node:") pipe = ds.repeat(epochs) - pipe = pipe.random_shuffle(_spread_resource_prefix="node:") + pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:") pipe_shards = pipe.split(num_workers, equal=True) return pipe_shards diff --git a/release/nightly_tests/dataset/pipelined_training_app.yaml b/release/nightly_tests/dataset/pipelined_training_app.yaml index 2fbda804b9b50..23ee18a1008b7 100644 --- a/release/nightly_tests/dataset/pipelined_training_app.yaml +++ b/release/nightly_tests/dataset/pipelined_training_app.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" -env_vars: {} python: pip_packages: [] diff --git a/release/nightly_tests/dataset/shuffle_app_config.yaml b/release/nightly_tests/dataset/shuffle_app_config.yaml index ac02d79b90415..d89acec77a973 100644 --- a/release/nightly_tests/dataset/shuffle_app_config.yaml +++ b/release/nightly_tests/dataset/shuffle_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" -env_vars: {} python: pip_packages: ["boto3", "numpy", "torch", "tqdm", "pyarrow"] diff --git a/release/nightly_tests/decision_tree/decision_tree_app_config.yaml b/release/nightly_tests/decision_tree/decision_tree_app_config.yaml index 92f5d3707fe1c..70ae8eb896d16 100644 --- a/release/nightly_tests/decision_tree/decision_tree_app_config.yaml +++ b/release/nightly_tests/decision_tree/decision_tree_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} debian_packages: [] python: diff --git a/release/nightly_tests/many_nodes_tests/app_config.yaml b/release/nightly_tests/many_nodes_tests/app_config.yaml index 67eb10caac1e7..9586d050b0418 100644 --- a/release/nightly_tests/many_nodes_tests/app_config.yaml +++ b/release/nightly_tests/many_nodes_tests/app_config.yaml @@ -1,5 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} +env_vars: {"RAY_gcs_server_rpc_server_thread_num": "8", "RAY_GCS_ACTOR_SCHEDULING_ENABLED": "true"} debian_packages: [] python: diff --git a/release/nightly_tests/nightly_tests.yaml b/release/nightly_tests/nightly_tests.yaml index d932924ffa6a3..9482eade1e713 100644 --- a/release/nightly_tests/nightly_tests.yaml +++ b/release/nightly_tests/nightly_tests.yaml @@ -317,13 +317,24 @@ prepare: python wait_cluster.py 32 1000 script: python dask_on_ray/dask_on_ray_sort.py --nbytes 1_000_000_000_000 --npartitions 1000 --num-nodes 31 --ray --data-dir /tmp/ray --s3-bucket core-nightly-test -- name: many_nodes_actor_test +# TODO (yic): Add this back when we make it stable +# - name: many_nodes_actor_test +# cluster: +# app_config: many_nodes_tests/app_config.yaml +# compute_template: many_nodes_tests/compute_config.yaml + +# run: +# timeout: 7200 +# prepare: python wait_cluster.py 500 5400 +# script: python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 +# # TODO: enable failure test later +# #&& python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --fail --no-report && python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --no-report + +- name: pg_autoscaling_regression_test cluster: - app_config: many_nodes_tests/app_config.yaml - compute_template: many_nodes_tests/compute_config.yaml + app_config: placement_group_tests/app_config.yaml + compute_template: placement_group_tests/compute.yaml run: - timeout: 7200 - prepare: python wait_cluster.py 500 5400 - script: python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 - # TODO(yic): Add extra test for python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --fail --no-report && python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --no-report + timeout: 1200 + script: python placement_group_tests/pg_run.py diff --git a/release/nightly_tests/placement_group_tests/app_config.yaml b/release/nightly_tests/placement_group_tests/app_config.yaml new file mode 100644 index 0000000000000..d30247838e1e9 --- /dev/null +++ b/release/nightly_tests/placement_group_tests/app_config.yaml @@ -0,0 +1,12 @@ +base_image: "anyscale/ray-ml:pinned-nightly-py37" +debian_packages: [] + +python: + pip_packages: [] + conda_packages: [] + +post_build_cmds: + - pip uninstall -y ray + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - pip3 install -U ray[default] + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/nightly_tests/placement_group_tests/cluster.py b/release/nightly_tests/placement_group_tests/cluster.py new file mode 100644 index 0000000000000..a12ed798a4e99 --- /dev/null +++ b/release/nightly_tests/placement_group_tests/cluster.py @@ -0,0 +1,13 @@ +import time +from ray.cluster_utils import Cluster + +cluster = Cluster() + +cluster.add_node(num_cpus=16) + +time.sleep(20) +print("Scaling up.") +cluster.add_node(num_cpus=16, num_gpus=1) + +print("Scaled up. Waiting for 1000 seconds until done.") +time.sleep(1000) diff --git a/release/nightly_tests/placement_group_tests/compute.yaml b/release/nightly_tests/placement_group_tests/compute.yaml new file mode 100644 index 0000000000000..5b619db7651a4 --- /dev/null +++ b/release/nightly_tests/placement_group_tests/compute.yaml @@ -0,0 +1,27 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +aws: + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 500 + +head_node_type: + name: head_node + instance_type: m5.4xlarge + +worker_node_types: + - name: cpu_node + instance_type: m5.4xlarge + min_workers: 0 + max_workers: 2 + use_spot: false + - name: fake_gpu_node + instance_type: m5.4xlarge + min_workers: 0 + max_workers: 2 + use_spot: false + resources: + cpu: 16 + gpu: 1 diff --git a/release/nightly_tests/placement_group_tests/pg_run.py b/release/nightly_tests/placement_group_tests/pg_run.py new file mode 100644 index 0000000000000..7bb616c2dcaa3 --- /dev/null +++ b/release/nightly_tests/placement_group_tests/pg_run.py @@ -0,0 +1,65 @@ +import os +import time +import json + +import ray +from ray.util.placement_group import placement_group + +# Tests are supposed to run for 10 minutes. +RUNTIME = 600 +NUM_CPU_BUNDLES = 30 + + +@ray.remote(num_cpus=1) +class Worker(object): + def __init__(self, i): + self.i = i + + def work(self): + time.sleep(0.1) + print("work ", self.i) + + +@ray.remote(num_cpus=1, num_gpus=1) +class Trainer(object): + def __init__(self, i): + self.i = i + + def train(self): + time.sleep(0.2) + print("train ", self.i) + + +def main(): + ray.init(address="auto") + + bundles = [{"CPU": 1, "GPU": 1}] + bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)] + + pg = placement_group(bundles, strategy="PACK") + + ray.get(pg.ready()) + + workers = [ + Worker.options(placement_group=pg).remote(i) + for i in range(NUM_CPU_BUNDLES) + ] + + trainer = Trainer.options(placement_group=pg).remote(0) + + start = time.time() + while True: + ray.get([workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)]) + ray.get(trainer.train.remote()) + end = time.time() + if end - start > RUNTIME: + break + + if "TEST_OUTPUT_JSON" in os.environ: + out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") + results = {} + json.dump(results, out_file) + + +if __name__ == "__main__": + main() diff --git a/release/nightly_tests/shuffle/shuffle_app_config.yaml b/release/nightly_tests/shuffle/shuffle_app_config.yaml index 67eb10caac1e7..d30247838e1e9 100644 --- a/release/nightly_tests/shuffle/shuffle_app_config.yaml +++ b/release/nightly_tests/shuffle/shuffle_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} debian_packages: [] python: @@ -10,5 +9,4 @@ post_build_cmds: - pip uninstall -y ray - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} - pip3 install -U ray[default] - - echo {{env["DATESTAMP"]}} - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml b/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml index 2fea571c90f77..536c7b6da27f4 100644 --- a/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml +++ b/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} debian_packages: [] python: diff --git a/release/nightly_tests/stress_tests/stress_tests_app_config.yaml b/release/nightly_tests/stress_tests/stress_tests_app_config.yaml index 1f264f9fa1e44..66c99bb3bfe5a 100644 --- a/release/nightly_tests/stress_tests/stress_tests_app_config.yaml +++ b/release/nightly_tests/stress_tests/stress_tests_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} debian_packages: [] python: diff --git a/release/release_logs/1.7.0/benchmarks/many_actors.txt b/release/release_logs/1.7.0/benchmarks/many_actors.txt new file mode 100644 index 0000000000000..2995df9b7f18d --- /dev/null +++ b/release/release_logs/1.7.0/benchmarks/many_actors.txt @@ -0,0 +1,10 @@ +{ + "actors_per_second": 333.2797984180003, + "num_actors": 10000, + "time": 30.0048189163208, + "success": "1", + "_runtime": 43.551865577697754, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_han7mApDaGYvrbvhuLKBSGBz", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/benchmarks/many_nodes.txt b/release/release_logs/1.7.0/benchmarks/many_nodes.txt new file mode 100644 index 0000000000000..d6d5a3c0b6631 --- /dev/null +++ b/release/release_logs/1.7.0/benchmarks/many_nodes.txt @@ -0,0 +1,10 @@ +{ + "tasks_per_second": 3.224712885579051, + "num_tasks": 1000, + "time": 610.1051273345947, + "success": "1", + "_runtime": 620.4832813739777, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_6f82dxdGaxTV4uZNSamTYGLY", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/benchmarks/many_pgs.txt b/release/release_logs/1.7.0/benchmarks/many_pgs.txt new file mode 100644 index 0000000000000..560c050dcecb4 --- /dev/null +++ b/release/release_logs/1.7.0/benchmarks/many_pgs.txt @@ -0,0 +1,10 @@ +{ + "pgs_per_second": 17.06879130613137, + "num_pgs": 1000, + "time": 58.586456537246704, + "success": "1", + "_runtime": 69.5553240776062, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_gr3X2VEThCAQrtiHrJRd8yxW", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/benchmarks/many_tasks.txt b/release/release_logs/1.7.0/benchmarks/many_tasks.txt new file mode 100644 index 0000000000000..fa9c7d8d41db2 --- /dev/null +++ b/release/release_logs/1.7.0/benchmarks/many_tasks.txt @@ -0,0 +1,10 @@ +{ + "tasks_per_second": 27.508657888123608, + "num_tasks": 10000, + "time": 663.5219151973724, + "success": "1", + "_runtime": 674.2678966522217, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_XCJkRqS4HkuHLXehx7i6Fwvc", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/microbenchmark.txt b/release/release_logs/1.7.0/microbenchmark.txt new file mode 100644 index 0000000000000..b5fa29117583d --- /dev/null +++ b/release/release_logs/1.7.0/microbenchmark.txt @@ -0,0 +1,134 @@ +{ + "single_client_get_calls": [ + 34647.91400708946, + 311.7390971967917 + ], + "single_client_put_calls": [ + 58969.83872190603, + 869.618205663433 + ], + "multi_client_put_calls": [ + 199832.5755298421, + 2482.9205035774476 + ], + "single_client_get_calls_Plasma_Store": [ + 7082.757370159696, + 146.62873820799672 + ], + "single_client_put_calls_Plasma_Store": [ + 6321.65654587901, + 11.077913617295936 + ], + "multi_client_put_calls_Plasma_Store": [ + 9186.218655830648, + 112.23231532820908 + ], + "single_client_put_gigabytes": [ + 20.299125005168346, + 5.063681202623047 + ], + "single_client_tasks_and_get_batch": [ + 13.14018865978927, + 0.3152301478634011 + ], + "multi_client_put_gigabytes": [ + 36.56441662881655, + 1.843382220404724 + ], + "single_client_get_object_containing_10k_refs": [ + 10.351906653488715, + 0.23442465466734483 + ], + "single_client_tasks_sync": [ + 1257.4155346823063, + 16.879731074181798 + ], + "single_client_tasks_async": [ + 13436.707639489237, + 467.0229967004351 + ], + "multi_client_tasks_async": [ + 37893.82918345513, + 2501.210898297811 + ], + "1_1_actor_calls_sync": [ + 2018.517206134362, + 4.133444448098185 + ], + "1_1_actor_calls_async": [ + 5107.498479502846, + 155.05763494606228 + ], + "1_1_actor_calls_concurrent": [ + 4974.868578485068, + 46.89895438701842 + ], + "1_n_actor_calls_async": [ + 13035.656413458306, + 263.67959962428176 + ], + "n_n_actor_calls_async": [ + 42424.91241384691, + 909.2063842725172 + ], + "n_n_actor_calls_with_arg_async": [ + 2910.8727809194884, + 142.55651461439174 + ], + "1_1_async_actor_calls_sync": [ + 1434.0111494545497, + 15.145616176257736 + ], + "1_1_async_actor_calls_async": [ + 3227.631490168903, + 74.52309737428871 + ], + "1_1_async_actor_calls_with_args_async": [ + 2417.18007329992, + 42.010241468147406 + ], + "1_n_async_actor_calls_async": [ + 13212.476889889944, + 280.91562344862103 + ], + "n_n_async_actor_calls_async": [ + 32212.030653578477, + 4172.2556150359205 + ], + "client__get_calls": [ + 1518.5267029642152, + 18.33838666361156 + ], + "client__put_calls": [ + 869.7170835067376, + 8.603084105450836 + ], + "client__put_gigabytes": [ + 0.11768745420143228, + 0.002542373184018965 + ], + "client__tasks_and_put_batch": [ + 58861.12144186892, + 546.7701167395176 + ], + "client__1_1_actor_calls_sync": [ + 472.8343418119895, + 6.16968890867776 + ], + "client__1_1_actor_calls_async": [ + 742.6478263697102, + 2.886810073788351 + ], + "client__1_1_actor_calls_concurrent": [ + 729.3572241473628, + 19.903703549912592 + ], + "client__tasks_and_get_batch": [ + 0.6990944804839968, + 0.00738047968242822 + ], + "_runtime": 558.9188287258148, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_AHVUzrAzUMiLZ4p9EEAbL68s", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/scalability/object_store.txt b/release/release_logs/1.7.0/scalability/object_store.txt new file mode 100644 index 0000000000000..6917229b88dc5 --- /dev/null +++ b/release/release_logs/1.7.0/scalability/object_store.txt @@ -0,0 +1,10 @@ +{ + "broadcast_time": 611.015479593, + "object_size": 1073741824, + "num_nodes": 50, + "success": "1", + "_runtime": 620.4363269805908, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_Chj4PHZqrEjbzc8Ni4RY1Fev", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/scalability/single_node.txt b/release/release_logs/1.7.0/scalability/single_node.txt new file mode 100644 index 0000000000000..c868fa3c8eb4e --- /dev/null +++ b/release/release_logs/1.7.0/scalability/single_node.txt @@ -0,0 +1,16 @@ +{ + "args_time": 17.256289814000013, + "num_args": 10000, + "returns_time": 5.854934190999984, + "num_returns": 3000, + "get_time": 25.88724605799996, + "queued_time": 140.99555420300004, + "num_queued": 1000000, + "large_object_time": 294.249499343, + "large_object_size": 107374182400, + "success": "1", + "_runtime": 528.4356288909912, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_ELgpggWSHiqhksawLcz4urEP", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/stress_tests/dead_actors.txt b/release/release_logs/1.7.0/stress_tests/dead_actors.txt new file mode 100644 index 0000000000000..ab763e4173b75 --- /dev/null +++ b/release/release_logs/1.7.0/stress_tests/dead_actors.txt @@ -0,0 +1,11 @@ +{ + "success": 1, + "total_time": 130.34314274787903, + "avg_iteration_time": 1.303428828716278, + "max_iteration_time": 3.651247501373291, + "min_iteration_time": 0.09438443183898926, + "_runtime": 902.0143933296204, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_pxDnaxYFzDNsyifjJNV1qhqs", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/stress_tests/many_tasks.txt b/release/release_logs/1.7.0/stress_tests/many_tasks.txt new file mode 100644 index 0000000000000..a0244c5b28489 --- /dev/null +++ b/release/release_logs/1.7.0/stress_tests/many_tasks.txt @@ -0,0 +1,19 @@ +{ + "success": 1, + "stage_0_time": 5.256332874298096, + "stage_1_time": 174.50774693489075, + "stage_1_avg_iteration_time": 17.450765538215638, + "stage_1_max_iteration_time": 17.627604961395264, + "stage_1_min_iteration_time": 17.23277997970581, + "stage_2_time": 268.01243686676025, + "stage_2_avg_iteration_time": 53.60213441848755, + "stage_2_max_iteration_time": 59.097413063049316, + "stage_2_min_iteration_time": 48.71518564224243, + "stage_3_creation_time": 0.5777060985565186, + "stage_3_time": 2066.70570230484, + "stage_4_spread": 3.2197082901427945, + "_runtime": 5045.744384527206, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_b8v2V4Tr7vwee6tCDjTjdXLL", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/stress_tests/placement_group.txt b/release/release_logs/1.7.0/stress_tests/placement_group.txt new file mode 100644 index 0000000000000..cbe7c99c54a04 --- /dev/null +++ b/release/release_logs/1.7.0/stress_tests/placement_group.txt @@ -0,0 +1,9 @@ +{ + "success": 1, + "avg_pg_create_time_ms": 0.9874122837809874, + "avg_pg_remove_time_ms": 4.4027920900909265, + "_runtime": 458.8596382141113, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_7uQL743cWCzdDT3ZYTpRDETi", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/serve_tests/serve_tests.yaml b/release/serve_tests/serve_tests.yaml index 06edd31be95eb..4362ca296d909 100644 --- a/release/serve_tests/serve_tests.yaml +++ b/release/serve_tests/serve_tests.yaml @@ -27,6 +27,7 @@ - name: serve_micro_benchmark cluster: app_config: app_config.yaml + # 16 CPUS compute_template: compute_tpl_single_node.yaml run: @@ -34,5 +35,19 @@ long_running: False script: python workloads/serve_micro_benchmark.py + smoke_test: + timeout: 600 + +- name: serve_cluster_fault_tolerance + cluster: + app_config: app_config.yaml + # 16 CPUS + compute_template: compute_tpl_single_node.yaml + + run: + timeout: 7200 + long_running: False + script: python workloads/serve_cluster_fault_tolerance.py + smoke_test: timeout: 600 \ No newline at end of file diff --git a/release/serve_tests/workloads/serve_cluster_fault_tolerance.py b/release/serve_tests/workloads/serve_cluster_fault_tolerance.py new file mode 100644 index 0000000000000..431c78b9c5df3 --- /dev/null +++ b/release/serve_tests/workloads/serve_cluster_fault_tolerance.py @@ -0,0 +1,119 @@ +""" +Test that a serve deployment can recover from cluster failures by resuming +from checkpoints of external source, such as s3. + +For product testing, we skip the part of actually starting new cluster as +it's Job Manager's responsibility, and only re-deploy to the same cluster +with remote checkpoint. +""" + +import click +import time +import requests +import uuid +import os + +from serve_test_cluster_utils import setup_local_single_node_cluster + +from serve_test_utils import (save_test_results) + +import ray +from ray import serve +from ray.serve.utils import logger + +# Deployment configs +DEFAULT_NUM_REPLICAS = 4 +DEFAULT_MAX_BATCH_SIZE = 16 + + +def request_with_retries(endpoint, timeout=3): + start = time.time() + while True: + try: + return requests.get( + "http://127.0.0.1:8000" + endpoint, timeout=timeout) + except requests.RequestException: + if time.time() - start > timeout: + raise TimeoutError + time.sleep(0.1) + + +@click.command() +def main(): + # Setup local cluster, note this cluster setup is the same for both + # local and product ray cluster env. + # Each test uses different ray namespace, thus kv storage key for each + # checkpoint is different to avoid collision. + namespace = uuid.uuid4().hex + + # IS_SMOKE_TEST is set by args of releaser's e2e.py + smoke_test = os.environ.get("IS_SMOKE_TEST", "1") + if smoke_test == "1": + checkpoint_path = "file://checkpoint.db" + else: + checkpoint_path = "s3://serve-nightly-tests/fault-tolerant-test-checkpoint" # noqa: E501 + + _, cluster = setup_local_single_node_cluster( + 1, checkpoint_path=checkpoint_path, namespace=namespace) + + # Deploy for the first time + @serve.deployment(name="echo", num_replicas=DEFAULT_NUM_REPLICAS) + class Echo: + def __init__(self): + return True + + def __call__(self, request): + return "hii" + + Echo.deploy() + + # Ensure endpoint is working + for _ in range(5): + response = request_with_retries("/echo/", timeout=3) + assert response.text == "hii" + + logger.info("Initial deployment successful with working endpoint.") + + # Kill current cluster, recover from remote checkpoint and ensure endpoint + # is still available with expected results + + ray.kill(serve.api._global_client._controller, no_restart=True) + ray.shutdown() + cluster.shutdown() + serve.api._set_global_client(None) + + # Start another ray cluster with same namespace to resume from previous + # checkpoints with no new deploy() call. + setup_local_single_node_cluster( + 1, checkpoint_path=checkpoint_path, namespace=namespace) + + for _ in range(5): + response = request_with_retries("/echo/", timeout=3) + assert response.text == "hii" + + logger.info("Deployment recovery from s3 checkpoint is successful " + "with working endpoint.") + + # Delete dangling checkpoints. If script failed before this step, it's up + # to the TTL policy on s3 to clean up, but won't lead to collision with + # subsequent tests since each test run in different uuid namespace. + serve.shutdown() + ray.shutdown() + cluster.shutdown() + + # Checkpoints in S3 bucket are moved after 7 days with explicit lifecycle + # rules. Each checkpoint is ~260 Bytes in size from this test. + + # Save results + save_test_results( + { + "result": "success" + }, + default_output_file="/tmp/serve_cluster_fault_tolerance.json") + + +if __name__ == "__main__": + main() + import pytest + import sys + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/release/serve_tests/workloads/serve_test_cluster_utils.py b/release/serve_tests/workloads/serve_test_cluster_utils.py index 22e4e30cfdf35..3d9ccc44ae7f5 100644 --- a/release/serve_tests/workloads/serve_test_cluster_utils.py +++ b/release/serve_tests/workloads/serve_test_cluster_utils.py @@ -6,13 +6,16 @@ from ray.cluster_utils import Cluster from ray.serve.utils import logger from ray.serve.config import DeploymentMode - +from ray.serve.constants import DEFAULT_CHECKPOINT_PATH # Cluster setup configs NUM_CPU_PER_NODE = 10 NUM_CONNECTIONS = 10 -def setup_local_single_node_cluster(num_nodes): +def setup_local_single_node_cluster( + num_nodes: int, + checkpoint_path: str = DEFAULT_CHECKPOINT_PATH, + namespace="serve"): """Setup ray cluster locally via ray.init() and Cluster() Each actor is simulated in local process on single node, @@ -21,19 +24,23 @@ def setup_local_single_node_cluster(num_nodes): cluster = Cluster() for i in range(num_nodes): cluster.add_node( - redis_port=6379 if i == 0 else None, + redis_port=6380 if i == 0 else None, num_cpus=NUM_CPU_PER_NODE, num_gpus=0, resources={str(i): 2}, ) - ray.init(address=cluster.address, dashboard_host="0.0.0.0") + ray.init( + address=cluster.address, dashboard_host="0.0.0.0", namespace=namespace) serve_client = serve.start( - http_options={"location": DeploymentMode.EveryNode}) + detached=True, + http_options={"location": DeploymentMode.EveryNode}, + _checkpoint_path=checkpoint_path, + ) - return serve_client + return serve_client, cluster -def setup_anyscale_cluster(): +def setup_anyscale_cluster(checkpoint_path: str = DEFAULT_CHECKPOINT_PATH): """Setup ray cluster at anyscale via ray.client() Note this is by default large scale and should be kicked off @@ -44,7 +51,9 @@ def setup_anyscale_cluster(): # ray.client().env({}).connect() ray.init(address="auto") serve_client = serve.start( - http_options={"location": DeploymentMode.EveryNode}) + http_options={"location": DeploymentMode.EveryNode}, + _checkpoint_path=checkpoint_path, + ) return serve_client diff --git a/release/util/pip_download_test.sh b/release/util/pip_download_test.sh index 6ab91732ab255..c1d998b44e2b1 100755 --- a/release/util/pip_download_test.sh +++ b/release/util/pip_download_test.sh @@ -56,7 +56,7 @@ do else failed=true fi - if sh sanity_check_cpp.sh; then + if bash sanity_check_cpp.sh; then echo "PYTHON ${PYTHON_VERSION} succeed sanity check C++." else cpp_failed=true diff --git a/rllib/BUILD b/rllib/BUILD index f4c527bbb8099..b09e149b14a22 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -876,25 +876,13 @@ py_test( srcs = ["agents/ddpg/tests/test_ddpg.py"] ) -# DQNTrainer/SimpleQTrainer +# DQNTrainer py_test( name = "test_dqn", tags = ["team:ml", "trainers_dir"], size = "large", srcs = ["agents/dqn/tests/test_dqn.py"] ) -py_test( - name = "test_r2d2", - tags = ["team:ml", "trainers_dir"], - size = "large", - srcs = ["agents/dqn/tests/test_r2d2.py"] -) -py_test( - name = "test_simple_q", - tags = ["team:ml", "trainers_dir"], - size = "medium", - srcs = ["agents/dqn/tests/test_simple_q.py"] -) # Dreamer py_test( @@ -1002,6 +990,22 @@ py_test( srcs = ["agents/qmix/tests/test_qmix.py"] ) +# R2D2Trainer +py_test( + name = "test_r2d2", + tags = ["team:ml", "trainers_dir"], + size = "large", + srcs = ["agents/dqn/tests/test_r2d2.py"] +) + +# RNNSACTrainer +py_test( + name = "test_rnnsac", + tags = ["team:ml", "trainers_dir"], + size = "medium", + srcs = ["agents/sac/tests/test_rnnsac.py"] +) + # SACTrainer py_test( name = "test_sac", @@ -1010,6 +1014,14 @@ py_test( srcs = ["agents/sac/tests/test_sac.py"] ) +# SimpleQTrainer +py_test( + name = "test_simple_q", + tags = ["team:ml", "trainers_dir"], + size = "medium", + srcs = ["agents/dqn/tests/test_simple_q.py"] +) + # TD3Trainer py_test( name = "test_td3", @@ -1328,18 +1340,38 @@ py_test( # -------------------------------------------------------------------- sh_test( - name = "env/tests/test_local_inference", + name = "env/tests/test_local_inference_cartpole", tags = ["team:ml", "env"], size = "medium", - srcs = ["env/tests/test_local_inference.sh"], + srcs = ["env/tests/test_policy_client_server_setup.sh"], + args = ["local", "cartpole"], data = glob(["examples/serving/*.py"]), ) sh_test( - name = "env/tests/test_remote_inference", + name = "env/tests/test_remote_inference_cartpole", tags = ["team:ml", "env"], size = "medium", - srcs = ["env/tests/test_remote_inference.sh"], + srcs = ["env/tests/test_policy_client_server_setup.sh"], + args = ["remote", "cartpole"], + data = glob(["examples/serving/*.py"]), +) + +sh_test( + name = "env/tests/test_local_inference_unity3d", + tags = ["team:ml", "env"], + size = "medium", + srcs = ["env/tests/test_policy_client_server_setup.sh"], + args = ["local", "unity3d"], + data = glob(["examples/serving/*.py"]), +) + +sh_test( + name = "env/tests/test_remote_inference_unity3d", + tags = ["team:ml", "env"], + size = "medium", + srcs = ["env/tests/test_policy_client_server_setup.sh"], + args = ["remote", "unity3d"], data = glob(["examples/serving/*.py"]), ) @@ -1350,6 +1382,13 @@ py_test( srcs = ["env/tests/test_record_env_wrapper.py"] ) +py_test( + name = "env/tests/test_remote_worker_envs", + tags = ["team:ml", "env"], + size = "medium", + srcs = ["env/tests/test_remote_worker_envs.py"] +) + py_test( name = "env/wrappers/tests/test_unity3d_env", tags = ["team:ml", "env"], @@ -1847,14 +1886,14 @@ py_test( args = ["TestSupportedMultiAgentOffPolicy"] ) -# py_test( -# name = "tests/test_supported_spaces_pg", -# main = "tests/test_supported_spaces.py", -# tags = ["team:ml", "tests_dir", "tests_dir_S"], -# size = "enormous", -# srcs = ["tests/test_supported_spaces.py"], -# args = ["TestSupportedSpacesPG"] -# ) +py_test( + name = "tests/test_supported_spaces_pg", + main = "tests/test_supported_spaces.py", + tags = ["team:ml", "tests_dir", "tests_dir_S"], + size = "large", + srcs = ["tests/test_supported_spaces.py"], + args = ["TestSupportedSpacesPG"] + ) py_test( name = "tests/test_supported_spaces_off_policy", diff --git a/rllib/agents/a3c/a3c_tf_policy.py b/rllib/agents/a3c/a3c_tf_policy.py index cbc5bbbd797d6..6e7b362a4fd95 100644 --- a/rllib/agents/a3c/a3c_tf_policy.py +++ b/rllib/agents/a3c/a3c_tf_policy.py @@ -111,7 +111,7 @@ def grad_stats(policy: Policy, train_batch: SampleBatch, "grad_gnorm": tf.linalg.global_norm(grads), "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], - policy.model.value_function()), + policy.model.value_function()) } diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index 99172adb814e0..ea44f4767cfdc 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -72,19 +72,25 @@ def actor_critic_loss(policy: Policy, model: ModelV2, total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] - entropy * policy.config["entropy_coeff"]) - policy.entropy = entropy - policy.pi_err = pi_err - policy.value_err = value_err + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["entropy"] = entropy + model.tower_stats["pi_err"] = pi_err + model.tower_stats["value_err"] = value_err return total_loss def loss_and_entropy_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: + return { - "policy_entropy": policy.entropy, - "policy_loss": policy.pi_err, - "vf_loss": policy.value_err, + "policy_entropy": torch.mean( + torch.stack(policy.get_tower_stats("entropy"))), + "policy_loss": torch.mean( + torch.stack(policy.get_tower_stats("pi_err"))), + "vf_loss": torch.mean( + torch.stack(policy.get_tower_stats("value_err"))), } diff --git a/rllib/agents/a3c/tests/test_a2c.py b/rllib/agents/a3c/tests/test_a2c.py index 2394b3f5812b7..4c8a259245adc 100644 --- a/rllib/agents/a3c/tests/test_a2c.py +++ b/rllib/agents/a3c/tests/test_a2c.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.a3c as a3c from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestA2C(unittest.TestCase): @@ -29,6 +29,7 @@ def test_a2c_compilation(self): trainer = a3c.A2CTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop() @@ -37,7 +38,9 @@ def test_a2c_exec_impl(ray_start_regular): config = {"min_iter_time_s": 0} for _ in framework_iterator(config): trainer = a3c.A2CTrainer(env="CartPole-v0", config=config) - assert isinstance(trainer.train(), dict) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() @@ -48,7 +51,9 @@ def test_a2c_exec_impl_microbatch(ray_start_regular): } for _ in framework_iterator(config): trainer = a3c.A2CTrainer(env="CartPole-v0", config=config) - assert isinstance(trainer.train(), dict) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/agents/a3c/tests/test_a3c.py b/rllib/agents/a3c/tests/test_a3c.py index 6ffbab01f955f..59147f213a7a5 100644 --- a/rllib/agents/a3c/tests/test_a3c.py +++ b/rllib/agents/a3c/tests/test_a3c.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.a3c as a3c from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestA3C(unittest.TestCase): @@ -31,6 +31,7 @@ def test_a3c_compilation(self): trainer = a3c.A3CTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action( trainer, include_state=config["model"]["use_lstm"]) diff --git a/rllib/agents/ars/tests/test_ars.py b/rllib/agents/ars/tests/test_ars.py index b6bb3c8df7277..a78353de44ac4 100644 --- a/rllib/agents/ars/tests/test_ars.py +++ b/rllib/agents/ars/tests/test_ars.py @@ -7,9 +7,16 @@ class TestARS(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init(num_cpus=3) + + @classmethod + def tearDownClass(cls): + ray.shutdown() + def test_ars_compilation(self): """Test whether an ARSTrainer can be built on all frameworks.""" - ray.init(num_cpus=3) config = ars.DEFAULT_CONFIG.copy() # Keep it simple. config["model"]["fcnet_hiddens"] = [10] @@ -30,7 +37,6 @@ def test_ars_compilation(self): check_compute_single_action(trainer) trainer.stop() - ray.shutdown() if __name__ == "__main__": diff --git a/rllib/agents/cql/cql.py b/rllib/agents/cql/cql.py index 3c9c026c7bc34..19f1573e29ba9 100644 --- a/rllib/agents/cql/cql.py +++ b/rllib/agents/cql/cql.py @@ -14,10 +14,11 @@ from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \ UpdateTargetNetwork from ray.rllib.offline.shuffled_input import ShuffledInput -from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import merge_dicts from ray.rllib.utils.framework import try_import_tf, try_import_tfp +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import TrainerConfigDict tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/cql/cql_torch_policy.py b/rllib/agents/cql/cql_torch_policy.py index fed6470dc585e..f62b23069a4fd 100644 --- a/rllib/agents/cql/cql_torch_policy.py +++ b/rllib/agents/cql/cql_torch_policy.py @@ -14,12 +14,12 @@ build_sac_model_and_action_dist, optimizer_fn, ComputeTDErrorMixin, \ TargetNetworkMixin, setup_late_mixins, action_distribution_fn from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import LocalOptimizer, TensorType, \ TrainerConfigDict from ray.rllib.utils.torch_ops import apply_grad_clipping, \ @@ -250,23 +250,29 @@ def cql_loss(policy: Policy, model: ModelV2, critic_loss[1].backward(retain_graph=False) policy.critic_optims[1].step() - # Save for stats function. - policy.q_t = q_t_selected - policy.policy_t = policy_t - policy.log_pis_t = log_pis_t - model.td_error = td_error - policy.actor_loss = actor_loss - policy.critic_loss = critic_loss - policy.alpha_loss = alpha_loss - policy.log_alpha_value = model.log_alpha - policy.alpha_value = alpha - policy.target_entropy = model.target_entropy - # CQL Stats. - policy.cql_loss = cql_loss + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + # SAC stats. + model.tower_stats["q_t"] = q_t_selected + model.tower_stats["policy_t"] = policy_t + model.tower_stats["log_pis_t"] = log_pis_t + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + model.tower_stats["alpha_loss"] = alpha_loss + model.tower_stats["log_alpha_value"] = model.log_alpha + model.tower_stats["alpha_value"] = alpha + model.tower_stats["target_entropy"] = model.target_entropy + # CQL stats. + model.tower_stats["cql_loss"] = cql_loss + + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error + if use_lagrange: - policy.log_alpha_prime_value = model.log_alpha_prime[0] - policy.alpha_prime_value = alpha_prime - policy.alpha_prime_loss = alpha_prime_loss + model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0] + model.tower_stats["alpha_prime_value"] = alpha_prime + model.tower_stats["alpha_prime_loss"] = alpha_prime_loss if obs.shape[0] == policy.config["train_batch_size"]: policy.alpha_prime_optim.zero_grad() @@ -274,22 +280,27 @@ def cql_loss(policy: Policy, model: ModelV2, policy.alpha_prime_optim.step() # Return all loss terms corresponding to our optimizers. - if use_lagrange: - return tuple([policy.actor_loss] + policy.critic_loss + - [policy.alpha_loss] + [policy.alpha_prime_loss]) - return tuple([policy.actor_loss] + policy.critic_loss + - [policy.alpha_loss]) + return tuple([actor_loss] + critic_loss + [alpha_loss] + + ([alpha_prime_loss] if use_lagrange else [])) def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: - sac_dict = stats(policy, train_batch) - sac_dict["cql_loss"] = torch.mean(torch.stack(policy.cql_loss)) + # Get SAC loss stats. + stats_dict = stats(policy, train_batch) + + # Add CQL loss stats to the dict. + stats_dict["cql_loss"] = torch.mean( + torch.stack(*policy.get_tower_stats("cql_loss"))) + if policy.config["lagrangian"]: - sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value - sac_dict["alpha_prime_value"] = policy.alpha_prime_value - sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss - return sac_dict + stats_dict["log_alpha_prime_value"] = torch.mean( + torch.stack(policy.get_tower_stats("log_alpha_prime_value"))) + stats_dict["alpha_prime_value"] = torch.mean( + torch.stack(policy.get_tower_stats("alpha_prime_value"))) + stats_dict["alpha_prime_loss"] = torch.mean( + torch.stack(policy.get_tower_stats("alpha_prime_loss"))) + return stats_dict def cql_optimizer_fn(policy: Policy, config: TrainerConfigDict) -> \ diff --git a/rllib/agents/cql/tests/test_cql.py b/rllib/agents/cql/tests/test_cql.py index 9f8466a220e00..7e3ef58896f67 100644 --- a/rllib/agents/cql/tests/test_cql.py +++ b/rllib/agents/cql/tests/test_cql.py @@ -7,7 +7,7 @@ import ray.rllib.agents.cql as cql from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -69,10 +69,13 @@ def test_cql_compilation(self): for fw in framework_iterator(config): trainer = cql.CQLTrainer(config=config) for i in range(num_iterations): - results = trainer.train().get("evaluation") - if results: + results = trainer.train() + check_train_results(results) + print(results) + eval_results = results.get("evaluation") + if eval_results: print(f"iter={trainer.iteration} " - f"R={results['episode_reward_mean']}") + f"R={eval_results['episode_reward_mean']}") check_compute_single_action(trainer) diff --git a/rllib/agents/ddpg/ddpg_tf_model.py b/rllib/agents/ddpg/ddpg_tf_model.py index 53d2d666dc60c..f3c4a3ece6e9b 100644 --- a/rllib/agents/ddpg/ddpg_tf_model.py +++ b/rllib/agents/ddpg/ddpg_tf_model.py @@ -1,6 +1,6 @@ import numpy as np import gym -from typing import List +from typing import List, Optional from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.framework import try_import_tf @@ -29,9 +29,9 @@ def __init__( model_config: ModelConfigDict, name: str, # Extra DDPGActionModel args: - actor_hiddens: List[int] = [256, 256], + actor_hiddens: Optional[List[int]] = None, actor_hidden_activation: str = "relu", - critic_hiddens: List[int] = [256, 256], + critic_hiddens: Optional[List[int]] = None, critic_hidden_activation: str = "relu", twin_q: bool = False, add_layer_norm: bool = False): @@ -48,6 +48,12 @@ def __init__( should be defined in subclasses of DDPGActionModel. """ + if actor_hiddens is None: + actor_hiddens = [256, 256] + + if critic_hiddens is None: + critic_hiddens = [256, 256] + super(DDPGTFModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index 8c24a84c04a5e..d3c295feba940 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -28,7 +28,7 @@ from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ - LocalOptimizer, ModelGradients, PolicyID + LocalOptimizer, ModelGradients from ray.util.debug import log_once tf1, tf, tfv = try_import_tf() @@ -429,17 +429,17 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, TargetNetworkMixin.__init__(policy, config) -def validate_spaces(pid: PolicyID, observation_space: gym.spaces.Space, +def validate_spaces(policy: Policy, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: if not isinstance(action_space, Box): raise UnsupportedSpaceException( "Action space ({}) of {} is not supported for " - "DDPG.".format(action_space, pid)) + "DDPG.".format(action_space, policy)) elif len(action_space.shape) > 1: raise UnsupportedSpaceException( "Action space ({}) of {} has multiple dimensions " - "{}. ".format(action_space, pid, action_space.shape) + + "{}. ".format(action_space, policy, action_space.shape) + "Consider reshaping this into a single dimension, " "using a Tuple action space, or the multi-agent API.") diff --git a/rllib/agents/ddpg/ddpg_torch_model.py b/rllib/agents/ddpg/ddpg_torch_model.py index 2297ee0b2a815..615e0ea8b5814 100644 --- a/rllib/agents/ddpg/ddpg_torch_model.py +++ b/rllib/agents/ddpg/ddpg_torch_model.py @@ -1,6 +1,6 @@ import numpy as np import gym -from typing import List, Dict, Union +from typing import List, Dict, Union, Optional from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 @@ -31,9 +31,9 @@ def __init__( model_config: ModelConfigDict, name: str, # Extra DDPGActionModel args: - actor_hiddens: List[int] = [256, 256], + actor_hiddens: Optional[List[int]] = None, actor_hidden_activation: str = "relu", - critic_hiddens: List[int] = [256, 256], + critic_hiddens: Optional[List[int]] = None, critic_hidden_activation: str = "relu", twin_q: bool = False, add_layer_norm: bool = False): @@ -51,6 +51,12 @@ def __init__( only defines the layers for the output heads. Those layers for forward() should be defined in subclasses of DDPGTorchModel. """ + if actor_hiddens is None: + actor_hiddens = [256, 256] + + if critic_hiddens is None: + critic_hiddens = [256, 256] + nn.Module.__init__(self) super(DDPGTorchModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index ef22a5e75fd47..c6eb6bddbda6e 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -172,18 +172,17 @@ def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _, [actor_loss, critic_loss] = model.custom_loss( [actor_loss, critic_loss], input_dict) - # Store values for stats function. - policy.q_t = q_t - policy.actor_loss = actor_loss - policy.critic_loss = critic_loss - - # Store td-error in model, such that for multi-GPU, we do not override - # them during the parallel loss phase. TD-error tensor in final stats - # can then be concatenated and retrieved for each individual batch item. - model.td_error = td_error + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["q_t"] = q_t + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error # Return two loss terms (corresponding to the two optimizers, we create). - return policy.actor_loss, policy.critic_loss + return actor_loss, critic_loss def make_ddpg_optimizers(policy: Policy, @@ -217,12 +216,16 @@ def apply_gradients_fn(policy: Policy, gradients: GradInfoDict) -> None: def build_ddpg_stats(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]: + + q_t = torch.stack(policy.get_tower_stats("q_t")) stats = { - "actor_loss": policy.actor_loss, - "critic_loss": policy.critic_loss, - "mean_q": torch.mean(policy.q_t), - "max_q": torch.max(policy.q_t), - "min_q": torch.min(policy.q_t), + "actor_loss": torch.mean( + torch.stack(policy.get_tower_stats("actor_loss"))), + "critic_loss": torch.mean( + torch.stack(policy.get_tower_stats("critic_loss"))), + "mean_q": torch.mean(q_t), + "max_q": torch.max(q_t), + "min_q": torch.min(q_t), } return stats @@ -251,8 +254,8 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # (one TD-error value per item in batch to update PR weights). loss_fn(self, self.model, None, input_dict) - # Self.td_error is set within actor_critic_loss call. - return self.model.td_error + # `self.model.td_error` is set within actor_critic_loss call. + return self.model.tower_stats["td_error"] self.compute_td_error = compute_td_error diff --git a/rllib/agents/ddpg/tests/test_apex_ddpg.py b/rllib/agents/ddpg/tests/test_apex_ddpg.py index 61556fb9b961b..16ebab9a1f9ae 100644 --- a/rllib/agents/ddpg/tests/test_apex_ddpg.py +++ b/rllib/agents/ddpg/tests/test_apex_ddpg.py @@ -4,7 +4,7 @@ import ray import ray.rllib.agents.ddpg.apex as apex_ddpg from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestApexDDPG(unittest.TestCase): @@ -40,7 +40,9 @@ def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self): check(scale, [0.0] + expected) for _ in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) # Test again per-worker scale distribution diff --git a/rllib/agents/ddpg/tests/test_ddpg.py b/rllib/agents/ddpg/tests/test_ddpg.py index be404e720d48e..7f72e03d0e30c 100644 --- a/rllib/agents/ddpg/tests/test_ddpg.py +++ b/rllib/agents/ddpg/tests/test_ddpg.py @@ -13,7 +13,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator from ray.rllib.utils.torch_ops import convert_to_torch_tensor tf1, tf, tfv = try_import_tf() @@ -45,6 +45,7 @@ def test_ddpg_compilation(self): trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0") for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) # Ensure apply_gradient_fn is being called and updating global_step @@ -288,8 +289,9 @@ def test_ddpg_loss_function(self): elif fw == "torch": loss_torch(policy, policy.model, None, input_) - c, a, t = policy.critic_loss, policy.actor_loss, \ - policy.model.td_error + c, a, t = policy.get_tower_stats("critic_loss")[0], \ + policy.get_tower_stats("actor_loss")[0], \ + policy.get_tower_stats("td_error")[0] # Check pure loss values. check(c, expect_c) check(a, expect_a) diff --git a/rllib/agents/ddpg/tests/test_td3.py b/rllib/agents/ddpg/tests/test_td3.py index 75b84e4ddc57e..a542cf5a1574d 100644 --- a/rllib/agents/ddpg/tests/test_td3.py +++ b/rllib/agents/ddpg/tests/test_td3.py @@ -5,7 +5,7 @@ import ray.rllib.agents.ddpg.td3 as td3 from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() @@ -30,6 +30,7 @@ def test_td3_compilation(self): num_iterations = 1 for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 74afc564f1708..49c24b07ed3e7 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -33,6 +33,7 @@ from ray.rllib.utils import merge_dicts from ray.rllib.utils.actors import create_colocated from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.typing import SampleBatchType from ray.tune.trainable import Trainable from ray.tune.utils.placement_groups import PlacementGroupFactory @@ -227,7 +228,7 @@ def add_apex_metrics(result: dict) -> dict: result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), - "learner": copy.deepcopy(learner_thread.stats), + LEARNER_INFO: copy.deepcopy(learner_thread.learner_info), "replay_shard_0": replay_stats, }) return result diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index ac4b8f0dbb8e5..5f1eadf020a39 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -25,7 +25,8 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork, \ MultiGPUTrainOneStep -from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -200,8 +201,17 @@ def update_prio(item): td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error")) samples.policy_batches[policy_id].set_get_interceptor(None) - prio_dict[policy_id] = (samples.policy_batches[policy_id] - .get("batch_indexes"), td_error) + batch_indices = samples.policy_batches[policy_id].get( + "batch_indexes") + # In case the buffer stores sequences, TD-error could already + # be calculated per sequence chunk. + if len(batch_indices) != len(td_error): + T = local_replay_buffer.replay_sequence_length + assert len(batch_indices) > len( + td_error) and len(batch_indices) % T == 0 + batch_indices = batch_indices.reshape([-1, T])[:, 0] + assert len(batch_indices) == len(td_error) + prio_dict[policy_id] = (batch_indices, td_error) local_replay_buffer.update_priorities(prio_dict) return info_dict diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index d060a1ce4012a..a7826d0da489c 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -121,7 +121,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # Do forward pass on loss to update td error attribute build_q_losses(self, self.model, None, input_dict) - return self.q_loss.td_error + return self.model.tower_stats["q_loss"].td_error self.compute_td_error = compute_td_error @@ -216,8 +216,9 @@ def get_distribution_inputs_and_class( is_training=is_training) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals - policy.q_values = q_vals - return policy.q_values, TorchCategorical, [] # state-out + model.tower_stats["q_values"] = q_vals + + return q_vals, TorchCategorical, [] # state-out def build_q_losses(policy: Policy, model, _, @@ -286,19 +287,21 @@ def build_q_losses(policy: Policy, model, _, q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1) - policy.q_loss = QLoss( - q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best, - train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS], - train_batch[SampleBatch.DONES].float(), config["gamma"], - config["n_step"], config["num_atoms"], config["v_min"], - config["v_max"]) + q_loss = QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, + q_probs_tp1_best, train_batch[PRIO_WEIGHTS], + train_batch[SampleBatch.REWARDS], + train_batch[SampleBatch.DONES].float(), config["gamma"], + config["n_step"], config["num_atoms"], config["v_min"], + config["v_max"]) - # Store td-error in model, such that for multi-GPU, we do not override - # them during the parallel loss phase. TD-error tensor in final stats - # can then be concatenated and retrieved for each individual batch item. - model.td_error = policy.q_loss.td_error + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["td_error"] = q_loss.td_error + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["q_loss"] = q_loss - return policy.q_loss.loss + return q_loss.loss def adam_optimizer(policy: Policy, @@ -314,9 +317,16 @@ def adam_optimizer(policy: Policy, def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: - return dict({ - "cur_lr": policy.cur_lr, - }, **policy.q_loss.stats) + stats = {} + for stats_key in policy.model_gpu_towers[0].tower_stats[ + "q_loss"].stats.keys(): + stats[stats_key] = torch.mean( + torch.stack([ + t.tower_stats["q_loss"].stats[stats_key].to(policy.device) + for t in policy.model_gpu_towers if "q_loss" in t.tower_stats + ])) + stats["cur_lr"] = policy.cur_lr + return stats def setup_early_mixins(policy: Policy, obs_space, action_space, @@ -385,7 +395,7 @@ def grad_process_and_td_error_fn(policy: Policy, def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, action_dist) -> Dict[str, TensorType]: - return {"q_values": policy.q_values} + return {"q_values": model.tower_stats["q_values"]} DQNTorchPolicy = build_policy_class( diff --git a/rllib/agents/dqn/learner_thread.py b/rllib/agents/dqn/learner_thread.py index 0f8d6f15bd79a..93bed4b18de5e 100644 --- a/rllib/agents/dqn/learner_thread.py +++ b/rllib/agents/dqn/learner_thread.py @@ -1,9 +1,8 @@ import queue import threading -from ray.rllib.evaluation.metrics import get_learner_stats -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat @@ -33,7 +32,7 @@ def __init__(self, local_worker): self.daemon = True self.weights_updated = False self.stopped = False - self.stats = {} + self.learner_info = {} def run(self): # Switch on eager mode if configured. @@ -49,11 +48,18 @@ def step(self): if replay is not None: prio_dict = {} with self.grad_timer: - grad_out = self.local_worker.learn_on_batch(replay) - for pid, info in grad_out.items(): - td_error = info.get( - "td_error", - info[LEARNER_STATS_KEY].get("td_error")) + # Use LearnerInfoBuilder as a unified way to build the + # final results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same + # structure no matter the setup (multi-GPU, multi-agent, + # minibatch SGD, tf vs torch). + learner_info_builder = LearnerInfoBuilder(num_devices=1) + multi_agent_results = self.local_worker.learn_on_batch( + replay) + for pid, results in multi_agent_results.items(): + learner_info_builder.add_learn_on_batch_results( + results, pid) + td_error = results["td_error"] # Switch off auto-conversion from numpy to torch/tf # tensors for the indices. This may lead to errors # when sent to the buffer for processing @@ -62,7 +68,7 @@ def step(self): prio_dict[pid] = ( replay.policy_batches[pid].get("batch_indexes"), td_error) - self.stats[pid] = get_learner_stats(info) + self.learner_info = learner_info_builder.finalize() self.grad_timer.push_units_processed(replay.count) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) diff --git a/rllib/agents/dqn/r2d2.py b/rllib/agents/dqn/r2d2.py index 7985b55fe305a..d568272e957e9 100644 --- a/rllib/agents/dqn/r2d2.py +++ b/rllib/agents/dqn/r2d2.py @@ -28,7 +28,7 @@ DEFAULT_CONFIG = dqn.DQNTrainer.merge_trainer_configs( dqn.DEFAULT_CONFIG, # See keys in impala.py, which are also supported. { - # Learning rate for adam optimizer + # Learning rate for adam optimizer. "lr": 1e-4, # Discount factor. "gamma": 0.997, @@ -40,8 +40,6 @@ "num_workers": 2, # Batch mode must be complete_episodes. "batch_mode": "complete_episodes", - # R2D2 does not suport n-step > 1 yet! - "n_step": 1, # If True, assume a zero-initialized state input (no matter where in # the episode the sequence is located). @@ -71,7 +69,6 @@ # Size of the replay buffer (in sequences, not timesteps). "buffer_size": 100000, # If True prioritized replay buffer will be used. - # Note: Not supported yet by R2D2! "prioritized_replay": False, # Set automatically: The number of contiguous environment steps to # replay at once. Will be calculated via @@ -91,7 +88,8 @@ def validate_config(config: TrainerConfigDict) -> None: """Checks and updates the config based on settings. - Rewrites rollout_fragment_length to take into account n_step truncation. + Rewrites rollout_fragment_length to take into account burn-in and + max_seq_len truncation. """ if config["replay_sequence_length"] != -1: raise ValueError( @@ -102,15 +100,9 @@ def validate_config(config: TrainerConfigDict) -> None: config["replay_sequence_length"] = \ config["burn_in"] + config["model"]["max_seq_len"] - if config.get("prioritized_replay"): - raise ValueError("Prioritized replay is not supported for R2D2 yet!") - if config.get("batch_mode") != "complete_episodes": raise ValueError("`batch_mode` must be 'complete_episodes'!") - if config["n_step"] > 1: - raise ValueError("`n_step` > 1 not yet supported by R2D2!") - def calculate_rr_weights(config: TrainerConfigDict) -> List[float]: """Calculate the round robin weights for the rollout and train steps""" diff --git a/rllib/agents/dqn/r2d2_tf_policy.py b/rllib/agents/dqn/r2d2_tf_policy.py index 1d72d12e7e25b..d34c35a44976b 100644 --- a/rllib/agents/dqn/r2d2_tf_policy.py +++ b/rllib/agents/dqn/r2d2_tf_policy.py @@ -156,7 +156,7 @@ def r2d2_loss(policy: Policy, model, _, def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, seq_mask)) - # Make sure use the correct time indices: + # Make sure to use the correct time indices: # Q(t) - [gamma * r + Q^(t+1)] q_selected = tf.reshape(q_selected, [B, T])[:, :-1] td_error = q_selected - tf.stop_gradient( @@ -164,7 +164,9 @@ def reduce_mean_valid(t): td_error = td_error * tf.cast(seq_mask, tf.float32) weights = tf.reshape(weights, [B, T])[:, :-1] policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) - policy._td_error = tf.reshape(td_error, [-1]) + # Store the TD-error per time chunk (b/c we need only one mean + # prioritized replay weight per stored sequence). + policy._td_error = tf.reduce_mean(td_error, axis=-1) policy._loss_stats = { "mean_q": reduce_mean_valid(q_selected), "min_q": tf.reduce_min(q_selected), diff --git a/rllib/agents/dqn/r2d2_torch_policy.py b/rllib/agents/dqn/r2d2_torch_policy.py index 894c6dc2fb729..97c34327f7215 100644 --- a/rllib/agents/dqn/r2d2_torch_policy.py +++ b/rllib/agents/dqn/r2d2_torch_policy.py @@ -19,8 +19,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import apply_grad_clipping, FLOAT_MIN, \ - huber_loss, sequence_mask +from ray.rllib.utils.torch_ops import apply_grad_clipping, \ + concat_multi_gpu_td_errors, FLOAT_MIN, huber_loss, sequence_mask from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, nn = try_import_torch() @@ -170,16 +170,20 @@ def reduce_mean_valid(t): td_error = q_selected - target.reshape([B, T])[:, :-1].detach() td_error = td_error * seq_mask weights = weights.reshape([B, T])[:, :-1] - policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) - policy._td_error = td_error.reshape([-1]) - policy._loss_stats = { - "mean_q": reduce_mean_valid(q_selected), - "min_q": torch.min(q_selected), - "max_q": torch.max(q_selected), - "mean_td_error": reduce_mean_valid(td_error), - } + total_loss = reduce_mean_valid(weights * huber_loss(td_error)) - return policy._total_loss + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["total_loss"] = total_loss + model.tower_stats["mean_q"] = reduce_mean_valid(q_selected) + model.tower_stats["min_q"] = torch.min(q_selected) + model.tower_stats["max_q"] = torch.max(q_selected) + model.tower_stats["mean_td_error"] = reduce_mean_valid(td_error) + # Store per time chunk (b/c we need only one mean + # prioritized replay weight per stored sequence). + model.tower_stats["td_error"] = torch.mean(td_error, dim=-1) + + return total_loss def h_function(x, epsilon=1.0): @@ -233,15 +237,23 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # Do forward pass on loss to update td error attribute r2d2_loss(self, self.model, None, input_dict) - return self._td_error + return self.model.tower_stats["td_error"] self.compute_td_error = compute_td_error -def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: - return dict({ +def build_q_stats(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]: + + return { "cur_lr": policy.cur_lr, - }, **policy._loss_stats) + "total_loss": torch.mean( + torch.stack(policy.get_tower_stats("total_loss"))), + "mean_q": torch.mean(torch.stack(policy.get_tower_stats("mean_q"))), + "min_q": torch.mean(torch.stack(policy.get_tower_stats("min_q"))), + "max_q": torch.mean(torch.stack(policy.get_tower_stats("max_q"))), + "mean_td_error": torch.mean( + torch.stack(policy.get_tower_stats("mean_td_error"))), + } def setup_early_mixins(policy: Policy, obs_space, action_space, @@ -279,7 +291,7 @@ def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, postprocess_fn=postprocess_nstep_and_prio, optimizer_fn=adam_optimizer, extra_grad_process_fn=grad_process_and_td_error_fn, - extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error}, + extra_learn_fetches_fn=concat_multi_gpu_td_errors, extra_action_out_fn=extra_action_out_fn, before_init=setup_early_mixins, before_loss_init=before_loss_init, diff --git a/rllib/agents/dqn/simple_q_tf_policy.py b/rllib/agents/dqn/simple_q_tf_policy.py index 0801b6fd26e63..13e62bca1fd9a 100644 --- a/rllib/agents/dqn/simple_q_tf_policy.py +++ b/rllib/agents/dqn/simple_q_tf_policy.py @@ -181,7 +181,7 @@ def compute_q_values(policy: Policy, explore, is_training=None) -> TensorType: model_out, _ = model({ - SampleBatch.CUR_OBS: obs, + SampleBatch.OBS: obs, "is_training": is_training if is_training is not None else policy._get_is_training_placeholder(), }, [], None) diff --git a/rllib/agents/dqn/simple_q_torch_policy.py b/rllib/agents/dqn/simple_q_torch_policy.py index 055ce51598265..205fa6042e09e 100644 --- a/rllib/agents/dqn/simple_q_torch_policy.py +++ b/rllib/agents/dqn/simple_q_torch_policy.py @@ -16,7 +16,7 @@ from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import huber_loss +from ray.rllib.utils.torch_ops import concat_multi_gpu_td_errors, huber_loss from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, nn = try_import_torch() @@ -112,12 +112,20 @@ def build_q_losses(policy: Policy, model, dist_class, td_error = q_t_selected - q_t_selected_target.detach() loss = torch.mean(huber_loss(td_error)) - # save TD error as an attribute for outside access - policy.td_error = td_error + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["loss"] = loss + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error return loss +def stats_fn(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]: + return {"loss": torch.mean(torch.stack(policy.get_tower_stats("loss")))} + + def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, action_dist) -> Dict[str, TensorType]: """Adds q-values to the action out dict.""" @@ -144,10 +152,11 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, framework="torch", loss_fn=build_q_losses, get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, + stats_fn=stats_fn, extra_action_out_fn=extra_action_out_fn, after_init=setup_late_mixins, make_model_and_action_dist=build_q_model_and_distribution, mixins=[TargetNetworkMixin], action_distribution_fn=get_distribution_inputs_and_class, - extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error}, + extra_learn_fetches_fn=concat_multi_gpu_td_errors, ) diff --git a/rllib/agents/dqn/tests/test_apex_dqn.py b/rllib/agents/dqn/tests/test_apex_dqn.py index 63c051310baec..93702bf8d7c1b 100644 --- a/rllib/agents/dqn/tests/test_apex_dqn.py +++ b/rllib/agents/dqn/tests/test_apex_dqn.py @@ -4,8 +4,10 @@ import ray import ray.rllib.agents.dqn.apex as apex from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestApexDQN(unittest.TestCase): @@ -26,7 +28,9 @@ def test_apex_zero_workers(self): config["optimizer"]["num_replay_buffer_shards"] = 1 for _ in framework_iterator(config): trainer = apex.ApexTrainer(config=config, env="CartPole-v0") - trainer.train() + results = trainer.train() + check_train_results(results) + print(results) trainer.stop() def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): @@ -53,7 +57,9 @@ def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): check_compute_single_action(trainer) for i in range(2): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) # Test again per-worker epsilon distribution # (should not have changed). @@ -97,7 +103,8 @@ def _step_n_times(trainer, n: int): """ for _ in range(n): results = trainer.train() - return results["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"] + return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ + LEARNER_STATS_KEY]["cur_lr"] # Check eager execution frameworks here, since it's easier to control # exact timesteps with these frameworks. diff --git a/rllib/agents/dqn/tests/test_dqn.py b/rllib/agents/dqn/tests/test_dqn.py index dbf4876742b1f..fbf029a511243 100644 --- a/rllib/agents/dqn/tests/test_dqn.py +++ b/rllib/agents/dqn/tests/test_dqn.py @@ -4,7 +4,7 @@ import ray import ray.rllib.agents.dqn as dqn from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestDQN(unittest.TestCase): @@ -30,6 +30,7 @@ def test_dqn_compilation(self): trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) @@ -46,6 +47,7 @@ def test_dqn_compilation(self): trainer = dqn.DQNTrainer(config=rainbow_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) diff --git a/rllib/agents/dqn/tests/test_r2d2.py b/rllib/agents/dqn/tests/test_r2d2.py index d6e0d52d285e8..44b2e0887a1c5 100644 --- a/rllib/agents/dqn/tests/test_r2d2.py +++ b/rllib/agents/dqn/tests/test_r2d2.py @@ -4,7 +4,7 @@ import ray.rllib.agents.dqn as dqn from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() @@ -43,6 +43,7 @@ def test_r2d2_compilation(self): trainer = dqn.R2D2Trainer(config=config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer, include_state=True) diff --git a/rllib/agents/dqn/tests/test_simple_q.py b/rllib/agents/dqn/tests/test_simple_q.py index 12cddac283208..299bf39f63e51 100644 --- a/rllib/agents/dqn/tests/test_simple_q.py +++ b/rllib/agents/dqn/tests/test_simple_q.py @@ -10,7 +10,7 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.numpy import fc, one_hot, huber_loss from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() @@ -41,6 +41,7 @@ def test_simple_q_compilation(self): sb = rw.sample() assert sb.count == config["rollout_fragment_length"] results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) diff --git a/rllib/agents/dreamer/dreamer.py b/rllib/agents/dreamer/dreamer.py index 4a8170f527875..b3433f62cd5a0 100644 --- a/rllib/agents/dreamer/dreamer.py +++ b/rllib/agents/dreamer/dreamer.py @@ -7,11 +7,12 @@ from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - LEARNER_INFO, _get_shared_metrics + _get_shared_metrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.agents.dreamer.dreamer_model import DreamerModel from ray.rllib.execution.rollout_ops import ParallelRollouts +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.typing import SampleBatchType logger = logging.getLogger(__name__) diff --git a/rllib/agents/impala/tests/test_impala.py b/rllib/agents/impala/tests/test_impala.py index ba5e28e82073c..13e22240e34aa 100644 --- a/rllib/agents/impala/tests/test_impala.py +++ b/rllib/agents/impala/tests/test_impala.py @@ -4,8 +4,10 @@ import ray.rllib.agents.impala as impala from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.test_utils import check, \ - check_compute_single_action, framework_iterator + check_compute_single_action, check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() @@ -39,7 +41,10 @@ def test_impala_compilation(self): # to do with LSTMs, though). trainer = impala.ImpalaTrainer(config=local_cfg, env=env) for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) + check_compute_single_action( trainer, include_state=lstm, @@ -61,7 +66,8 @@ def test_impala_lr_schedule(self): config["env"] = "CartPole-v0" def get_lr(result): - return result["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"] + return result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ + LEARNER_STATS_KEY]["cur_lr"] for fw in framework_iterator(config, frameworks=("tf", "torch")): trainer = impala.ImpalaTrainer(config=config) diff --git a/rllib/agents/impala/vtrace_tf_policy.py b/rllib/agents/impala/vtrace_tf_policy.py index 99960a3206b2c..f5b5ddc4192db 100644 --- a/rllib/agents/impala/vtrace_tf_policy.py +++ b/rllib/agents/impala/vtrace_tf_policy.py @@ -111,8 +111,13 @@ def __init__(self, self.mean_entropy = tf.reduce_mean(masked_entropy) # The summed weighted loss. - self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - - self.entropy * entropy_coeff) + self.total_loss = self.pi_loss - self.entropy * entropy_coeff + + # Optional vf loss (or in a separate term due to separate + # optimizers/networks). + self.loss_wo_vf = self.total_loss + if not config["_separate_vf_optimizer"]: + self.total_loss += self.vf_loss * vf_loss_coeff def _make_time_major(policy, seq_lens, tensor, drop_last=False): @@ -220,7 +225,10 @@ def make_time_major(*args, **kw): clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) - return policy.loss.total_loss + if policy.config.get("_separate_vf_optimizer"): + return policy.loss.loss_wo_vf, policy.loss.vf_loss + else: + return policy.loss.total_loss def stats(policy, train_batch): @@ -239,13 +247,21 @@ def stats(policy, train_batch): "vf_loss": policy.loss.mean_vf_loss, "vf_explained_var": explained_variance( tf.reshape(policy.loss.value_targets, [-1]), - tf.reshape(values_batched, [-1])), + tf.reshape(values_batched, [-1])) } def grad_stats(policy, train_batch, grads): + # We have support for more than one loss (list of lists of grads). + if policy.config.get("_tf_policy_handles_more_than_one_loss"): + grad_gnorm = [tf.linalg.global_norm(g) for g in grads] + # Old case: We have a single list of grads (only one loss term and + # optimizer). + else: + grad_gnorm = tf.linalg.global_norm(grads) + return { - "grad_gnorm": tf.linalg.global_norm(grads), + "grad_gnorm": grad_gnorm, } diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index ec279cd5573b0..c8738d1875f63 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -1,10 +1,12 @@ import gym import logging import numpy as np +from typing import Any, Dict import ray import ray.rllib.agents.impala.vtrace_torch as vtrace from ray.rllib.models.torch.torch_action_dist import TorchCategorical +from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule, \ @@ -182,17 +184,22 @@ def _make_time_major(*args, **kw): clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) - # Store loss object only for multi-GPU tower 0. - if model is policy.model_gpu_towers[0]: - policy.loss = loss - values_batched = make_time_major( - policy, - train_batch.get(SampleBatch.SEQ_LENS), - values, - drop_last=policy.config["vtrace"]) - policy._vf_explained_var = explained_variance( - torch.reshape(loss.value_targets, [-1]), - torch.reshape(values_batched, [-1])), + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["pi_loss"] = loss.pi_loss + model.tower_stats["vf_loss"] = loss.vf_loss + model.tower_stats["entropy"] = loss.entropy + model.tower_stats["mean_entropy"] = loss.mean_entropy + model.tower_stats["total_loss"] = loss.total_loss + + values_batched = make_time_major( + policy, + train_batch.get(SampleBatch.SEQ_LENS), + values, + drop_last=policy.config["vtrace"]) + model.tower_stats["vf_explained_var"] = explained_variance( + torch.reshape(loss.value_targets, [-1]), + torch.reshape(values_batched, [-1])) return loss.total_loss @@ -236,15 +243,21 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False): return res -def stats(policy, train_batch): +def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, Any]: + return { "cur_lr": policy.cur_lr, - "policy_loss": policy.loss.pi_loss, - "entropy": policy.loss.mean_entropy, + "total_loss": torch.mean( + torch.stack(policy.get_tower_stats("total_loss"))), + "policy_loss": torch.mean( + torch.stack(policy.get_tower_stats("pi_loss"))), + "entropy": torch.mean( + torch.stack(policy.get_tower_stats("mean_entropy"))), "entropy_coeff": policy.entropy_coeff, "var_gnorm": global_norm(policy.model.trainable_variables()), - "vf_loss": policy.loss.vf_loss, - "vf_explained_var": policy._vf_explained_var, + "vf_loss": torch.mean(torch.stack(policy.get_tower_stats("vf_loss"))), + "vf_explained_var": torch.mean( + torch.stack(policy.get_tower_stats("vf_explained_var"))), } diff --git a/rllib/agents/maml/maml.py b/rllib/agents/maml/maml.py index 9d82a0e192cc5..c85d4f158b3c5 100644 --- a/rllib/agents/maml/maml.py +++ b/rllib/agents/maml/maml.py @@ -8,11 +8,12 @@ from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics + STEPS_TRAINED_COUNTER, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.execution.metric_ops import CollectMetrics from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.util.iter import from_actors logger = logging.getLogger(__name__) @@ -98,9 +99,10 @@ def __call__(self, data_tuple): # Metric Updating metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += samples.count + fetches = None for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) - fetches = get_learner_stats(fetches) + learner_stats = get_learner_stats(fetches) # Sync workers with meta policy self.workers.sync_weights() @@ -110,11 +112,12 @@ def __call__(self, data_tuple): # Update KLS def update(pi, pi_id): - assert "inner_kl" not in fetches, ( - "inner_kl should be nested under policy id key", fetches) - if pi_id in fetches: - assert "inner_kl" in fetches[pi_id], (fetches, pi_id) - pi.update_kls(fetches[pi_id]["inner_kl"]) + assert "inner_kl" not in learner_stats, ( + "inner_kl should be nested under policy id key", learner_stats) + if pi_id in learner_stats: + assert "inner_kl" in learner_stats[pi_id], (learner_stats, + pi_id) + pi.update_kls(learner_stats[pi_id]["inner_kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) diff --git a/rllib/agents/maml/tests/test_maml.py b/rllib/agents/maml/tests/test_maml.py index b84e028571907..e1905b5cc853f 100644 --- a/rllib/agents/maml/tests/test_maml.py +++ b/rllib/agents/maml/tests/test_maml.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.maml as maml from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestMAML(unittest.TestCase): @@ -34,7 +34,9 @@ def test_maml_compilation(self): env_ = "ray.rllib.examples.env.{}".format(env) trainer = maml.MAMLTrainer(config=config, env=env_) for i in range(num_iterations): - trainer.train() + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action( trainer, include_prev_action_reward=True) trainer.stop() diff --git a/rllib/agents/marwil/tests/test_bc.py b/rllib/agents/marwil/tests/test_bc.py index c6508330e43de..d6ac234897839 100644 --- a/rllib/agents/marwil/tests/test_bc.py +++ b/rllib/agents/marwil/tests/test_bc.py @@ -6,7 +6,7 @@ import ray.rllib.agents.marwil as marwil from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() @@ -51,7 +51,11 @@ def test_bc_compilation_and_learning_from_offline_file(self): trainer = marwil.BCTrainer(config=config, env="CartPole-v0") learnt = False for i in range(num_iterations): - eval_results = trainer.train().get("evaluation") + results = trainer.train() + check_train_results(results) + print(results) + + eval_results = results.get("evaluation") if eval_results: print("iter={} R={}".format( i, eval_results["episode_reward_mean"])) diff --git a/rllib/agents/marwil/tests/test_marwil.py b/rllib/agents/marwil/tests/test_marwil.py index 29c6b678ecf2c..b8ca7af86ae21 100644 --- a/rllib/agents/marwil/tests/test_marwil.py +++ b/rllib/agents/marwil/tests/test_marwil.py @@ -9,7 +9,7 @@ from ray.rllib.offline import JsonReader from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -57,7 +57,11 @@ def test_marwil_compilation_and_learning_from_offline_file(self): trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0") learnt = False for i in range(num_iterations): - eval_results = trainer.train().get("evaluation") + results = trainer.train() + check_train_results(results) + print(results) + + eval_results = results.get("evaluation") if eval_results: print("iter={} R={} ".format( i, eval_results["episode_reward_mean"])) diff --git a/rllib/agents/mbmpo/mbmpo.py b/rllib/agents/mbmpo/mbmpo.py index 0a537213ac193..aaf2d835e6c1f 100644 --- a/rllib/agents/mbmpo/mbmpo.py +++ b/rllib/agents/mbmpo/mbmpo.py @@ -26,10 +26,11 @@ get_learner_stats from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics + STEPS_TRAINED_COUNTER, _get_shared_metrics from ray.rllib.execution.metric_ops import CollectMetrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.sgd import standardized from ray.rllib.utils.torch_ops import convert_to_torch_tensor from ray.rllib.utils.typing import EnvType, TrainerConfigDict @@ -160,17 +161,19 @@ def __call__(self, data_tuple): adapt_metrics_dict, prefix="MAMLIter{}".format(self.step_counter)) # MAML Meta-update. + fetches = None for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) - fetches = get_learner_stats(fetches) + learner_stats = get_learner_stats(fetches) # Update KLs. def update(pi, pi_id): - assert "inner_kl" not in fetches, ( - "inner_kl should be nested under policy id key", fetches) - if pi_id in fetches: - assert "inner_kl" in fetches[pi_id], (fetches, pi_id) - pi.update_kls(fetches[pi_id]["inner_kl"]) + assert "inner_kl" not in learner_stats, ( + "inner_kl should be nested under policy id key", learner_stats) + if pi_id in learner_stats: + assert "inner_kl" in learner_stats[pi_id], (learner_stats, + pi_id) + pi.update_kls(learner_stats[pi_id]["inner_kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) diff --git a/rllib/agents/mbmpo/tests/test_mbmpo.py b/rllib/agents/mbmpo/tests/test_mbmpo.py index de708fd50d58c..941686c3e717b 100644 --- a/rllib/agents/mbmpo/tests/test_mbmpo.py +++ b/rllib/agents/mbmpo/tests/test_mbmpo.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.mbmpo as mbmpo from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestMBMPO(unittest.TestCase): @@ -28,8 +28,12 @@ def test_mbmpo_compilation(self): trainer = mbmpo.MBMPOTrainer( config=config, env="ray.rllib.examples.env.mbmpo_env.CartPoleWrapper") + for i in range(num_iterations): - trainer.train() + results = trainer.train() + check_train_results(results) + print(results) + check_compute_single_action( trainer, include_prev_action_reward=False) trainer.stop() diff --git a/rllib/agents/pg/pg_torch_policy.py b/rllib/agents/pg/pg_torch_policy.py index d707f01f2364e..34a17c5e03f97 100644 --- a/rllib/agents/pg/pg_torch_policy.py +++ b/rllib/agents/pg/pg_torch_policy.py @@ -44,11 +44,15 @@ def pg_torch_loss( # L = -E[ log(pi(a|s)) * A] log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS]) - # Save the loss in the policy object for the stats_fn below. - policy.pi_err = -torch.mean( + # Final policy loss. + policy_loss = -torch.mean( log_probs * train_batch[Postprocessing.ADVANTAGES]) - return policy.pi_err + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["policy_loss"] = policy_loss + + return policy_loss def pg_loss_stats(policy: Policy, @@ -64,8 +68,8 @@ def pg_loss_stats(policy: Policy, """ return { - # `pi_err` (the loss) is stored inside `pg_torch_loss()`. - "policy_loss": policy.pi_err.item(), + "policy_loss": torch.mean( + torch.stack(policy.get_tower_stats("policy_loss"))), } diff --git a/rllib/agents/pg/tests/test_pg.py b/rllib/agents/pg/tests/test_pg.py index 44a52829beaf3..40b985cc8e488 100644 --- a/rllib/agents/pg/tests/test_pg.py +++ b/rllib/agents/pg/tests/test_pg.py @@ -7,8 +7,9 @@ from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils import check, check_compute_single_action, fc, \ - framework_iterator +from ray.rllib.utils.numpy import fc +from ray.rllib.utils.test_utils import check, check_compute_single_action, \ + check_train_results, framework_iterator class TestPG(unittest.TestCase): @@ -31,7 +32,10 @@ def test_pg_compilation(self): for env in ["FrozenLake-v0", "CartPole-v0"]: trainer = pg.PGTrainer(config=config, env=env) for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) + check_compute_single_action( trainer, include_prev_action_reward=True) diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index 455044bebfe1d..142b96d6e247f 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -304,7 +304,7 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: "vf_loss": policy._mean_vf_loss, "vf_explained_var": explained_variance( tf.reshape(policy._value_targets, [-1]), - tf.reshape(values_batched, [-1])), + tf.reshape(values_batched, [-1])) } if policy.config["vtrace"]: diff --git a/rllib/agents/ppo/appo_torch_policy.py b/rllib/agents/ppo/appo_torch_policy.py index f8ee24989d825..324b73bf5a6b7 100644 --- a/rllib/agents/ppo/appo_torch_policy.py +++ b/rllib/agents/ppo/appo_torch_policy.py @@ -159,7 +159,7 @@ def reduce_mean_valid(t): torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) - mean_kl = reduce_mean_valid(action_kl) + mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. @@ -188,7 +188,7 @@ def reduce_mean_valid(t): torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) - mean_kl = reduce_mean_valid(action_kl) + mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. @@ -208,16 +208,17 @@ def reduce_mean_valid(t): # Optional additional KL Loss if policy.config["use_kl_loss"]: - total_loss += policy.kl_coeff * mean_kl - - policy._total_loss = total_loss - policy._mean_policy_loss = mean_policy_loss - # Backward compatibility: Deprecate policy._mean_kl. - policy._mean_kl_loss = policy._mean_kl = mean_kl - policy._mean_vf_loss = mean_vf_loss - policy._mean_entropy = mean_entropy - policy._value_targets = value_targets - policy._vf_explained_var = explained_variance( + total_loss += policy.kl_coeff * mean_kl_loss + + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["total_loss"] = total_loss + model.tower_stats["mean_policy_loss"] = mean_policy_loss + model.tower_stats["mean_kl_loss"] = mean_kl_loss + model.tower_stats["mean_vf_loss"] = mean_vf_loss + model.tower_stats["mean_entropy"] = mean_entropy + model.tower_stats["value_targets"] = value_targets + model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] @@ -239,22 +240,28 @@ def stats(policy: Policy, train_batch: SampleBatch): """ stats_dict = { "cur_lr": policy.cur_lr, - "policy_loss": policy._mean_policy_loss, - "entropy": policy._mean_entropy, + "total_loss": torch.mean( + torch.stack(policy.get_tower_stats("total_loss"))), + "policy_loss": torch.mean( + torch.stack(policy.get_tower_stats("mean_policy_loss"))), + "entropy": torch.mean( + torch.stack(policy.get_tower_stats("mean_entropy"))), "var_gnorm": global_norm(policy.model.trainable_variables()), - "vf_loss": policy._mean_vf_loss, - "vf_explained_var": policy._vf_explained_var, + "vf_loss": torch.mean( + torch.stack(policy.get_tower_stats("mean_vf_loss"))), + "vf_explained_var": torch.mean( + torch.stack(policy.get_tower_stats("vf_explained_var"))), } if policy.config["vtrace"]: is_stat_mean = torch.mean(policy._is_ratio, [0, 1]) is_stat_var = torch.var(policy._is_ratio, [0, 1]) - stats_dict.update({"mean_IS": is_stat_mean}) - stats_dict.update({"var_IS": is_stat_var}) + stats_dict["mean_IS"] = is_stat_mean + stats_dict["var_IS"] = is_stat_var if policy.config["use_kl_loss"]: - stats_dict.update({"kl": policy._mean_kl_loss}) - stats_dict.update({"KL_Coeff": policy.kl_coeff}) + stats_dict["kl"] = policy.get_tower_stats("mean_kl_loss") + stats_dict["KL_Coeff"] = policy.kl_coeff return stats_dict diff --git a/rllib/agents/ppo/ddppo.py b/rllib/agents/ppo/ddppo.py index d3eee646999e4..b7c15918b16fe 100644 --- a/rllib/agents/ppo/ddppo.py +++ b/rllib/agents/ppo/ddppo.py @@ -26,9 +26,10 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - STEPS_TRAINED_COUNTER, LEARNER_INFO, LEARN_ON_BATCH_TIMER, \ + STEPS_TRAINED_COUNTER, LEARN_ON_BATCH_TIMER, \ _get_shared_metrics, _get_global_vars from ray.rllib.evaluation.rollout_worker import get_global_worker +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -75,6 +76,11 @@ "truncate_episodes": True, # This is auto set based on sample batch size. "train_batch_size": -1, + # Kl divergence penalty should be fixed to 0 in DDPPO because in order + # for it to be used as a penalty, we would have to un-decentralize + # DDPPO + "kl_coeff": 0.0, + "kl_target": 0.0 }, _allow_unknown_configs=True, ) @@ -131,6 +137,13 @@ def validate_config(config): raise ValueError( "Distributed data parallel requires truncate_episodes " "batch mode.") + # DDPPO doesn't support KL penalties like PPO-1. + # In order to support KL penalties, DDPPO would need to become + # undecentralized, which defeats the purpose of the algorithm. + # Users can still tune the entropy coefficient to control the + # policy entropy (similar to controlling the KL penalty). + if config["kl_coeff"] != 0.0 or config["kl_target"] != 0.0: + raise ValueError("DDPPO doesn't support KL penalties like PPO-1") def execution_plan(workers: WorkerSet, diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index e0ced5d82cdeb..e43d460087b84 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -20,9 +20,11 @@ StandardizeFields, SelectExperiences from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep from ray.rllib.execution.metric_ops import StandardMetricsReporting -from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -217,12 +219,12 @@ def warn_about_bad_reward_scales(config, result): return result # Punt on handling multiagent case. # Warn about excessively high VF loss. - learner_stats = result["info"]["learner"] - if DEFAULT_POLICY_ID in learner_stats: + learner_info = result["info"][LEARNER_INFO] + if DEFAULT_POLICY_ID in learner_info: scaled_vf_loss = config["vf_loss_coeff"] * \ - learner_stats[DEFAULT_POLICY_ID][LEARNER_STATS_KEY]["vf_loss"] + learner_info[DEFAULT_POLICY_ID][LEARNER_STATS_KEY]["vf_loss"] - policy_loss = learner_stats[DEFAULT_POLICY_ID][LEARNER_STATS_KEY][ + policy_loss = learner_info[DEFAULT_POLICY_ID][LEARNER_STATS_KEY][ "policy_loss"] if config.get("model", {}).get("vf_share_layers") and \ scaled_vf_loss > 100: diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index f8f310e6b07e3..69e19e33d7817 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -105,15 +105,15 @@ def reduce_mean_valid(t): policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) - # Store stats in policy for stats_fn. - policy._total_loss = total_loss - policy._mean_policy_loss = mean_policy_loss - policy._mean_vf_loss = mean_vf_loss - policy._vf_explained_var = explained_variance( + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["total_loss"] = total_loss + model.tower_stats["mean_policy_loss"] = mean_policy_loss + model.tower_stats["mean_vf_loss"] = mean_vf_loss + model.tower_stats["vf_explained_var"] = explained_variance( train_batch[Postprocessing.VALUE_TARGETS], model.value_function()) - policy._mean_entropy = mean_entropy - # Backward compatibility: Deprecate policy._mean_kl. - policy._mean_kl_loss = policy._mean_kl = mean_kl_loss + model.tower_stats["mean_entropy"] = mean_entropy + model.tower_stats["mean_kl_loss"] = mean_kl_loss return total_loss @@ -132,12 +132,17 @@ def kl_and_loss_stats(policy: Policy, return { "cur_kl_coeff": policy.kl_coeff, "cur_lr": policy.cur_lr, - "total_loss": policy._total_loss, - "policy_loss": policy._mean_policy_loss, - "vf_loss": policy._mean_vf_loss, - "vf_explained_var": policy._vf_explained_var, - "kl": policy._mean_kl_loss, - "entropy": policy._mean_entropy, + "total_loss": torch.mean( + torch.stack(policy.get_tower_stats("total_loss"))), + "policy_loss": torch.mean( + torch.stack(policy.get_tower_stats("mean_policy_loss"))), + "vf_loss": torch.mean( + torch.stack(policy.get_tower_stats("mean_vf_loss"))), + "vf_explained_var": torch.mean( + torch.stack(policy.get_tower_stats("vf_explained_var"))), + "kl": torch.mean(torch.stack(policy.get_tower_stats("mean_kl_loss"))), + "entropy": torch.mean( + torch.stack(policy.get_tower_stats("mean_entropy"))), "entropy_coeff": policy.entropy_coeff, } diff --git a/rllib/agents/ppo/tests/test_appo.py b/rllib/agents/ppo/tests/test_appo.py index 32a5989263f7c..be007f3dd9995 100644 --- a/rllib/agents/ppo/tests/test_appo.py +++ b/rllib/agents/ppo/tests/test_appo.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.ppo as ppo from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestAPPO(unittest.TestCase): @@ -27,7 +27,9 @@ def test_appo_compilation(self): _config["vtrace"] = False trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0") for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() @@ -36,7 +38,9 @@ def test_appo_compilation(self): _config["vtrace"] = True trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0") for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() @@ -55,10 +59,12 @@ def test_appo_two_tf_optimizers(self): num_iterations = 2 # Only supported for tf so far. - for _ in framework_iterator(config, frameworks="tf"): + for _ in framework_iterator(config, frameworks=("tf2", "tf")): trainer = ppo.APPOTrainer(config=config, env="CartPole-v0") for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/agents/ppo/tests/test_ddppo.py b/rllib/agents/ppo/tests/test_ddppo.py index e1191cfb2cd35..0e8154a662d12 100644 --- a/rllib/agents/ppo/tests/test_ddppo.py +++ b/rllib/agents/ppo/tests/test_ddppo.py @@ -1,11 +1,13 @@ import unittest +import pytest import ray import ray.rllib.agents.ppo as ppo from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.policy.policy import LEARNER_STATS_KEY +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestDDPPO(unittest.TestCase): @@ -26,7 +28,9 @@ def test_ddppo_compilation(self): for _ in framework_iterator(config, frameworks="torch"): trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") for i in range(num_iterations): - trainer.train() + results = trainer.train() + check_train_results(results) + print(results) # Make sure, weights on all workers are the same (including # local one). weights = trainer.workers.foreach_worker( @@ -48,13 +52,25 @@ def test_ddppo_schedule(self): trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") for _ in range(num_iterations): result = trainer.train() - lr = result["info"]["learner"][DEFAULT_POLICY_ID][ + lr = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ LEARNER_STATS_KEY]["cur_lr"] trainer.stop() assert lr == 0.0, "lr should anneal to 0.0" + def test_validate_config(self): + """Test if DDPPO will raise errors after invalid configs are passed.""" + config = ppo.ddppo.DEFAULT_CONFIG.copy() + config["kl_coeff"] = 1. + msg = "DDPPO doesn't support KL penalties like PPO-1" + # import ipdb; ipdb.set_trace() + with pytest.raises(ValueError, match=msg): + ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") + config["kl_coeff"] = 0. + config["kl_target"] = 1. + with pytest.raises(ValueError, match=msg): + ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") + if __name__ == "__main__": - import pytest import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index 2dfcec41010b5..198922ee7a338 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -14,11 +14,12 @@ from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_action_dist import TorchCategorical -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.numpy import fc -from ray.rllib.utils.test_utils import check, framework_iterator, \ - check_compute_single_action +from ray.rllib.utils.test_utils import check, check_compute_single_action, \ + check_train_results, framework_iterator # Fake CartPole episode of n time steps. FAKE_BATCH = SampleBatch({ @@ -59,7 +60,8 @@ def _check_lr_tf(policy, policy_id): assert lr == optim_lr, "LR scheduling error!" def on_train_result(self, *, trainer, result: dict, **kwargs): - stats = result["info"]["learner"][DEFAULT_POLICY_ID][LEARNER_STATS_KEY] + stats = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ + LEARNER_STATS_KEY] # Learning rate should go to 0 after 1 iter. check(stats["cur_lr"], 5e-5 if trainer.iteration == 1 else 0.0) # Entropy coeff goes to 0.05, then 0.0 (per iter). @@ -90,7 +92,7 @@ def test_ppo_compilation_and_schedule_mixins(self): config["model"]["lstm_cell_size"] = 10 config["model"]["max_seq_len"] = 20 # Use default-native keras models whenever possible. - config["model"]["_use_default_native_models"] = True + # config["model"]["_use_default_native_models"] = True # Setup lr- and entropy schedules for testing. config["lr_schedule"] = [[0, config["lr"]], [128, 0.0]] @@ -124,7 +126,9 @@ def test_ppo_compilation_and_schedule_mixins(self): check(lr, config["lr"]) for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action( trainer, @@ -313,6 +317,19 @@ def test_ppo_loss_function(self): check(pl, np.mean(-pg_loss)) check(v, np.mean(vf_loss), decimals=4) check(tl, overall_loss, decimals=4) + elif fw == "torch": + check(policy.model.tower_stats["mean_kl_loss"], kl) + check(policy.model.tower_stats["mean_entropy"], entropy) + check(policy.model.tower_stats["mean_policy_loss"], + np.mean(-pg_loss)) + check( + policy.model.tower_stats["mean_vf_loss"], + np.mean(vf_loss), + decimals=4) + check( + policy.model.tower_stats["total_loss"], + overall_loss, + decimals=4) else: check(policy._mean_kl_loss, kl) check(policy._mean_entropy, entropy) diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index 6c1078cbad314..ca0324ce4d08f 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -8,7 +8,6 @@ from ray.rllib.agents.qmix.model import RNNModel, _get_size from ray.rllib.env.multi_agent_env import ENV_STATE from ray.rllib.env.wrappers.group_agents_wrapper import GROUP_REWARDS -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import chop_into_sequences @@ -16,6 +15,7 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import _unpack_obs from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.annotations import override # Torch must be installed. diff --git a/rllib/agents/sac/rnnsac.py b/rllib/agents/sac/rnnsac.py index 3fb67e50d8ced..79bf6cdc816bc 100644 --- a/rllib/agents/sac/rnnsac.py +++ b/rllib/agents/sac/rnnsac.py @@ -11,10 +11,6 @@ { # Batch mode (see common config) "batch_mode": "complete_episodes", - # If True prioritized replay buffer will be used. - "prioritized_replay": False, - # RNNSAC does not suport n-step > 1 yet! - "n_step": 1, # If True, assume a zero-initialized state input (no matter where in # the episode the sequence is located). # If False, store the initial states along with each SampleBatch, use @@ -50,9 +46,6 @@ def validate_config(config: TrainerConfigDict) -> None: config["replay_sequence_length"] = \ config["burn_in"] + config["model"]["max_seq_len"] - if config["n_step"] > 1: - raise ValueError("`n_step` > 1 not yet supported by RNNSAC!") - def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: """Policy class picker function. Class is chosen based on DL-framework. diff --git a/rllib/agents/sac/rnnsac_torch_policy.py b/rllib/agents/sac/rnnsac_torch_policy.py index c0d223c0a4766..faef59e1bee67 100644 --- a/rllib/agents/sac/rnnsac_torch_policy.py +++ b/rllib/agents/sac/rnnsac_torch_policy.py @@ -371,6 +371,7 @@ def reduce_mean_valid(t): critic_loss.append( reduce_mean_valid( train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) + td_error = td_error * seq_mask # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. @@ -401,26 +402,21 @@ def reduce_mean_valid(t): actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) - # Save for stats function. - policy.q_t = q_t * seq_mask[..., None] - policy.policy_t = policy_t * seq_mask[..., None] - policy.log_pis_t = log_pis_t * seq_mask[..., None] - - # Store td-error in model, such that for multi-GPU, we do not override - # them during the parallel loss phase. TD-error tensor in final stats - # can then be concatenated and retrieved for each individual batch item. - model.td_error = td_error * seq_mask - - policy.actor_loss = actor_loss - policy.critic_loss = critic_loss - policy.alpha_loss = alpha_loss - policy.log_alpha_value = model.log_alpha - policy.alpha_value = alpha - policy.target_entropy = model.target_entropy + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["q_t"] = q_t * seq_mask[..., None] + model.tower_stats["policy_t"] = policy_t * seq_mask[..., None] + model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None] + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + model.tower_stats["alpha_loss"] = alpha_loss + # Store per time chunk (b/c we need only one mean + # prioritized replay weight per stored sequence). + model.tower_stats["td_error"] = torch.mean( + td_error.reshape([-1, T]), dim=-1) # Return all loss terms corresponding to our optimizers. - return tuple([policy.actor_loss] + policy.critic_loss + - [policy.alpha_loss]) + return tuple([actor_loss] + critic_loss + [alpha_loss]) RNNSACTorchPolicy = SACTorchPolicy.with_updates( diff --git a/rllib/agents/sac/sac_tf_model.py b/rllib/agents/sac/sac_tf_model.py index 546de04ab47c9..0b78f65a526fb 100644 --- a/rllib/agents/sac/sac_tf_model.py +++ b/rllib/agents/sac/sac_tf_model.py @@ -1,6 +1,7 @@ import gym from gym.spaces import Box, Discrete import numpy as np +import tree # pip install dm_tree from typing import Dict, List, Optional from ray.rllib.models.catalog import ModelCatalog @@ -267,13 +268,18 @@ def get_policy_output(self, model_out: TensorType) -> TensorType: Returns: TensorType: Distribution inputs for sampling actions. """ - # Model outs may come as original Tuple observations, concat them + # Model outs may come as original Tuple/Dict observations, concat them # here if this is the case. if isinstance(self.action_model.obs_space, Box): if isinstance(model_out, (list, tuple)): model_out = tf.concat(model_out, axis=-1) elif isinstance(model_out, dict): - model_out = tf.concat(list(model_out.values()), axis=-1) + model_out = tf.concat( + [ + tf.expand_dims(val, 1) if len(val.shape) == 1 else val + for val in tree.flatten(model_out.values()) + ], + axis=-1) out, _ = self.action_model({"obs": model_out}, [], None) return out diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index 111d8b717f494..629de0efce536 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -6,7 +6,6 @@ from gym.spaces import Box, Discrete from functools import partial import logging -import numpy as np from typing import Dict, List, Optional, Tuple, Type, Union import ray @@ -53,9 +52,6 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, target model will be created in this function and assigned to `policy.target_model`. """ - # With separate state-preprocessor (before obs+action concat). - num_outputs = int(np.product(obs_space.shape)) - # Force-ignore any additionally provided hidden layer sizes. # Everything should be configured using SAC's "Q_model" and "policy_model" # settings. @@ -70,7 +66,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, - num_outputs=num_outputs, + num_outputs=None, model_config=config["model"], framework=config["framework"], default_model=default_model_cls, @@ -90,7 +86,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, policy.target_model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, - num_outputs=num_outputs, + num_outputs=None, model_config=config["model"], framework=config["framework"], default_model=default_model_cls, diff --git a/rllib/agents/sac/sac_torch_model.py b/rllib/agents/sac/sac_torch_model.py index 64bbb40920453..1fdc09412da13 100644 --- a/rllib/agents/sac/sac_torch_model.py +++ b/rllib/agents/sac/sac_torch_model.py @@ -1,6 +1,7 @@ import gym from gym.spaces import Box, Discrete import numpy as np +import tree # pip install dm_tree from typing import Dict, List, Optional from ray.rllib.models.catalog import ModelCatalog @@ -281,7 +282,12 @@ def get_policy_output(self, model_out: TensorType) -> TensorType: if isinstance(model_out, (list, tuple)): model_out = torch.cat(model_out, dim=-1) elif isinstance(model_out, dict): - model_out = torch.cat(list(model_out.values()), dim=-1) + model_out = torch.cat( + [ + torch.unsqueeze(val, 1) if len(val.shape) == 1 else val + for val in tree.flatten(model_out.values()) + ], + dim=-1) out, _ = self.action_model({"obs": model_out}, [], None) return out diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index 6bfdb98decc7b..dee2693abf29e 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -5,6 +5,7 @@ import gym from gym.spaces import Box, Discrete import logging +import tree # pip install dm_tree from typing import Dict, List, Optional, Tuple, Type, Union import ray @@ -314,26 +315,21 @@ def actor_critic_loss( # the Q-net(s)' variables. actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy) - # Save for stats function. - policy.q_t = q_t - policy.policy_t = policy_t - policy.log_pis_t = log_pis_t + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["q_t"] = q_t + model.tower_stats["policy_t"] = policy_t + model.tower_stats["log_pis_t"] = log_pis_t + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + model.tower_stats["alpha_loss"] = alpha_loss - # Store td-error in model, such that for multi-GPU, we do not override - # them during the parallel loss phase. TD-error tensor in final stats - # can then be concatenated and retrieved for each individual batch item. - model.td_error = td_error - - policy.actor_loss = actor_loss - policy.critic_loss = critic_loss - policy.alpha_loss = alpha_loss - policy.log_alpha_value = model.log_alpha - policy.alpha_value = alpha - policy.target_entropy = model.target_entropy + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error # Return all loss terms corresponding to our optimizers. - return tuple([policy.actor_loss] + policy.critic_loss + - [policy.alpha_loss]) + return tuple([actor_loss] + critic_loss + [alpha_loss]) def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: @@ -346,17 +342,23 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: Returns: Dict[str, TensorType]: The stats dict. """ + q_t = torch.stack(policy.get_tower_stats("q_t")) + return { - "actor_loss": torch.mean(policy.actor_loss), - "critic_loss": torch.mean(torch.stack(policy.critic_loss)), - "alpha_loss": torch.mean(policy.alpha_loss), - "alpha_value": torch.mean(policy.alpha_value), - "log_alpha_value": torch.mean(policy.log_alpha_value), - "target_entropy": policy.target_entropy, - "policy_t": torch.mean(policy.policy_t), - "mean_q": torch.mean(policy.q_t), - "max_q": torch.max(policy.q_t), - "min_q": torch.min(policy.q_t), + "actor_loss": torch.mean( + torch.stack(policy.get_tower_stats("actor_loss"))), + "critic_loss": torch.mean( + torch.stack(tree.flatten(policy.get_tower_stats("critic_loss")))), + "alpha_loss": torch.mean( + torch.stack(policy.get_tower_stats("alpha_loss"))), + "alpha_value": torch.exp(policy.model.log_alpha), + "log_alpha_value": policy.model.log_alpha, + "target_entropy": policy.model.target_entropy, + "policy_t": torch.mean( + torch.stack(policy.get_tower_stats("policy_t"))), + "mean_q": torch.mean(q_t), + "max_q": torch.max(q_t), + "min_q": torch.min(q_t), } @@ -430,9 +432,9 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # (one TD-error value per item in batch to update PR weights). actor_critic_loss(self, self.model, None, input_dict) - # `self.td_error` is set within actor_critic_loss call. Return - # its updated value here. - return self.td_error + # `self.model.td_error` is set within actor_critic_loss call. + # Return its updated value here. + return self.model.tower_stats["td_error"] # Assign the method to policy (self) for later usage. self.compute_td_error = compute_td_error diff --git a/rllib/agents/sac/tests/test_rnnsac.py b/rllib/agents/sac/tests/test_rnnsac.py new file mode 100644 index 0000000000000..f0e8c5a750c57 --- /dev/null +++ b/rllib/agents/sac/tests/test_rnnsac.py @@ -0,0 +1,73 @@ +import unittest + +import ray +import ray.rllib.agents.sac as sac +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.test_utils import check_compute_single_action, \ + framework_iterator + +tf1, tf, tfv = try_import_tf() +torch, nn = try_import_torch() + + +class TestRNNSAC(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_rnnsac_compilation(self): + """Test whether a R2D2Trainer can be built on all frameworks.""" + config = sac.RNNSAC_DEFAULT_CONFIG.copy() + config["num_workers"] = 0 # Run locally. + + # Wrap with an LSTM and use a very simple base-model. + config["model"] = { + "max_seq_len": 20, + } + config["policy_model"] = { + "use_lstm": True, + "lstm_cell_size": 64, + "fcnet_hiddens": [10], + "lstm_use_prev_action": True, + "lstm_use_prev_reward": True, + } + config["Q_model"] = { + "use_lstm": True, + "lstm_cell_size": 64, + "fcnet_hiddens": [10], + "lstm_use_prev_action": True, + "lstm_use_prev_reward": True, + } + + # Test with PR activated. + config["prioritized_replay"] = True + + config["burn_in"] = 20 + config["zero_init_states"] = True + + config["lr"] = 5e-4 + + num_iterations = 1 + + # Test building an RNNSAC agent in all frameworks. + for _ in framework_iterator(config, frameworks="torch"): + trainer = sac.RNNSACTrainer(config=config, env="CartPole-v0") + for i in range(num_iterations): + results = trainer.train() + print(results) + + check_compute_single_action( + trainer, + include_state=True, + include_prev_action_reward=True, + ) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/sac/tests/test_sac.py b/rllib/agents/sac/tests/test_sac.py index d9b1de208af33..06083b33e3fa9 100644 --- a/rllib/agents/sac/tests/test_sac.py +++ b/rllib/agents/sac/tests/test_sac.py @@ -1,5 +1,5 @@ from gym import Env -from gym.spaces import Box, Discrete, Tuple +from gym.spaces import Box, Dict, Discrete, Tuple import numpy as np import re import unittest @@ -21,8 +21,9 @@ from ray.rllib.utils.numpy import fc, huber_loss, relu from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator from ray.rllib.utils.torch_ops import convert_to_torch_tensor +from ray import tune tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -71,8 +72,6 @@ def test_sac_compilation(self): config["num_workers"] = 0 # Run locally. config["n_step"] = 3 config["twin_q"] = True - config["clip_actions"] = False - config["normalize_actions"] = True config["learning_starts"] = 0 config["prioritized_replay"] = True config["rollout_fragment_length"] = 10 @@ -92,22 +91,28 @@ def test_sac_compilation(self): image_space = Box(-1.0, 1.0, shape=(84, 84, 3)) simple_space = Box(-1.0, 1.0, shape=(3, )) + tune.register_env( + "random_dict_env", lambda _: RandomEnv({ + "observation_space": Dict({ + "a": simple_space, + "b": Discrete(2), + "c": image_space, }), + "action_space": Box(-1.0, 1.0, shape=(1, )), })) + tune.register_env( + "random_tuple_env", lambda _: RandomEnv({ + "observation_space": Tuple([ + simple_space, Discrete(2), image_space]), + "action_space": Box(-1.0, 1.0, shape=(1, )), })) + for fw in framework_iterator(config): # Test for different env types (discrete w/ and w/o image, + cont). for env in [ - RandomEnv, + "random_dict_env", + "random_tuple_env", "MsPacmanNoFrameskip-v4", "CartPole-v0", ]: print("Env={}".format(env)) - if env == RandomEnv: - config["env_config"] = { - "observation_space": Tuple((simple_space, Discrete(2), - image_space)), - "action_space": Box(-1.0, 1.0, shape=(1, )), - } - else: - config["env_config"] = {} # Test making the Q-model a custom one for CartPole, otherwise, # use the default model. config["Q_model"]["custom_model"] = "batch_norm{}".format( @@ -116,6 +121,7 @@ def test_sac_compilation(self): trainer = sac.SACTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) @@ -306,8 +312,10 @@ def test_sac_loss_function(self): elif fw == "torch": loss_torch(policy, policy.model, None, input_) - c, a, e, t = policy.critic_loss, policy.actor_loss, \ - policy.alpha_loss, policy.model.td_error + c, a, e, t = policy.get_tower_stats("critic_loss")[0], \ + policy.get_tower_stats("actor_loss")[0], \ + policy.get_tower_stats("alpha_loss")[0], \ + policy.get_tower_stats("td_error")[0] # Test actor gradients. policy.actor_optim.zero_grad() diff --git a/rllib/agents/tests/test_trainer.py b/rllib/agents/tests/test_trainer.py index baf7b665963ca..937206deac138 100644 --- a/rllib/agents/tests/test_trainer.py +++ b/rllib/agents/tests/test_trainer.py @@ -13,6 +13,7 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.examples.parallel_evaluation_and_training import \ AssertNumEvalEpisodesCallback +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import framework_iterator @@ -72,7 +73,7 @@ def test_add_delete_policy(self): trainer = pg.PGTrainer(config=config) pol0 = trainer.get_policy("p0") r = trainer.train() - self.assertTrue("p0" in r["info"]["learner"]) + self.assertTrue("p0" in r["info"][LEARNER_INFO]) for i in range(1, 3): def new_mapping_fn(agent_id, episode, worker, **kwargs): diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 7147ba9ea85c7..a1f4b64ee2426 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -8,22 +8,24 @@ import pickle import tempfile import time -from typing import Callable, Dict, List, Optional, Type, Union +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import ray from ray.actor import ActorHandle from ray.exceptions import RayError from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.utils import gym_env_creator from ray.rllib.evaluation.collectors.simple_list_collector import \ SimpleListCollector +from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.policy.policy import Policy, PolicySpec -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils import deep_update, FilterManager, merge_dicts from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override, \ PublicAPI @@ -36,7 +38,7 @@ from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \ PartialTrainerConfigDict, PolicyID, ResultDict, TensorStructType, \ - TrainerConfigDict + TensorType, TrainerConfigDict from ray.tune.logger import Logger, UnifiedLogger from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.resources import Resources @@ -1007,17 +1009,29 @@ def _sync_weights_to_workers( @PublicAPI def compute_single_action( self, - observation: TensorStructType, - state: List[TensorStructType] = None, - prev_action: TensorStructType = None, - prev_reward: float = None, - info: EnvInfoDict = None, + observation: Optional[TensorStructType] = None, + state: Optional[List[TensorStructType]] = None, + *, + prev_action: Optional[TensorStructType] = None, + prev_reward: Optional[float] = None, + info: Optional[EnvInfoDict] = None, + input_dict: Optional[SampleBatch] = None, policy_id: PolicyID = DEFAULT_POLICY_ID, full_fetch: bool = False, - explore: bool = None, - unsquash_actions: Optional[bool] = None, - clip_actions: Optional[bool] = None, - ) -> TensorStructType: + explore: Optional[bool] = None, + timestep: Optional[int] = None, + episode: Optional[MultiAgentEpisode] = None, + unsquash_action: Optional[bool] = None, + clip_action: Optional[bool] = None, + + # Deprecated args. + unsquash_actions=DEPRECATED_VALUE, + clip_actions=DEPRECATED_VALUE, + + # Kwargs placeholder for future compatibility. + **kwargs, + ) -> Union[TensorStructType, Tuple[TensorStructType, List[TensorType], + Dict[str, TensorType]]]: """Computes an action for the specified policy on the local worker. Note that you can also access the policy object through @@ -1025,70 +1039,123 @@ def compute_single_action( directly. Args: - observation (TensorStructType): observation from the environment. - state (List[TensorStructType]): RNN hidden state, if any. If state - is not None, then all of compute_single_action(...) is returned - (computed action, rnn state(s), logits dictionary). - Otherwise compute_single_action(...)[0] is returned - (computed action). - prev_action (TensorStructType): Previous action value, if any. - prev_reward (float): Previous reward, if any. - info (EnvInfoDict): info object, if any - policy_id (PolicyID): Policy to query (only applies to - multi-agent). - full_fetch (bool): Whether to return extra action fetch results. - This is always set to True if RNN state is specified. - explore (bool): Whether to pick an exploitation or exploration - action (default: None -> use self.config["explore"]). - unsquash_actions (bool): Should actions be unsquashed according to - the env's/Policy's action space? - clip_actions (bool): Should actions be clipped according to the - env's/Policy's action space? + observation: Single (unbatched) observation from the + environment. + state: List of all RNN hidden (single, unbatched) state tensors. + prev_action: Single (unbatched) previous action value. + prev_reward: Single (unbatched) previous reward value. + info: Env info dict, if any. + input_dict: An optional SampleBatch that holds all the values + for: obs, state, prev_action, and prev_reward, plus maybe + custom defined views of the current env trajectory. Note + that only one of `obs` or `input_dict` must be non-None. + policy_id: Policy to query (only applies to multi-agent). + Default: "default_policy". + full_fetch: Whether to return extra action fetch results. + This is always set to True if `state` is specified. + explore: Whether to apply exploration to the action. + Default: None -> use self.config["explore"]. + timestep: The current (sampling) time step. + episode: This provides access to all of the internal episodes' + state, which may be useful for model-based or multi-agent + algorithms. + unsquash_action: Should actions be unsquashed according to the + env's/Policy's action space? If None, use the value of + self.config["normalize_actions"]. + clip_action: Should actions be clipped according to the + env's/Policy's action space? If None, use the value of + self.config["clip_actions"]. + + Keyword Args: + kwargs: forward compatibility placeholder Returns: - any: The computed action if full_fetch=False, or - tuple: The full output of policy.compute_actions() if - full_fetch=True or we have an RNN-based Policy. + The computed action if full_fetch=False, or a tuple of a) the + full output of policy.compute_actions() if full_fetch=True + or we have an RNN-based Policy. Raises: KeyError: If the `policy_id` cannot be found in this Trainer's local worker. """ + if clip_actions != DEPRECATED_VALUE: + deprecation_warning( + old="Trainer.compute_single_action(`clip_actions`=...)", + new="Trainer.compute_single_action(`clip_action`=...)", + error=False) + clip_action = clip_actions + if unsquash_actions != DEPRECATED_VALUE: + deprecation_warning( + old="Trainer.compute_single_action(`unsquash_actions`=...)", + new="Trainer.compute_single_action(`unsquash_action`=...)", + error=False) + unsquash_action = unsquash_actions + + # User provided an input-dict: Assert that `obs`, `prev_a|r`, `state` + # are all None. + err_msg = "Provide either `input_dict` OR [`observation`, ...] as " \ + "args to Trainer.compute_single_action!" + if input_dict is not None: + assert observation is None and prev_action is None and \ + prev_reward is None and state is None, err_msg + observation = input_dict[SampleBatch.OBS] + else: + assert observation is not None, err_msg + + # Get the policy to compute the action for (in the multi-agent case, + # Trainer may hold >1 policies). policy = self.get_policy(policy_id) if policy is None: raise KeyError( f"PolicyID '{policy_id}' not found in PolicyMap of the " f"Trainer's local worker!") - local_worker = self.workers.local_worker() - if state is None: - state = [] - # Check the preprocessor and preprocess, if necessary. pp = local_worker.preprocessors[policy_id] if pp and type(pp).__name__ != "NoPreprocessor": observation = pp.transform(observation) - filtered_observation = local_worker.filters[policy_id]( + observation = local_worker.filters[policy_id]( observation, update=False) - # Compute the action. - result = policy.compute_single_action( - filtered_observation, - state, - prev_action, - prev_reward, - info, - unsquash_actions=unsquash_actions, - clip_actions=clip_actions, - explore=explore) + # Input-dict. + if input_dict is not None: + input_dict[SampleBatch.OBS] = observation + action, state, extra = policy.compute_single_action( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episode=episode, + ) + # Individual args. + else: + action, state, extra = policy.compute_single_action( + obs=observation, + state=state, + prev_action=prev_action, + prev_reward=prev_reward, + info=info, + explore=explore, + timestep=timestep, + episode=episode, + ) + + # If we work in normalized action space (normalize_actions=True), + # we re-translate here into the env's action space. + if unsquash_action: + action = space_utils.unsquash_action(action, + policy.action_space_struct) + # Clip, according to env's action space. + elif clip_action: + action = space_utils.clip_action(action, + policy.action_space_struct) # Return 3-Tuple: Action, states, and extra-action fetches. if state or full_fetch: - return result + return action, state, extra # Ensure backward compatibility. else: - return result[0] + return action @Deprecated(new="compute_single_action", error=False) def compute_action(self, *args, **kwargs): @@ -1098,15 +1165,21 @@ def compute_action(self, *args, **kwargs): def compute_actions( self, observations: TensorStructType, - state: List[TensorStructType] = None, - prev_action: TensorStructType = None, - prev_reward: TensorStructType = None, - info=None, - policy_id=DEFAULT_POLICY_ID, - full_fetch=False, - explore=None, + state: Optional[List[TensorStructType]] = None, + *, + prev_action: Optional[TensorStructType] = None, + prev_reward: Optional[TensorStructType] = None, + info: Optional[EnvInfoDict] = None, + policy_id: PolicyID = DEFAULT_POLICY_ID, + full_fetch: bool = False, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + episodes: Optional[List[MultiAgentEpisode]] = None, + unsquash_actions: Optional[bool] = None, + clip_actions: Optional[bool] = None, + # Deprecated. normalize_actions=None, - clip_actions=None, + **kwargs, ): """Computes an action for the specified policy on the local Worker. @@ -1114,30 +1187,46 @@ def compute_actions( self.get_policy(policy_id) and call compute_actions() on it directly. Args: - observation (obj): observation from the environment. - state (dict): RNN hidden state, if any. If state is not None, + observation: observation from the environment. + state: RNN hidden state, if any. If state is not None, then all of compute_single_action(...) is returned (computed action, rnn state(s), logits dictionary). Otherwise compute_single_action(...)[0] is returned (computed action). - prev_action (obj): previous action value, if any - prev_reward (int): previous reward, if any - info (dict): info object, if any - policy_id (str): Policy to query (only applies to multi-agent). - full_fetch (bool): Whether to return extra action fetch results. + prev_action: Previous action value, if any. + prev_reward: Previous reward, if any. + info: Env info dict, if any. + policy_id: Policy to query (only applies to multi-agent). + full_fetch: Whether to return extra action fetch results. This is always set to True if RNN state is specified. - explore (bool): Whether to pick an exploitation or exploration + explore: Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). - normalize_actions (bool): Should actions be unsquashed according - to the env's/Policy's action space? - clip_actions (bool): Should actions be clipped according to the - env's/Policy's action space? + timestep: The current (sampling) time step. + episodes: This provides access to all of the internal episodes' + state, which may be useful for model-based or multi-agent + algorithms. + unsquash_actions: Should actions be unsquashed according + to the env's/Policy's action space? If None, use + self.config["normalize_actions"]. + clip_actions: Should actions be clipped according to the + env's/Policy's action space? If None, use + self.config["clip_actions"]. + + Keyword Args: + kwargs: forward compatibility placeholder Returns: any: The computed action if full_fetch=False, or tuple: The full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy. """ + if normalize_actions is not None: + deprecation_warning( + old="Trainer.compute_actions(`normalize_actions`=...)", + new="Trainer.compute_actions(`unsquash_actions`=...)", + error=False) + unsquash_actions = normalize_actions + # Preprocess obs and states. state_defined = state is not None policy = self.get_policy(policy_id) @@ -1162,23 +1251,38 @@ def compute_actions( state = list(zip(*filtered_state)) state = [np.stack(s) for s in state] + input_dict = {SampleBatch.OBS: obs_batch} + if prev_action: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action + if prev_reward: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward + if info: + input_dict[SampleBatch.INFOS] = info + for i, s in enumerate(state): + input_dict[f"state_in_{i}"] = s + # Batch compute actions - actions, states, infos = policy.compute_actions( - obs_batch, - state, - prev_action, - prev_reward, - info, - normalize_actions=normalize_actions, - clip_actions=clip_actions, - explore=explore) - - # Unbatch actions for the environment - atns, actions = space_utils.unbatch(actions), {} - for key, atn in zip(observations, atns): - actions[key] = atn - - # Unbatch states into a dict + actions, states, infos = policy.compute_actions_from_input_dict( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episodes=episodes, + ) + + # Unbatch actions for the environment into a multi-agent dict. + single_actions = space_utils.unbatch(actions) + actions = {} + for key, a in zip(observations, single_actions): + # If we work in normalized action space (normalize_actions=True), + # we re-translate here into the env's action space. + if unsquash_actions: + a = space_utils.unsquash_action(a, policy.action_space_struct) + # Clip, according to env's action space. + elif clip_actions: + a = space_utils.clip_action(a, policy.action_space_struct) + actions[key] = a + + # Unbatch states into a multi-agent dict. unbatched_states = {} for idx, agent_id in enumerate(observations): unbatched_states[agent_id] = [s[idx] for s in states] @@ -1403,6 +1507,7 @@ def collect_metrics(self, selected_workers=selected_workers) @classmethod + @override(Trainable) def resource_help(cls, config: TrainerConfigDict) -> str: return ("\n\nYou can adjust the resource requests of RLlib agents by " "setting `num_workers`, `num_gpus`, and other configs. See " @@ -1738,23 +1843,25 @@ def with_updates(**overrides) -> Type["Trainer"]: "build_trainer()` function!") def _register_if_needed(self, env_object: Union[str, EnvType, None], - config): + config) -> Optional[str]: if isinstance(env_object, str): return env_object elif isinstance(env_object, type): name = env_object.__name__ - # Add convenience `_get_spaces` method. + if config.get("remote_worker_envs"): - def _get_spaces(s): - return s.observation_space, s.action_space + @ray.remote(num_cpus=0) + class _wrapper(env_object): + # Add convenience `_get_spaces` and `_is_multi_agent` + # methods. + def _get_spaces(self): + return self.observation_space, self.action_space - env_object._get_spaces = _get_spaces + def _is_multi_agent(self): + return isinstance(self, MultiAgentEnv) - if config.get("remote_worker_envs"): - register_env( - name, - lambda cfg: ray.remote(num_cpus=0)(env_object).remote(cfg)) + register_env(name, lambda cfg: _wrapper.remote(cfg)) else: register_env(name, lambda cfg: env_object(cfg)) return name diff --git a/rllib/contrib/alpha_zero/core/alpha_zero_policy.py b/rllib/contrib/alpha_zero/core/alpha_zero_policy.py index 7b3b46e74747e..ad97829c04ba3 100644 --- a/rllib/contrib/alpha_zero/core/alpha_zero_policy.py +++ b/rllib/contrib/alpha_zero/core/alpha_zero_policy.py @@ -1,10 +1,11 @@ import numpy as np -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.contrib.alpha_zero.core.mcts import Node, RootParentNode from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY torch, _ = try_import_torch() @@ -39,9 +40,9 @@ def compute_actions(self, **kwargs): input_dict = {"obs": obs_batch} - if prev_action_batch: + if prev_action_batch is not None: input_dict["prev_actions"] = prev_action_batch - if prev_reward_batch: + if prev_reward_batch is not None: input_dict["prev_rewards"] = prev_reward_batch return self.compute_actions_from_input_dict( diff --git a/rllib/contrib/bandits/agents/policy.py b/rllib/contrib/bandits/agents/policy.py index e47c91005232c..07d837b4fc150 100644 --- a/rllib/contrib/bandits/agents/policy.py +++ b/rllib/contrib/bandits/agents/policy.py @@ -9,11 +9,11 @@ ParametricLinearModelThompsonSampling, ParametricLinearModelUCB from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import restore_original_dimensions -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.util.debug import log_once logger = logging.getLogger(__name__) diff --git a/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py b/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py index dfe3b8c85156d..4501a04357fee 100644 --- a/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py +++ b/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py @@ -7,6 +7,7 @@ from ray.rllib.contrib.bandits.agents import LinTSTrainer from ray.rllib.contrib.bandits.envs import WheelBanditEnv +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO def plot_model_weights(means, covs): @@ -43,7 +44,7 @@ def plot_model_weights(means, covs): trainer.train() info = trainer.train() - print(info["info"]["learner"]) + print(info["info"][LEARNER_INFO]) # Get model parameters means = [model.arms[i].theta.numpy() for i in range(5)] diff --git a/rllib/contrib/maddpg/maddpg_policy.py b/rllib/contrib/maddpg/maddpg_policy.py index 86e417e5d3112..51a02f35afaea 100644 --- a/rllib/contrib/maddpg/maddpg_policy.py +++ b/rllib/contrib/maddpg/maddpg_policy.py @@ -1,6 +1,5 @@ import ray from ray.rllib.agents.dqn.dqn_tf_policy import minimize_and_clip -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.postprocessing import adjust_nstep from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch @@ -9,6 +8,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils.framework import try_import_tf, try_import_tfp +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY import logging from gym.spaces import Box, Discrete diff --git a/rllib/contrib/sumo/connector.py b/rllib/contrib/sumo/connector.py index 6b1d3d1d47e35..0b795c45c8421 100644 --- a/rllib/contrib/sumo/connector.py +++ b/rllib/contrib/sumo/connector.py @@ -162,7 +162,7 @@ def _stopping_condition(self, current_step_counter, until_end): return True return False - def step(self, until_end=False, agents=set()): + def step(self, until_end=False, agents=None): """ Runs a "learning" step and returns if the simulation has finished. This function in meant to be called by the RLLIB Environment. @@ -176,6 +176,9 @@ def step(self, until_end=False, agents=set()): Return: Bool. True iff the simulation is still ongoing. """ + if agents is None: + agents = set() + # Execute SUMO steps until the learning needs to happen current_step_counter = 0 logger.debug( diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 4b2c77fe1532b..8ee302eb24683 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -1,5 +1,6 @@ from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING +import ray from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -121,10 +122,14 @@ def to_base_env( env = _VectorEnvToBaseEnv(env) else: if remote_envs: + # Determine, whether the already existing sub-env (could + # be a ray.actor) is multi-agent or not. + multiagent = ray.get(env._is_multi_agent.remote()) if \ + hasattr(env, "_is_multi_agent") else False env = RemoteVectorEnv( make_env, num_envs, - multiagent=False, + multiagent=multiagent, remote_env_batch_wait_ms=remote_env_batch_wait_ms, existing_envs=[env], ) diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index ed5705bf725d0..4840de357585a 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -163,7 +163,8 @@ def make_multi_agent(env_name_or_creator): """ class MultiEnv(MultiAgentEnv): - def __init__(self, config): + def __init__(self, config=None): + config = config or {} num = config.pop("num_agents", 1) if isinstance(env_name_or_creator, str): self.agents = [ diff --git a/rllib/env/policy_server_input.py b/rllib/env/policy_server_input.py index 26e96673adb5c..c7148a94a8a2d 100644 --- a/rllib/env/policy_server_input.py +++ b/rllib/env/policy_server_input.py @@ -89,10 +89,18 @@ def get_metrics(): # and sends data and metrics into the queues. handler = _make_handler(self.rollout_worker, self.samples_queue, self.metrics_queue) - HTTPServer.__init__(self, (address, port), handler) - - logger.info("Starting connector server at {}:{}".format( - self.server_name, self.server_port)) + try: + import time + time.sleep(1) + HTTPServer.__init__(self, (address, port), handler) + except OSError: + print(f"Creating a PolicyServer on {address}:{port} failed!") + import time + time.sleep(1) + raise + + logger.info("Starting connector server at " + f"{self.server_name}:{self.server_port}") # Start the serving thread, listening on socket and handling commands. serving_thread = threading.Thread( diff --git a/rllib/env/remote_vector_env.py b/rllib/env/remote_vector_env.py index aa2e958efee5a..2d09302f59c15 100644 --- a/rllib/env/remote_vector_env.py +++ b/rllib/env/remote_vector_env.py @@ -29,6 +29,8 @@ def __init__(self, existing_envs: Optional[List[ray.actor.ActorHandle]] = None): # Could be creating local or remote envs. self.make_env = make_env + # Whether the given `make_env` callable already returns ray.remote + # objects or not. self.make_env_creates_actors = False # Already existing env objects (generated by the RolloutWorker). self.existing_envs = existing_envs or [] @@ -50,9 +52,13 @@ def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, self.actors = [] while len(self.actors) < self.num_envs: self.actors.append(self.make_env(len(self.actors))) - # `self.make_env` produces gym.Envs (or other similar types, such + # `self.make_env` produces gym.Envs (or children thereof, such # as MultiAgentEnv): Need to auto-wrap it here. The problem with - # this is that custom methods wil get lost. + # this is that custom methods wil get lost. If you would like to + # keep your custom methods in your envs, you should provide the + # env class directly in your config (w/o tune.register_env()), + # such that your class will directly be made a @ray.remote + # (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`). else: def make_remote_env(i): @@ -125,7 +131,15 @@ def make_remote_env(i): def send_actions(self, action_dict: MultiEnvDict) -> None: for env_id, actions in action_dict.items(): actor = self.actors[env_id] - obj_ref = actor.step.remote(actions) + # `actor` is a simple single-agent (remote) env, e.g. a gym.Env + # that was made a @ray.remote. + if not self.multiagent and self.make_env_creates_actors: + obj_ref = actor.step.remote(actions[_DUMMY_AGENT_ID]) + # `actor` is already a _RemoteSingleAgentEnv or + # _RemoteMultiAgentEnv wrapper + # (handles the multi-agent action_dict automatically). + else: + obj_ref = actor.step.remote(actions) self.pending[obj_ref] = actor @override(BaseEnv) diff --git a/rllib/env/tests/test_local_inference.sh b/rllib/env/tests/test_local_inference.sh deleted file mode 100755 index be910f173c620..0000000000000 --- a/rllib/env/tests/test_local_inference.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -rm -f last_checkpoint.out -pkill -f cartpole_server.py -sleep 1 - -if [ -f test_local_inference.sh ]; then - basedir="../../examples/serving" -else - basedir="rllib/examples/serving" # In bazel. -fi - -# Start server with 2 workers (will listen on ports 9900 and 9901 for client -# connections). -# Do not attempt to restore from checkpoint; leads to errors on travis. -(python $basedir/cartpole_server.py --run=PPO --num-workers=2 --no-restore 2>&1 | grep -v 200) & -server_pid=$! - -echo "Waiting for server to start" -while ! curl localhost:9900; do - sleep 1 -done -while ! curl localhost:9901; do - sleep 1 -done - -# Start client 1 (port 9900). -sleep 2 -(python $basedir/cartpole_client.py --inference-mode=local --port=9900) & -client1_pid=$! - -# Start client 2 (port 9901). -sleep 2 -(python $basedir/cartpole_client.py --inference-mode=local --port=9901) & -client2_pid=$! - -# Start client 3 (also port 9901) and run it until it reaches 150.0 -# reward. Then stop everything. -sleep 2 -python $basedir/cartpole_client.py --stop-reward=150.0 --inference-mode=local --port=9901 - -kill $server_pid $client1_pid $client2_pid || true diff --git a/rllib/env/tests/test_policy_client_server_setup.sh b/rllib/env/tests/test_policy_client_server_setup.sh new file mode 100755 index 0000000000000..4d458ee5b8dba --- /dev/null +++ b/rllib/env/tests/test_policy_client_server_setup.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +rm -f last_checkpoint.out + +if [ "$1" == "local" ]; then + inference_mode=local +else + inference_mode=remote +fi + +if [ "$2" == "cartpole" ]; then + server_script=cartpole_server.py + client_script=cartpole_client.py + stop_criterion="--stop-reward=150.0" +else + server_script=unity3d_server.py + client_script=unity3d_dummy_client.py + stop_criterion="--num-episodes=10" +fi + +pkill -f $server_script +sleep 1 + +if [ -f test_policy_client_server_setup.sh ]; then + basedir="../../examples/serving" +else + basedir="rllib/examples/serving" # In bazel. +fi + + +# Start server with 2 workers (will listen on ports 9900 and 9901 for client +# connections). +# Do not attempt to restore from checkpoint; leads to errors on travis. +(python $basedir/$server_script --run=PPO --num-workers=2 --no-restore 2>&1 | grep -v 200) & +server_pid=$! + +echo "Waiting for server to start ..." +while ! curl localhost:9900; do + sleep 1 +done +echo "Remote worker #1 on port 9900 is up!" +while ! curl localhost:9901; do + sleep 1 +done +echo "Remote worker #2 on port 9901 is up!" + +# Start client 1 (connect to port 9900). +sleep 2 +(python $basedir/$client_script --inference-mode=$inference_mode --port=9900) & +client1_pid=$! + +# Start client 2 (connect to port 9901). +sleep 2 +(python $basedir/$client_script --inference-mode=$inference_mode --port=9901) & +client2_pid=$! + +# Start client 3 (also connecting to port 9901) and run it until it reaches +# x reward (CartPole) or n episodes (dummy Unity3D). +# Then stop everything. +sleep 2 +python $basedir/$client_script $stop_criterion --inference-mode=$inference_mode --port=9901 + +kill $server_pid $client1_pid $client2_pid || true diff --git a/rllib/env/tests/test_remote_inference.sh b/rllib/env/tests/test_remote_inference.sh deleted file mode 100755 index 1a9ead838576c..0000000000000 --- a/rllib/env/tests/test_remote_inference.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -rm -f last_checkpoint.out -pkill -f cartpole_server.py -sleep 1 - -if [ -f test_local_inference.sh ]; then - basedir="../../examples/serving" -else - basedir="rllib/examples/serving" # In bazel. -fi - -# Do not attempt to restore from checkpoint; leads to errors on travis. -(python $basedir/cartpole_server.py --run=DQN --num-workers=2 --no-restore 2>&1 | grep -v 200) & -server_pid=$! - -echo "Waiting for server to start" -while ! curl localhost:9900; do - sleep 1 -done -while ! curl localhost:9901; do - sleep 1 -done - -# Start client 1 (port 9900). -sleep 2 -(python $basedir/cartpole_client.py --inference-mode=remote --port=9900) & -client1_pid=$! - -# Start client 2 (port 9901). -sleep 2 -(python $basedir/cartpole_client.py --inference-mode=remote --port=9901) & -client2_pid=$! - -# Start client 3 (also port 9901) and run it until it reaches 150.0 -# reward. Then stop everything. -sleep 2 -python $basedir/cartpole_client.py --stop-reward=150.0 --inference-mode=remote --port=9901 - -kill $server_pid $client1_pid $client2_pid || true - diff --git a/rllib/env/tests/test_remote_worker_envs.py b/rllib/env/tests/test_remote_worker_envs.py new file mode 100644 index 0000000000000..ba80c7e4cede1 --- /dev/null +++ b/rllib/env/tests/test_remote_worker_envs.py @@ -0,0 +1,98 @@ +import gym +import numpy as np +from pettingzoo.butterfly import pistonball_v4 +from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0 +import unittest + +import ray +from ray.rllib.agents.pg import pg +from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv +from ray.rllib.examples.env.random_env import RandomEnv, RandomMultiAgentEnv +from ray.rllib.examples.remote_vector_env_with_custom_api import \ + NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv +from ray import tune + + +# Function that outputs the environment you wish to register. +def env_creator(config): + env = pistonball_v4.env(local_ratio=config.get("local_ratio", 0.2)) + env = dtype_v0(env, dtype=np.float32) + env = color_reduction_v0(env, mode="R") + env = normalize_obs_v0(env) + return env + + +tune.register_env("cartpole", lambda env_ctx: gym.make("CartPole-v0")) + +tune.register_env("pistonball", + lambda config: PettingZooEnv(env_creator(config))) + + +class TestRemoteWorkerEnvSetting(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init(num_cpus=4) + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_remote_worker_env(self): + config = pg.DEFAULT_CONFIG.copy() + config["remote_worker_envs"] = True + config["num_envs_per_worker"] = 4 + + # Simple string env definition (gym.make(...)). + config["env"] = "CartPole-v0" + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + # Using tune.register. + config["env"] = "cartpole" + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + # Using class directly. + config["env"] = RandomEnv + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + # Using class directly: Sub-class of gym.Env, + # which implements its own API. + config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + def test_remote_worker_env_multi_agent(self): + config = pg.DEFAULT_CONFIG.copy() + config["remote_worker_envs"] = True + config["num_envs_per_worker"] = 4 + + # Full classpath provided. + config["env"] = \ + "ray.rllib.examples.env.random_env.RandomMultiAgentEnv" + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + # Using tune.register. + config["env"] = "pistonball" + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + # Using class directly. + config["env"] = RandomMultiAgentEnv + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/wrappers/unity3d_env.py b/rllib/env/wrappers/unity3d_env.py index 2ec8fd6282945..2f9f75e79cb1e 100644 --- a/rllib/env/wrappers/unity3d_env.py +++ b/rllib/env/wrappers/unity3d_env.py @@ -6,6 +6,7 @@ from typing import Callable, Optional, Tuple from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID logger = logging.getLogger(__name__) @@ -304,10 +305,12 @@ def get_policy_configs_for_game( # Policies (Unity: "behaviors") and agent-to-policy mapping fns. if game_name == "SoccerStrikersVsGoalie": policies = { - "Goalie": (None, obs_spaces["Goalie"], action_spaces["Goalie"], - {}), - "Striker": (None, obs_spaces["Striker"], - action_spaces["Striker"], {}), + "Goalie": PolicySpec( + observation_space=obs_spaces["Goalie"], + action_space=action_spaces["Goalie"]), + "Striker": PolicySpec( + observation_space=obs_spaces["Striker"], + action_space=action_spaces["Striker"]), } def policy_mapping_fn(agent_id, episode, worker, **kwargs): @@ -315,8 +318,9 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): else: policies = { - game_name: (None, obs_spaces[game_name], - action_spaces[game_name], {}), + game_name: PolicySpec( + observation_space=obs_spaces[game_name], + action_space=action_spaces[game_name]), } def policy_mapping_fn(agent_id, episode, worker, **kwargs): diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 40745251e2b64..7c5415375d230 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -756,8 +756,11 @@ def postprocess_episode( "True. Alternatively, set no_done_at_end=True to " "allow this.") - other_batches = pre_batches.copy() - del other_batches[agent_id] + if len(pre_batches) > 1: + other_batches = pre_batches.copy() + del other_batches[agent_id] + else: + other_batches = {} pid = self.agent_key_to_policy_id[(episode_id, agent_id)] policy = self.policy_map[pid] if any(pre_batch[SampleBatch.DONES][:-1]) or len( diff --git a/rllib/evaluation/metrics.py b/rllib/evaluation/metrics.py index 06afe96d3fc6f..73c25f916f0bb 100644 --- a/rllib/evaluation/metrics.py +++ b/rllib/evaluation/metrics.py @@ -9,8 +9,8 @@ from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict if TYPE_CHECKING: @@ -42,7 +42,6 @@ def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict: >>> print(get_stats(grad_info)) {"vf_loss": ..., "policy_loss": ...} """ - if LEARNER_STATS_KEY in grad_info: return grad_info[LEARNER_STATS_KEY] @@ -57,10 +56,15 @@ def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict: @DeveloperAPI def collect_metrics(local_worker: Optional["RolloutWorker"] = None, - remote_workers: List[ActorHandle] = [], - to_be_collected: List[ObjectRef] = [], + remote_workers: Optional[List[ActorHandle]] = None, + to_be_collected: Optional[List[ObjectRef]] = None, timeout_seconds: int = 180) -> ResultDict: """Gathers episode metrics from RolloutWorker instances.""" + if remote_workers is None: + remote_workers = [] + + if to_be_collected is None: + to_be_collected = [] episodes, to_be_collected = collect_episodes( local_worker, @@ -74,11 +78,16 @@ def collect_metrics(local_worker: Optional["RolloutWorker"] = None, @DeveloperAPI def collect_episodes( local_worker: Optional["RolloutWorker"] = None, - remote_workers: List[ActorHandle] = [], - to_be_collected: List[ObjectRef] = [], + remote_workers: Optional[List[ActorHandle]] = None, + to_be_collected: Optional[List[ObjectRef]] = None, timeout_seconds: int = 180 ) -> Tuple[List[Union[RolloutMetrics, OffPolicyEstimate]], List[ObjectRef]]: """Gathers new episodes metrics tuples from the given evaluators.""" + if remote_workers is None: + remote_workers = [] + + if to_be_collected is None: + to_be_collected = [] if remote_workers: pending = [ diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index a703b9f0a66e1..7151851587f73 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -860,14 +860,15 @@ def compute_gradients( summarize(samples))) if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} - if self.tf_sess is not None: - builder = TFRunBuilder(self.tf_sess, "compute_gradients") + if self.policy_config.get("framework") == "tf": for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue + policy = self.policy_map[pid] + builder = TFRunBuilder(policy.get_session(), + "compute_gradients") grad_out[pid], info_out[pid] = ( - self.policy_map[pid]._build_compute_gradients( - builder, batch)) + policy._build_compute_gradients(builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: @@ -897,14 +898,21 @@ def apply_gradients(self, grads: ModelGradients) -> Dict[PolicyID, Any]: if log_once("apply_gradients"): logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) if isinstance(grads, dict): - if self.tf_sess is not None: - builder = TFRunBuilder(self.tf_sess, "apply_gradients") - outputs = { - pid: self.policy_map[pid]._build_apply_gradients( - builder, grad) - for pid, grad in grads.items() + if self.policy_config.get("framework") == "tf": + builders = {} + outputs = {} + for pid, grad in grads.items(): + if pid not in self.policies_to_train: + continue + policy = self.policy_map[pid] + builders[pid] = TFRunBuilder(policy.get_session(), + "apply_gradients") + outputs[pid] = policy._build_apply_gradients( + builders[pid], grad) + return { + pid: builders[pid].get(op) + for pid, op in outputs.items() } - return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 0737355dc0dfe..09fdb3b968dea 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -179,7 +179,7 @@ def central_vf_stats(policy, train_batch, grads): return { "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], - policy._central_value_out), + policy._central_value_out) } diff --git a/rllib/examples/custom_keras_model.py b/rllib/examples/custom_keras_model.py index cec793dd17bb6..c1c419d50e545 100644 --- a/rllib/examples/custom_keras_model.py +++ b/rllib/examples/custom_keras_model.py @@ -11,9 +11,10 @@ from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY tf1, tf, tfv = try_import_tf() @@ -110,7 +111,7 @@ def metrics(self): # Tests https://github.com/ray-project/ray/issues/7293 def check_has_custom_metric(result): - r = result["result"]["info"]["learner"] + r = result["result"]["info"][LEARNER_INFO] if DEFAULT_POLICY_ID in r: r = r[DEFAULT_POLICY_ID].get(LEARNER_STATS_KEY, r[DEFAULT_POLICY_ID]) diff --git a/rllib/examples/custom_model_loss_and_metrics.py b/rllib/examples/custom_model_loss_and_metrics.py index 9cea42cdf639a..6a38084f01188 100644 --- a/rllib/examples/custom_model_loss_and_metrics.py +++ b/rllib/examples/custom_model_loss_and_metrics.py @@ -19,9 +19,10 @@ from ray.rllib.examples.models.custom_loss_model import CustomLossModel, \ TorchCustomLossModel from ray.rllib.models import ModelCatalog -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY tf1, tf, tfv = try_import_tf() @@ -83,9 +84,9 @@ # Torch metrics structure. if args.framework == "torch": - assert LEARNER_STATS_KEY in info["learner"][DEFAULT_POLICY_ID] - assert "model" in info["learner"][DEFAULT_POLICY_ID] - assert "custom_metrics" in info["learner"][DEFAULT_POLICY_ID] + assert LEARNER_STATS_KEY in info[LEARNER_INFO][DEFAULT_POLICY_ID] + assert "model" in info[LEARNER_INFO][DEFAULT_POLICY_ID] + assert "custom_metrics" in info[LEARNER_INFO][DEFAULT_POLICY_ID] # TODO: (sven) Make sure the metrics structure gets unified between # tf and torch. Tf should work like current torch: @@ -96,4 +97,5 @@ # model: [return values of ModelV2's `metrics` method] # custom_metrics: [return values of callback: `on_learn_on_batch`] else: - assert "model" in info["learner"][DEFAULT_POLICY_ID][LEARNER_STATS_KEY] + assert "model" in info[LEARNER_INFO][DEFAULT_POLICY_ID][ + LEARNER_STATS_KEY] diff --git a/rllib/examples/deterministic_training.py b/rllib/examples/deterministic_training.py index 528e002971c43..e6fd21e56a9c3 100644 --- a/rllib/examples/deterministic_training.py +++ b/rllib/examples/deterministic_training.py @@ -10,6 +10,7 @@ from ray.rllib.examples.env.env_using_remote_actor import \ CartPoleWithRemoteParamServer, ParameterStorage from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import check parser = argparse.ArgumentParser() @@ -60,6 +61,7 @@ check(results1["hist_stats"], results2["hist_stats"]) # As well as training behavior (minibatch sequence during SGD # iterations). - check(results1["info"]["learner"][DEFAULT_POLICY_ID]["learner_stats"], - results2["info"]["learner"][DEFAULT_POLICY_ID]["learner_stats"]) + check( + results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"], + results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"]) ray.shutdown() diff --git a/rllib/examples/env/coin_game_non_vectorized_env.py b/rllib/examples/env/coin_game_non_vectorized_env.py index 5d725ade56d5d..e773bab36a6b9 100644 --- a/rllib/examples/env/coin_game_non_vectorized_env.py +++ b/rllib/examples/env/coin_game_non_vectorized_env.py @@ -13,7 +13,7 @@ from gym.utils import seeding from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.utils import override -from typing import Dict +from typing import Dict, Optional from ray.rllib.examples.env.utils.interfaces import InfoAccumulationInterface @@ -36,7 +36,9 @@ class CoinGame(InfoAccumulationInterface, MultiAgentEnv, gym.Env): np.array([-1, 0]), ] - def __init__(self, config: Dict = {}): + def __init__(self, config: Optional[Dict] = None): + if config is None: + config = {} self._validate_config(config) @@ -325,7 +327,10 @@ def _init_info(self): class AsymCoinGame(CoinGame): NAME = "AsymCoinGame" - def __init__(self, config: dict = {}): + def __init__(self, config: Optional[dict] = None): + if config is None: + config = {} + if "asymmetric" in config: assert config["asymmetric"] else: diff --git a/rllib/examples/env/coin_game_vectorized_env.py b/rllib/examples/env/coin_game_vectorized_env.py index a71fa4327d399..546a9b1a815b0 100644 --- a/rllib/examples/env/coin_game_vectorized_env.py +++ b/rllib/examples/env/coin_game_vectorized_env.py @@ -21,7 +21,9 @@ class VectorizedCoinGame(CoinGame): Vectorized Coin Game environment. """ - def __init__(self, config={}): + def __init__(self, config=None): + if config is None: + config = {} super().__init__(config) @@ -159,7 +161,10 @@ def _load_env(self, env_state): class AsymVectorizedCoinGame(VectorizedCoinGame): NAME = "AsymCoinGame" - def __init__(self, config={}): + def __init__(self, config=None): + if config is None: + config = {} + if "asymmetric" in config: assert config["asymmetric"] else: diff --git a/rllib/examples/env/matrix_sequential_social_dilemma.py b/rllib/examples/env/matrix_sequential_social_dilemma.py index 97d222b3cff20..9348a184890b8 100644 --- a/rllib/examples/env/matrix_sequential_social_dilemma.py +++ b/rllib/examples/env/matrix_sequential_social_dilemma.py @@ -8,7 +8,7 @@ import logging from abc import ABC from collections import Iterable -from typing import Dict +from typing import Dict, Optional import numpy as np from gym.spaces import Discrete @@ -39,7 +39,9 @@ class MatrixSequentialSocialDilemma(InfoAccumulationInterface, MultiAgentEnv, episode. """ - def __init__(self, config: Dict = {}): + def __init__(self, config: Optional[Dict] = None): + if config is None: + config = {} assert "reward_randomness" not in config.keys() assert self.PAYOUT_MATRIX is not None diff --git a/rllib/examples/env/random_env.py b/rllib/examples/env/random_env.py index b6b451fef7c33..ceeca23424c24 100644 --- a/rllib/examples/env/random_env.py +++ b/rllib/examples/env/random_env.py @@ -14,7 +14,9 @@ class RandomEnv(gym.Env): configured as well. """ - def __init__(self, config): + def __init__(self, config=None): + config = config or {} + # Action space. self.action_space = config.get("action_space", Discrete(2)) # Observation space from which to sample. @@ -63,3 +65,25 @@ def step(self, action): # Multi-agent version of the RandomEnv. RandomMultiAgentEnv = make_multi_agent(lambda c: RandomEnv(c)) + + +# Large observation space "pre-compiled" random env (for testing). +class RandomLargeObsSpaceEnv(RandomEnv): + def __init__(self, config=None): + config = config or {} + config.update({ + "observation_space": gym.spaces.Box(-1.0, 1.0, (5000, )) + }) + super().__init__(config=config) + + +# Large observation space + cont. actions "pre-compiled" random env +# (for testing). +class RandomLargeObsSpaceEnvContActions(RandomEnv): + def __init__(self, config=None): + config = config or {} + config.update({ + "observation_space": gym.spaces.Box(-1.0, 1.0, (5000, )), + "action_space": gym.spaces.Box(-1.0, 1.0, (5, )), + }) + super().__init__(config=config) diff --git a/rllib/examples/pettingzoo_env.py b/rllib/examples/pettingzoo_env.py index 5eeb962200849..661f03f012088 100644 --- a/rllib/examples/pettingzoo_env.py +++ b/rllib/examples/pettingzoo_env.py @@ -42,19 +42,17 @@ def env_creator(config): # Register env register_env("pistonball", lambda config: PettingZooEnv(env_creator(config))) - env = PettingZooEnv(env_creator(config)) - observation_space = env.observation_space - action_space = env.action_space - del env # Configuration for multiagent setup with policy sharing: config["multiagent"] = { - # Setup a single, shared policy for all agents. - "policies": { - "av": (None, observation_space, action_space, {}) - }, - # Map all agents to that policy. - "policy_mapping_fn": lambda agent_id, episode, **kwargs: "av", + # Setup a single, shared policy for all agents: "av". + # Use a simple set of strings (PolicyID) here. RLlib will + # automatically determine the policy class (Trainer's default class), + # observation- and action spaces (inferred from the env), and + # config overrides ({} in this case). + "policies": {"av"}, + # Map all agents to the "av" PolicyID. + "policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: "av", } # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. diff --git a/rllib/examples/remote_vector_env_with_custom_api.py b/rllib/examples/remote_vector_env_with_custom_api.py index 1dcc65eda89f8..c212249990611 100644 --- a/rllib/examples/remote_vector_env_with_custom_api.py +++ b/rllib/examples/remote_vector_env_with_custom_api.py @@ -65,7 +65,7 @@ class NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv(TaskSettableEnv): of gym.Env). """ - def __init__(self, config): + def __init__(self, config=None): self.action_space = gym.spaces.Box(0, 1, shape=(1, )) self.observation_space = gym.spaces.Box(0, 1, shape=(2, )) self.task = 1 @@ -108,7 +108,6 @@ def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: # Specify your custom (single, non-vectorized) env directly as a # class. This way, RLlib can auto-create Actors from this class # and handle everything correctly. - # TODO: Test for multi-agent case. "env": NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv, # Set up our own callbacks. "callbacks": TaskSettingCallback, diff --git a/rllib/examples/rock_paper_scissors_multiagent.py b/rllib/examples/rock_paper_scissors_multiagent.py index bc7477a7f0716..0905314c1140b 100644 --- a/rllib/examples/rock_paper_scissors_multiagent.py +++ b/rllib/examples/rock_paper_scissors_multiagent.py @@ -9,19 +9,19 @@ import argparse import os +from pettingzoo.classic import rps_v2 import random from ray import tune from ray.rllib.agents.pg import PGTrainer, PGTFPolicy, PGTorchPolicy from ray.rllib.agents.registry import get_trainer_class +from ray.rllib.env import PettingZooEnv from ray.rllib.examples.policy.rock_paper_scissors_dummies import \ BeatLastHeuristic, AlwaysSameHeuristic from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved from ray.tune.registry import register_env -from ray.rllib.env import PettingZooEnv -from pettingzoo.classic import rps_v2 tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -149,8 +149,8 @@ def entropy_policy_gradient_loss(policy, model, dist_class, train_batch): logits, _ = model.from_batch(train_batch) action_dist = dist_class(logits, model) if args.framework == "torch": - # required by PGTorchPolicy's stats fn. - policy.pi_err = torch.tensor([0.0]) + # Required by PGTorchPolicy's stats fn. + model.tower_stats["policy_loss"] = torch.tensor([0.0]) return torch.mean(-0.1 * action_dist.entropy() - (action_dist.logp(train_batch["actions"]) * train_batch["advantages"])) diff --git a/rllib/examples/serving/cartpole_client.py b/rllib/examples/serving/cartpole_client.py index 4f9f247eda49b..a368e6b44b852 100755 --- a/rllib/examples/serving/cartpole_client.py +++ b/rllib/examples/serving/cartpole_client.py @@ -54,7 +54,7 @@ "(Policy-computed) ones.") parser.add_argument( "--stop-reward", - type=int, + type=float, default=9999, help="Stop once the specified reward is reached.") parser.add_argument( diff --git a/rllib/examples/serving/unity3d_client.py b/rllib/examples/serving/unity3d_client.py index 8c8784ebf18ab..f3089abd402ae 100644 --- a/rllib/examples/serving/unity3d_client.py +++ b/rllib/examples/serving/unity3d_client.py @@ -52,9 +52,13 @@ parser.add_argument( "--server", type=str, - default=SERVER_ADDRESS + ":" + str(SERVER_PORT), - help="The Policy server's address and port to connect to from this client." -) + default=SERVER_ADDRESS, + help="The Policy server's address to connect to from this client.") +parser.add_argument( + "--port", + type=int, + default=SERVER_PORT, + help="The port to use (on --server).") parser.add_argument( "--no-train", action="store_true", @@ -75,7 +79,7 @@ "learnt policy weights from the server?") parser.add_argument( "--stop-reward", - type=int, + type=float, default=9999, help="Stop once the specified reward is reached.") @@ -85,7 +89,7 @@ # Start the client for sending environment information (e.g. observations, # actions) to a policy server (listening on port 9900). client = PolicyClient( - "http://" + args.server, + "http://" + args.server + ":" + str(args.port), inference_mode=args.inference_mode, update_interval=args.update_interval_local_mode) diff --git a/rllib/examples/serving/unity3d_dummy_client.py b/rllib/examples/serving/unity3d_dummy_client.py new file mode 100644 index 0000000000000..93e7245f31a43 --- /dev/null +++ b/rllib/examples/serving/unity3d_dummy_client.py @@ -0,0 +1,144 @@ +""" +Dummy in-place replacement for the unity3d_client.py script +in case you don't have an actual Unity3D engine installed or just want +to test client/server connectivity with the unity3d_server.py script. + +This client script simply uses RLlib's RandomMultiAgentEnv to mimic +one of the ML Agents (Unity3D) example games (e.g. "3DBall"). + +To run this script on possibly different machines +against a central Policy server: + +1) Run (two separate shells/machines): +$ python unity3d_server.py --env 3DBall +$ python unity3d_dummy_client.py --env 3DBall --inference-mode=local +""" + +import argparse + +from ray.rllib.env.policy_client import PolicyClient +from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv +from ray.rllib.examples.env.random_env import RandomMultiAgentEnv + +SERVER_ADDRESS = "localhost" +SERVER_PORT = 9900 + +parser = argparse.ArgumentParser() +parser.add_argument( + "--env", + type=str, + default="3DBall", + choices=[ + "3DBall", "3DBallHard", "FoodCollector", "GridFoodCollector", + "Pyramids", "Sorter", "Tennis", "VisualHallway", "Walker" + ], + help="The name of the Env to mimic. Only those examples supported so " + "far for which all agents have the same " + "observation- and action spaces (feel free to add more to this script!)") +parser.add_argument( + "--horizon", + type=int, + default=200, + help="The max. number of `step()`s for any episode (per agent) before " + "it'll be reset again automatically.") +parser.add_argument( + "--server", + type=str, + default=SERVER_ADDRESS, + help="The Policy server's address to connect to from this client.") +parser.add_argument( + "--port", + type=int, + default=SERVER_PORT, + help="The port to use (on --server).") +parser.add_argument( + "--no-train", + action="store_true", + help="Whether to disable training (on the server side).") +parser.add_argument( + "--inference-mode", + type=str, + default="local", + choices=["local", "remote"], + help="Whether to compute actions `local`ly or `remote`ly. Note that " + "`local` is much faster b/c observations/actions do not have to be " + "sent via the network.") +parser.add_argument( + "--update-interval-local-mode", + type=float, + default=10.0, + help="For `inference-mode=local`, every how many seconds do we update " + "learnt policy weights from the server?") +parser.add_argument( + "--num-episodes", + type=int, + default=10, + help="Stop once the specified number of episodes have been played.") + +if __name__ == "__main__": + args = parser.parse_args() + + # Start the client for sending environment information (e.g. observations, + # actions) to a policy server (listening on port 9900). + client = PolicyClient( + "http://" + args.server + ":" + str(args.port), + inference_mode=args.inference_mode, + update_interval=args.update_interval_local_mode) + + # Get the multi-agent policies dict and agent->policy + # mapping-fn. + policies, policy_mapping_fn = \ + Unity3DEnv.get_policy_configs_for_game(args.env) + + # Make sure all policies' obs- and action spaces are the same. + # If not, we won't be able to mimic the Unity3D env using RLlib's + # RandomMultiAgentEnv. + first_policy_spec = next(iter(policies.values())) + for pid, policy_spec in policies.items(): + assert policy_spec.observation_space == \ + first_policy_spec.observation_space + assert policy_spec.action_space == first_policy_spec.action_space + + # Start and reset the actual Unity3DEnv (either already running Unity3D + # editor or a binary (game) to be started automatically). + env = RandomMultiAgentEnv({ + # Same number of agents as the actual Unity3D game would have. + "num_agents": len(policies), + # Make sure we stick to the user given horizons using our + # RandomMultiAgentEnv options. + "max_episode_len": args.horizon, + "p_done": 0.0, + # Same obs- action spaces as the actual Unity3D game would have. + "observation_space": first_policy_spec.observation_space, + "action_space": first_policy_spec.action_space, + }) + obs = env.reset() + eid = client.start_episode(training_enabled=not args.no_train) + + # Keep track of the total reward per episode. + total_rewards_this_episode = 0.0 + + # Loop through the env until n episodes completed. + num_episodes = 0 + while True: + # Get actions from the Policy server given our current obs. + actions = client.get_action(eid, obs) + # Apply actions to our env. + obs, rewards, dones, infos = env.step(actions) + total_rewards_this_episode += sum(rewards.values()) + # Log rewards and single-agent dones. + client.log_returns(eid, rewards, infos, multiagent_done_dict=dones) + # Check whether all agents are done and end the episode, if necessary. + if dones["__all__"]: + print("Episode done: Reward={}".format(total_rewards_this_episode)) + + num_episodes += 1 + if num_episodes >= args.num_episodes: + quit(0) + + # End the episode and reset dummy Env. + total_rewards_this_episode = 0.0 + client.end_episode(eid, obs) + obs = env.reset() + # Start a new episode. + eid = client.start_episode(training_enabled=not args.no_train) diff --git a/rllib/examples/serving/unity3d_server.py b/rllib/examples/serving/unity3d_server.py index 56c1a0089fe50..04ce5567fc165 100755 --- a/rllib/examples/serving/unity3d_server.py +++ b/rllib/examples/serving/unity3d_server.py @@ -31,24 +31,42 @@ import os import ray -from ray.tune import register_env -from ray.rllib.agents.ppo import PPOTrainer +from ray.rllib.agents.registry import get_trainer_class from ray.rllib.env.policy_server_input import PolicyServerInput from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv -from ray.rllib.examples.env.random_env import RandomMultiAgentEnv SERVER_ADDRESS = "localhost" SERVER_PORT = 9900 CHECKPOINT_FILE = "last_checkpoint_{}.out" parser = argparse.ArgumentParser() +parser.add_argument( + "--run", + default="PPO", + choices=["DQN", "PPO"], + help="The RLlib-registered algorithm to use.") +parser.add_argument( + "--framework", + choices=["tf", "tf2", "tfe", "torch"], + default="tf", + help="The DL framework specifier.") +parser.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of workers to use. Each worker will create " + "its own listening socket for incoming experiences.") parser.add_argument( "--env", type=str, default="3DBall", - choices=["3DBall", "SoccerStrikersVsGoalie"], - help="The name of the Env to run in the Unity3D editor. Either `3DBall` " - "or `SoccerStrikersVsGoalie` (feel free to add more to this script!)") + choices=[ + "3DBall", "3DBallHard", "FoodCollector", "GridFoodCollector", + "Pyramids", "SoccerStrikersVsGoalie", "Sorter", "Tennis", + "VisualHallway", "Walker" + ], + help="The name of the Env to run in the Unity3D editor " + "(feel free to add more to this script!)") parser.add_argument( "--port", type=int, @@ -71,11 +89,21 @@ args = parser.parse_args() ray.init() - # Create a fake-env for the server. This env will never be used (neither - # for sampling, nor for evaluation) and its obs/action Spaces do not - # matter either (multi-agent config below defines Spaces per Policy). - register_env("fake_unity", lambda c: RandomMultiAgentEnv(c)) - + # `InputReader` generator (returns None if no input reader is needed on + # the respective worker). + def _input(ioctx): + # We are remote worker or we are local worker with num_workers=0: + # Create a PolicyServerInput. + if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0: + return PolicyServerInput( + ioctx, SERVER_ADDRESS, args.port + ioctx.worker_index - + (1 if ioctx.worker_index > 0 else 0)) + # No InputReader (PolicyServerInput) needed. + else: + return None + + # Get the multi-agent policies dict and agent->policy + # mapping-fn. policies, policy_mapping_fn = \ Unity3DEnv.get_policy_configs_for_game(args.env) @@ -83,27 +111,31 @@ # build their own samplers (and also Policy objects iff # `inference_mode=local` on clients' command line). config = { - # Use the connector server to generate experiences. - "input": ( - lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, args.port)), - # Use a single worker process (w/ SyncSampler) to run the server. - "num_workers": 0, + # Indicate that the Trainer we setup here doesn't need an actual env. + # Allow spaces to be determined by user (see below). + "env": None, + + # Use the `PolicyServerInput` to generate experiences. + "input": _input, + # Use n worker processes to listen on different ports. + "num_workers": args.num_workers, # Disable OPE, since the rollouts are coming from online clients. "input_evaluation": [], # Other settings. "train_batch_size": 256, "rollout_fragment_length": 20, - # Multi-agent setup for the particular env. + # Multi-agent setup for the given env. "multiagent": { "policies": policies, "policy_mapping_fn": policy_mapping_fn, }, - "framework": "tf", + # DL framework to use. + "framework": args.framework, } # Create the Trainer used for Policy serving. - trainer = PPOTrainer(env="fake_unity", config=config) + trainer = get_trainer_class(args.run)(config=config) # Attempt to restore from checkpoint if possible. checkpoint_path = CHECKPOINT_FILE.format(args.env) diff --git a/rllib/examples/trajectory_view_api.py b/rllib/examples/trajectory_view_api.py index 31ce04e879126..b4a288e013bd5 100644 --- a/rllib/examples/trajectory_view_api.py +++ b/rllib/examples/trajectory_view_api.py @@ -1,13 +1,15 @@ import argparse +import numpy as np import ray -from ray import tune +from ray.rllib.agents.ppo import PPOTrainer from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole from ray.rllib.examples.models.trajectory_view_utilizing_models import \ FrameStackingCartPoleModel, TorchFrameStackingCartPoleModel from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_learning_achieved +from ray import tune tf1, tf, tfv = try_import_tf() @@ -47,18 +49,19 @@ args = parser.parse_args() ray.init(num_cpus=3) + num_frames = 16 + ModelCatalog.register_custom_model( "frame_stack_model", FrameStackingCartPoleModel if args.framework != "torch" else TorchFrameStackingCartPoleModel) - tune.register_env("stateless_cartpole", lambda c: StatelessCartPole()) config = { - "env": "stateless_cartpole", + "env": StatelessCartPole, "model": { "vf_share_layers": True, "custom_model": "frame_stack_model", "custom_model_config": { - "num_frames": 16, + "num_frames": num_frames, }, # To compare against a simple LSTM: @@ -81,8 +84,45 @@ "timesteps_total": args.stop_timesteps, "episode_reward_mean": args.stop_reward, } - results = tune.run(args.run, config=config, stop=stop, verbose=2) + results = tune.run( + args.run, config=config, stop=stop, verbose=2, checkpoint_at_end=True) if args.as_test: check_learning_achieved(results, args.stop_reward) + + checkpoints = results.get_trial_checkpoints_paths( + trial=results.get_best_trial("episode_reward_mean", mode="max"), + metric="episode_reward_mean") + + checkpoint_path = checkpoints[0][0] + trainer = PPOTrainer(config) + trainer.restore(checkpoint_path) + + # Inference loop. + env = StatelessCartPole() + + # Run manual inference loop for n episodes. + for _ in range(10): + episode_reward = 0.0 + reward = 0.0 + action = 0 + done = False + obs = env.reset() + while not done: + # Create a dummy action using the same observation n times, + # as well as dummy prev-n-actions and prev-n-rewards. + action, state, logits = trainer.compute_single_action( + input_dict={ + "obs": obs, + "prev_n_obs": np.stack([obs for _ in range(num_frames)]), + "prev_n_actions": np.stack([0 for _ in range(num_frames)]), + "prev_n_rewards": np.stack( + [1.0 for _ in range(num_frames)]), + }, + full_fetch=True) + obs, reward, done, info = env.step(action) + episode_reward += reward + + print(f"Episode reward={episode_reward}") + ray.shutdown() diff --git a/rllib/execution/common.py b/rllib/execution/common.py index 3349541dac2f6..25e4428bffb63 100644 --- a/rllib/execution/common.py +++ b/rllib/execution/common.py @@ -22,9 +22,6 @@ LEARN_ON_BATCH_TIMER = "learn" LOAD_BATCH_TIMER = "load" -# Instant metrics (keys for metrics.info). -LEARNER_INFO = "learner" - # Asserts that an object is a type of SampleBatch. def _check_sample_batch_type(batch: SampleBatchType) -> None: diff --git a/rllib/execution/learner_thread.py b/rllib/execution/learner_thread.py index be7b028cdb04f..d8c6f93c146b1 100644 --- a/rllib/execution/learner_thread.py +++ b/rllib/execution/learner_thread.py @@ -3,10 +3,11 @@ import threading from typing import Dict, Optional -from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.execution.minibatch_buffer import MinibatchBuffer from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ + LEARNER_INFO, LEARNER_STATS_KEY from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat from ray.util.iter import _NextValueNotReady @@ -56,7 +57,7 @@ def __init__(self, local_worker: RolloutWorker, minibatch_buffer_size: int, self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False - self.stats = {} + self.learner_info = {} self.stopped = False self.num_steps = 0 @@ -75,12 +76,24 @@ def step(self) -> Optional[_NextValueNotReady]: return _NextValueNotReady() with self.grad_timer: - fetches = self.local_worker.learn_on_batch(batch) + # Use LearnerInfoBuilder as a unified way to build the final + # results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same structure + # no matter the setup (multi-GPU, multi-agent, minibatch SGD, + # tf vs torch). + learner_info_builder = LearnerInfoBuilder(num_devices=1) + multi_agent_results = self.local_worker.learn_on_batch(batch) + for pid, results in multi_agent_results.items(): + learner_info_builder.add_learn_on_batch_results(results, pid) + self.learner_info = learner_info_builder.finalize() + learner_stats = { + pid: info[LEARNER_STATS_KEY] + for pid, info in self.learner_info.items() + } self.weights_updated = True - self.stats = get_learner_stats(fetches) self.num_steps += 1 - self.outqueue.put((batch.count, self.stats)) + self.outqueue.put((batch.count, learner_stats)) self.learner_queue_size.push(self.inqueue.qsize()) def add_learner_metrics(self, result: Dict) -> Dict: @@ -91,7 +104,7 @@ def timer_to_ms(timer): result["info"].update({ "learner_queue": self.learner_queue_size.stats(), - "learner": copy.deepcopy(self.stats), + LEARNER_INFO: copy.deepcopy(self.learner_info), "timing_breakdown": { "learner_grad_time_ms": timer_to_ms(self.grad_timer), "learner_load_time_ms": timer_to_ms(self.load_timer), diff --git a/rllib/execution/multi_gpu_learner_thread.py b/rllib/execution/multi_gpu_learner_thread.py index 0d230878ff609..1120be7a77d4b 100644 --- a/rllib/execution/multi_gpu_learner_thread.py +++ b/rllib/execution/multi_gpu_learner_thread.py @@ -1,15 +1,15 @@ import logging -import threading - from six.moves import queue +import threading -from ray.rllib.evaluation.metrics import get_learner_stats -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.execution.learner_thread import LearnerThread from ray.rllib.execution.minibatch_buffer import MinibatchBuffer +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ + LEARNER_STATS_KEY from ray.rllib.utils.timer import TimerStat from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -103,18 +103,14 @@ def __init__( self.train_batch_size = train_batch_size - # TODO: (sven) Allow multi-GPU to work for multi-agent as well. - self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID] + self.policy_map = self.local_worker.policy_map + self.devices = next(iter(self.policy_map.values())).devices - logger.info("MultiGPULearnerThread devices {}".format( - self.policy.devices)) - assert self.train_batch_size % len(self.policy.devices) == 0 - assert self.train_batch_size >= len(self.policy.devices),\ + logger.info("MultiGPULearnerThread devices {}".format(self.devices)) + assert self.train_batch_size % len(self.devices) == 0 + assert self.train_batch_size >= len(self.devices),\ "batch too small" - if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}: - raise NotImplementedError("Multi-gpu mode for multi-agent") - self.tower_stack_indices = list(range(num_multi_gpu_tower_stacks)) # Two queues for tower stacks: @@ -146,18 +142,39 @@ def step(self) -> None: with self.load_wait_timer: buffer_idx, released = self.ready_tower_stacks_buffer.get() + get_num_samples_loaded_into_buffer = 0 with self.grad_timer: - fetches = self.policy.learn_on_loaded_batch( - offset=0, buffer_index=buffer_idx) - self.weights_updated = True - self.stats = {DEFAULT_POLICY_ID: get_learner_stats(fetches)} + # Use LearnerInfoBuilder as a unified way to build the final + # results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same structure + # no matter the setup (multi-GPU, multi-agent, minibatch SGD, + # tf vs torch). + learner_info_builder = LearnerInfoBuilder( + num_devices=len(self.devices)) + + for pid in self.policy_map.keys(): + # Not a policy-to-train. + if pid not in self.local_worker.policies_to_train: + continue + policy = self.policy_map[pid] + default_policy_results = policy.learn_on_loaded_batch( + offset=0, buffer_index=buffer_idx) + learner_info_builder.add_learn_on_batch_results( + default_policy_results) + self.weights_updated = True + get_num_samples_loaded_into_buffer += \ + policy.get_num_samples_loaded_into_buffer(buffer_idx) + + self.learner_info = learner_info_builder.finalize() + learner_stats = { + pid: self.learner_info[pid][LEARNER_STATS_KEY] + for pid in self.learner_info.keys() + } if released: self.idle_tower_stacks.put(buffer_idx) - self.outqueue.put( - (self.policy.get_num_samples_loaded_into_buffer(buffer_idx), - self.stats)) + self.outqueue.put((get_num_samples_loaded_into_buffer, learner_stats)) self.learner_queue_size.push(self.inqueue.qsize()) @@ -180,7 +197,7 @@ def run(self) -> None: def _step(self) -> None: s = self.multi_gpu_learner_thread - policy = s.policy + policy_map = s.policy_map # Get a new batch from the data (inqueue). with self.queue_timer: @@ -191,7 +208,14 @@ def _step(self) -> None: # Load the batch into the idle stack. with self.load_timer: - policy.load_batch_into_buffer(batch=batch, buffer_index=buffer_idx) + for pid in policy_map.keys(): + if pid not in s.local_worker.policies_to_train: + continue + policy = policy_map[pid] + policy.load_batch_into_buffer( + batch=batch if isinstance(batch, SampleBatch) else + batch.policy_batches[pid], + buffer_index=buffer_idx) # Tag just-loaded stack as "ready". s.ready_tower_stacks.put(buffer_idx) diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 364a814c8c996..1f65620b115eb 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -4,14 +4,15 @@ from ray.util.iter import from_actors, LocalIterator from ray.util.iter_metrics import SharedMetrics -from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \ - STEPS_SAMPLED_COUNTER, LEARNER_INFO, SAMPLE_TIMER, GRAD_WAIT_TIMER, \ + STEPS_SAMPLED_COUNTER, SAMPLE_TIMER, GRAD_WAIT_TIMER, \ _check_sample_batch_type, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.sgd import standardized from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients @@ -130,7 +131,9 @@ def __call__(self, item): (grads, info), count = item metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count - metrics.info[LEARNER_INFO] = get_learner_stats(info) + metrics.info[LEARNER_INFO] = { + DEFAULT_POLICY_ID: info + } if LEARNER_STATS_KEY in info else info metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() - self.fetch_start_time) return grads, count @@ -162,15 +165,24 @@ def __init__(self, min_batch_size: int, count_steps_by: str = "env_steps"): def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) - self.buffer.append(batch) if self.count_steps_by == "env_steps": - self.count += batch.count + size = batch.count else: assert isinstance(batch, MultiAgentBatch), \ "`count_steps_by=agent_steps` only allowed in multi-agent " \ "environments!" - self.count += batch.agent_steps() + size = batch.agent_steps() + + # Incoming batch is an empty dummy batch -> Ignore. + # Possibly produced automatically by a PolicyServer to unblock + # an external env waiting for inputs from unresponsive/disconnected + # client(s). + if size == 0: + return [] + + self.count += size + self.buffer.append(batch) if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index 6c0e089ef598a..e289d5a7f2fbb 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -1,22 +1,21 @@ import logging import numpy as np import math -import tree # pip install dm_tree from typing import List, Tuple, Any import ray -from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import \ AGENT_STEPS_TRAINED_COUNTER, APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, \ - LAST_TARGET_UPDATE_TS, LEARNER_INFO, LEARN_ON_BATCH_TIMER, \ + LAST_TARGET_UPDATE_TS, LEARN_ON_BATCH_TIMER, \ LOAD_BATCH_TIMER, NUM_TARGET_UPDATES, STEPS_SAMPLED_COUNTER, \ STEPS_TRAINED_COUNTER, WORKER_UPDATE_TIMER, _check_sample_batch_type, \ _get_global_vars, _get_shared_metrics -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ + LEARNER_INFO from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients @@ -62,7 +61,7 @@ def __call__(self, # train batch and loop through train batch `num_sgd_iter` times. if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: lw = self.workers.local_worker() - info = do_minibatch_sgd( + learner_info = do_minibatch_sgd( batch, { pid: lw.get_policy(pid) for pid in self.policies @@ -70,9 +69,10 @@ def __call__(self, }, lw, self.num_sgd_iter, self.sgd_minibatch_size, []) # Single update step using train batch. else: - info = self.workers.local_worker().learn_on_batch(batch) + learner_info = \ + self.workers.local_worker().learn_on_batch(batch) - metrics.info[LEARNER_INFO] = info + metrics.info[LEARNER_INFO] = learner_info learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): @@ -88,7 +88,7 @@ def __call__(self, e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) - return batch, info + return batch, learner_info class MultiGPUTrainOneStep: @@ -174,56 +174,43 @@ def __call__(self, # Execute minibatch SGD on loaded data. with learn_timer: - fetches = {} + # Use LearnerInfoBuilder as a unified way to build the final + # results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same structure + # no matter the setup (multi-GPU, multi-agent, minibatch SGD, + # tf vs torch). + learner_info_builder = LearnerInfoBuilder( + num_devices=len(self.devices)) + for policy_id, samples_per_device in num_loaded_samples.items(): policy = self.local_worker.policy_map[policy_id] num_batches = max( 1, int(samples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) - batch_fetches_all_towers = [] for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. - batch_fetches = policy.learn_on_loaded_batch( + results = policy.learn_on_loaded_batch( permutation[batch_index] * self.per_device_batch_size, buffer_index=0) - # No towers: Single CPU. - if "tower_0" not in batch_fetches: - batch_fetches_all_towers.append(batch_fetches) - else: - batch_fetches_all_towers.append( - tree.map_structure_with_path( - lambda p, *s: all_tower_reduce(p, *s), - *(batch_fetches.pop( - "tower_{}".format(tower_num)) - for tower_num in range( - len(self.devices))))) - for k, v in batch_fetches.items(): - if k == LEARNER_STATS_KEY: - for k1, v1 in batch_fetches[k].items(): - batch_fetches_all_towers[-1][ - LEARNER_STATS_KEY][k1] = v1 - else: - batch_fetches_all_towers[-1][k] = v - - # Reduce mean across all minibatch SGD steps (axis=0 to keep - # all shapes as-is). - fetches[policy_id] = tree.map_structure( - lambda *s: None if s[0] is None else np.nanmean(s, axis=0), - *batch_fetches_all_towers) + learner_info_builder.add_learn_on_batch_results( + results, policy_id) + + # Tower reduce and finalize results. + learner_info = learner_info_builder.finalize() load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() - metrics.info[LEARNER_INFO] = fetches + metrics.info[LEARNER_INFO] = learner_info if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: @@ -234,24 +221,13 @@ def __call__(self, # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) - return samples, fetches + return samples, learner_info # Backward compatibility. TrainTFMultiGPU = MultiGPUTrainOneStep -def all_tower_reduce(path, *tower_data): - """Reduces stats across towers based on their stats-dict paths.""" - if len(path) == 1 and path[0] == "td_error": - return np.concatenate(tower_data, axis=0) - elif path[-1].startswith("min_"): - return np.nanmin(tower_data) - elif path[-1].startswith("max_"): - return np.nanmax(tower_data) - return np.nanmean(tower_data) - - class ComputeGradients: """Callable that computes gradients with respect to the policy loss. @@ -273,7 +249,12 @@ def __call__(self, samples: SampleBatchType) -> Tuple[ModelGradients, int]: metrics = _get_shared_metrics() with metrics.timers[COMPUTE_GRADS_TIMER]: grad, info = self.workers.local_worker().compute_gradients(samples) - metrics.info[LEARNER_INFO] = get_learner_stats(info) + # RolloutWorker.compute_gradients returns pure single agent stats + # in a non-multi agent setup. + if isinstance(samples, MultiAgentBatch): + metrics.info[LEARNER_INFO] = info + else: + metrics.info[LEARNER_INFO] = {DEFAULT_POLICY_ID: info} return grad, samples.count diff --git a/rllib/models/tests/test_preprocessors.py b/rllib/models/tests/test_preprocessors.py index 015efe6edd723..2107ddec0cbd0 100644 --- a/rllib/models/tests/test_preprocessors.py +++ b/rllib/models/tests/test_preprocessors.py @@ -10,7 +10,7 @@ get_preprocessor, NoPreprocessor, TupleFlatteningPreprocessor, \ OneHotPreprocessor, AtariRamPreprocessor, GenericPixelPreprocessor from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestPreprocessors(unittest.TestCase): @@ -50,7 +50,9 @@ def test_preprocessing_disabled(self): for _ in framework_iterator(config): trainer = ppo.PPOTrainer(config=config) for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index 2236607d3f75f..c7323c41cab96 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -38,6 +38,8 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, assert isinstance(self.original_space, (Dict, Tuple)), \ "`obs_space.original_space` must be [Dict|Tuple]!" + self.processed_obs_space = self.original_space if \ + model_config.get("_disable_preprocessor_api") else obs_space super().__init__(self.original_space, action_space, num_outputs, model_config, name) @@ -124,8 +126,10 @@ def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: - orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], - self.obs_space, "tf") + orig_obs = restore_original_dimensions( + input_dict[SampleBatch.OBS], + self.processed_obs_space, + tensorlib="tf") # Push image observations through our CNNs. outs = [] for i, component in enumerate(tree.flatten(orig_obs)): diff --git a/rllib/models/torch/complex_input_net.py b/rllib/models/torch/complex_input_net.py index b795e4d5485c3..ac053bab6ccf3 100644 --- a/rllib/models/torch/complex_input_net.py +++ b/rllib/models/torch/complex_input_net.py @@ -40,6 +40,9 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, assert isinstance(self.original_space, (Dict, Tuple)), \ "`obs_space.original_space` must be [Dict|Tuple]!" + self.processed_obs_space = self.original_space if \ + model_config.get("_disable_preprocessor_api") else obs_space + nn.Module.__init__(self) TorchModelV2.__init__(self, self.original_space, action_space, num_outputs, model_config, name) @@ -140,8 +143,10 @@ def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: - orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], - self.obs_space, "tf") + orig_obs = restore_original_dimensions( + input_dict[SampleBatch.OBS], + self.processed_obs_space, + tensorlib="torch") # Push image observations through our CNNs. outs = [] for i, component in enumerate(tree.flatten(orig_obs)): diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py index 5cde72c4422e6..a7cc38cef6c8a 100644 --- a/rllib/models/torch/torch_modelv2.py +++ b/rllib/models/torch/torch_modelv2.py @@ -46,6 +46,14 @@ def __init__(self, *args, **kwargs): name, framework="torch") + # Dict to store per multi-gpu tower stats into. + # In PyTorch multi-GPU, we use a single TorchPolicy and copy + # it's Model(s) n times (1 copy for each GPU). When computing the loss + # on each tower, we cannot store the stats (e.g. `entropy`) inside the + # policy object as this would lead to race conditions between the + # different towers all accessing the same property at the same time. + self.tower_stats = {} + @override(ModelV2) def variables(self, as_dict: bool = False) -> \ Union[List[TensorType], Dict[str, TensorType]]: diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 169dc0bad7f41..76bb4c6bb666a 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -11,13 +11,14 @@ from ray.util.debug import log_once from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.repeated_values import RepeatedValues -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import add_mixins, force_list from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.tf_ops import get_gpu_devices from ray.rllib.utils.threading import with_lock @@ -65,15 +66,17 @@ def convert_eager_inputs(func): @functools.wraps(func) def _func(*args, **kwargs): if tf.executing_eagerly(): - args = [_convert_to_tf(x) for x in args] + eager_args = [_convert_to_tf(x) for x in args] # TODO: (sven) find a way to remove key-specific hacks. - kwargs = { + eager_kwargs = { k: _convert_to_tf( v, dtype=tf.int64 if k == "timestep" else None) for k, v in kwargs.items() if k not in {"info_batch", "episodes"} } - return func(*args, **kwargs) + return func(*eager_args, **eager_kwargs) + else: + return func(*args, **kwargs) return _func @@ -182,6 +185,14 @@ def apply_gradients(self, grads): return TracedEagerPolicy +class OptimizerWrapper: + def __init__(self, tape): + self.tape = tape + + def compute_gradients(self, loss, var_list): + return list(zip(self.tape.gradient(loss, var_list), var_list)) + + def build_eager_tf_policy( name, loss_fn, @@ -323,8 +334,11 @@ def __init__(self, observation_space, action_space, config): if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) - # TODO: (sven) Allow tf policy to have more than 1 optimizer. - # Just like torch Policy does. + + # The list of local (tf) optimizers (one per loss term). + self._optimizers: List[LocalOptimizer] = optimizers + # Backward compatibility: A user's policy may only support a single + # loss term and optimizer (no lists). self._optimizer: LocalOptimizer = \ optimizers[0] if optimizers else None @@ -432,6 +446,7 @@ def compute_actions(self, lambda s: tf.convert_to_tensor(s), obs_batch), }, _is_training=tf.constant(False)) + self._lazy_tensor_dict(input_dict) if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = \ tf.convert_to_tensor(prev_action_batch) @@ -465,7 +480,6 @@ def compute_actions_from_input_dict( explore, timestep) @with_lock - @convert_eager_inputs @convert_eager_outputs def _compute_action_helper(self, input_dict, state_batches, episodes, explore, timestep): @@ -481,7 +495,8 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, self._is_training = False self._state_in = state_batches or [] # Calculate RNN sequence lengths. - batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0] + batch_size = int( + tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]) seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \ else None @@ -528,7 +543,7 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, dist_inputs, self.dist_class, state_out = \ action_distribution_fn( self, self.model, - input_dict[SampleBatch.CUR_OBS], + input_dict[SampleBatch.OBS], explore=explore, timestep=timestep, is_training=False) @@ -566,7 +581,7 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, extra_fetches.update(extra_action_out_fn(self)) # Update our global timestep by the batch size. - self.global_timestep += int(batch_size) + self.global_timestep += batch_size return actions, state_out, extra_fetches @@ -725,51 +740,78 @@ def export_checkpoint(self, export_dir): def _get_is_training_placeholder(self): return tf.convert_to_tensor(self._is_training) - def _apply_gradients(self, grads_and_vars): - if apply_gradients_fn: - apply_gradients_fn(self, self._optimizer, grads_and_vars) - else: - self._optimizer.apply_gradients( - [(g, v) for g, v in grads_and_vars if g is not None]) - @with_lock def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" - with tf.GradientTape(persistent=compute_gradients_fn is not None) \ - as tape: - loss = loss_fn(self, self.model, self.dist_class, samples) - + # Gather all variables for which to calculate losses. if isinstance(self.model, tf.keras.Model): variables = self.model.trainable_variables else: variables = self.model.trainable_variables() - if compute_gradients_fn: - - class OptimizerWrapper: - def __init__(self, tape): - self.tape = tape - - def compute_gradients(self, loss, var_list): - return list( - zip(self.tape.gradient(loss, var_list), var_list)) + # Calculate the loss(es) inside a tf GradientTape. + with tf.GradientTape(persistent=compute_gradients_fn is not None) \ + as tape: + losses = loss_fn(self, self.model, self.dist_class, samples) + losses = force_list(losses) - grads_and_vars = compute_gradients_fn(self, - OptimizerWrapper(tape), - loss) + # User provided a compute_gradients_fn. + if compute_gradients_fn: + # Wrap our tape inside a wrapper, such that the resulting + # object looks like a "classic" tf.optimizer. This way, custom + # compute_gradients_fn will work on both tf static graph + # and tf-eager. + optimizer = OptimizerWrapper(tape) + # More than one loss terms/optimizers. + if self.config["_tf_policy_handles_more_than_one_loss"]: + grads_and_vars = compute_gradients_fn( + self, [optimizer] * len(losses), losses) + # Only one loss and one optimizer. + else: + grads_and_vars = [ + compute_gradients_fn(self, optimizer, losses[0]) + ] + # Default: Compute gradients using the above tape. else: - grads_and_vars = list( - zip(tape.gradient(loss, variables), variables)) + grads_and_vars = [ + list(zip(tape.gradient(loss, variables), variables)) + for loss in losses + ] if log_once("grad_vars"): - for _, v in grads_and_vars: - logger.info("Optimizing variable {}".format(v.name)) + for g_and_v in grads_and_vars: + for g, v in g_and_v: + if g is not None: + logger.info(f"Optimizing variable {v.name}") + + # `grads_and_vars` is returned a list (len=num optimizers/losses) + # of lists of (grad, var) tuples. + if self.config["_tf_policy_handles_more_than_one_loss"]: + grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars] + # `grads_and_vars` is returned as a list of (grad, var) tuples. + else: + grads_and_vars = grads_and_vars[0] + grads = [g for g, _ in grads_and_vars] - grads = [g for g, v in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, stats + def _apply_gradients(self, grads_and_vars): + if apply_gradients_fn: + if self.config["_tf_policy_handles_more_than_one_loss"]: + apply_gradients_fn(self, self._optimizers, grads_and_vars) + else: + apply_gradients_fn(self, self._optimizer, grads_and_vars) + else: + if self.config["_tf_policy_handles_more_than_one_loss"]: + for i, o in enumerate(self._optimizers): + o.apply_gradients([(g, v) for g, v in grads_and_vars[i] + if g is not None]) + else: + self._optimizer.apply_gradients( + [(g, v) for g, v in grads_and_vars if g is not None]) + def _stats(self, outputs, samples, grads): fetches = {} diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 3f75a8429c98a..6fd89f0117b97 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -14,9 +14,8 @@ from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import from_config -from ray.rllib.utils.spaces.space_utils import clip_action, \ - get_base_struct_from_space, get_dummy_batch_for_space, unbatch, \ - unsquash_action +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ + get_dummy_batch_for_space, unbatch from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ TensorType, TensorStructType, TrainerConfigDict, Tuple, Union @@ -28,10 +27,6 @@ logger = logging.getLogger(__name__) -# By convention, metrics from optimizing the loss can be reported in the -# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. -LEARNER_STATS_KEY = "learner_stats" - # A policy spec used in the "config.multiagent.policies" specification dict # as values (keys are the policy IDs (str)). E.g.: # config: @@ -180,16 +175,17 @@ def compute_actions( @DeveloperAPI def compute_single_action( self, - obs: TensorStructType, + obs: Optional[TensorStructType] = None, state: Optional[List[TensorType]] = None, + *, prev_action: Optional[TensorStructType] = None, prev_reward: Optional[TensorStructType] = None, info: dict = None, + input_dict: Optional[SampleBatch] = None, episode: Optional["MultiAgentEpisode"] = None, - clip_actions: bool = None, explore: Optional[bool] = None, timestep: Optional[int] = None, - unsquash_actions: bool = None, + # Kwars placeholder for future compatibility. **kwargs) -> \ Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]: """Unbatched version of compute_actions. @@ -199,14 +195,13 @@ def compute_single_action( state: List of RNN state inputs, if any. prev_action: Previous action value, if any. prev_reward: Previous reward, if any. - info (dict): Info object, if any. - episode: this provides access to all - of the internal episode state, which may be useful for - model-based or multi-agent algorithms. - unsquash_actions: Should actions be unsquashed according to - the Policy's action space? - clip_actions: Should actions be clipped according to the - Policy's action space? + info: Info object, if any. + input_dict: A SampleBatch or input dict containing the + single (unbatched) Tensors to compute actions. If given, it'll + be used instead of `obs`, `state`, `prev_action|reward`, and + `info`. + episode: This provides access to all of the internal episode state, + which may be useful for model-based or multi-agent algorithms. explore: Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). @@ -220,43 +215,37 @@ def compute_single_action( - state_outs: List of RNN state outputs, if any. - info: Dictionary of extra features, if any. """ - # If policy works in normalized space, we should unsquash the action. - # Use value of config.normalize_actions, if None. - unsquash_actions = \ - unsquash_actions if unsquash_actions is not None \ - else self.config["normalize_actions"] - clip_actions = clip_actions if clip_actions is not None else \ - self.config["clip_actions"] - - prev_action_batch = None - prev_reward_batch = None - info_batch = None + # Build the input-dict used for the call to + # `self.compute_actions_from_input_dict()`. + if input_dict is None: + input_dict = {SampleBatch.OBS: obs} + if state is not None: + for i, s in enumerate(state): + input_dict[f"state_in_{i}"] = s + if prev_action is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action + if prev_reward is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward + if info is not None: + input_dict[SampleBatch.INFOS] = info + + # Batch all data in input dict. + input_dict = tree.map_structure_with_path( + lambda p, s: (s if p == "seq_lens" else s.unsqueeze(0) if + torch and isinstance(s, torch.Tensor) else + np.expand_dims(s, 0)), + input_dict) + episodes = None - state_batch = None - if prev_action is not None: - prev_action_batch = [prev_action] - if prev_reward is not None: - prev_reward_batch = [prev_reward] - if info is not None: - info_batch = [info] if episode is not None: episodes = [episode] - if state is not None: - state_batch = [ - s.unsqueeze(0) - if torch and isinstance(s, torch.Tensor) else np.expand_dims( - s, 0) for s in state - ] - - out = self.compute_actions( - tree.map_structure(lambda s: np.array([s]), obs), - state_batch, - prev_action_batch=prev_action_batch, - prev_reward_batch=prev_reward_batch, - info_batch=info_batch, + + out = self.compute_actions_from_input_dict( + input_dict=SampleBatch(input_dict), episodes=episodes, explore=explore, - timestep=timestep) + timestep=timestep, + ) # Some policies don't return a tuple, but always just a single action. # E.g. ES and ARS. @@ -271,16 +260,6 @@ def compute_single_action( assert len(single_action) == 1 single_action = single_action[0] - # If we work in normalized action space (normalize_actions=True), - # we re-translate here into the env's action space. - if unsquash_actions: - single_action = unsquash_action(single_action, - self.action_space_struct) - # Clip, according to env's action space. - elif clip_actions: - single_action = clip_action(single_action, - self.action_space_struct) - # Return action, internal state(s), infos. return single_action, [s[0] for s in state_out], \ {k: v[0] for k, v in info.items()} @@ -288,7 +267,7 @@ def compute_single_action( @DeveloperAPI def compute_actions_from_input_dict( self, - input_dict: SampleBatch, + input_dict: Union[SampleBatch, Dict[str, TensorStructType]], explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["MultiAgentEpisode"]] = None, @@ -300,14 +279,19 @@ def compute_actions_from_input_dict( to construct the input_dict for the Model. Args: - input_dict (SampleBatch): A SampleBatch containing the Tensors + input_dict: A SampleBatch or input dict containing the Tensors to compute actions. `input_dict` already abides to the Policy's as well as the Model's view requirements and can thus be passed to the Model as-is. - explore (bool): Whether to pick an exploitation or exploration + explore: Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). - timestep (Optional[int]): The current (sampling) time step. - kwargs: forward compatibility placeholder + timestep: The current (sampling) time step. + episodes: This provides access to all of the internal episodes' + state, which may be useful for model-based or multi-agent + algorithms. + + Keyword Args: + kwargs: Forward compatibility placeholder. Returns: Tuple: diff --git a/rllib/policy/policy_template.py b/rllib/policy/policy_template.py index ea231ed2abc8f..d3463df7eaf71 100644 --- a/rllib/policy/policy_template.py +++ b/rllib/policy/policy_template.py @@ -7,12 +7,13 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils import add_mixins, force_list, NullContextManager from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch, try_import_jax +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.torch_ops import convert_to_non_torch_type from ray.rllib.utils.typing import ModelGradients, TensorType, \ TrainerConfigDict diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 9192d5ba6d4d5..389278a1a4328 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -183,7 +183,7 @@ def concat_samples( >>> print(SampleBatch.concat_samples([b1, b2])) {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])} """ - if isinstance(samples[0], MultiAgentBatch): + if any(isinstance(s, MultiAgentBatch) for s in samples): return MultiAgentBatch.concat_samples(samples) concatd_seq_lens = [] concat_samples = [] @@ -1171,7 +1171,12 @@ def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch": policy_batches = collections.defaultdict(list) env_steps = 0 for s in samples: + # Some batches in `samples` are not MultiAgentBatch. if not isinstance(s, MultiAgentBatch): + # If empty SampleBatch: ok (just ignore). + if isinstance(s, SampleBatch) and len(s) <= 0: + continue + # Otherwise: Error. raise ValueError( "`MultiAgentBatch.concat_samples()` can only concat " "MultiAgentBatch types, not {}!".format(type(s).__name__)) diff --git a/rllib/policy/tests/test_compute_log_likelihoods.py b/rllib/policy/tests/test_compute_log_likelihoods.py index 52259d6ea6e60..330ea381bddf9 100644 --- a/rllib/policy/tests/test_compute_log_likelihoods.py +++ b/rllib/policy/tests/test_compute_log_likelihoods.py @@ -57,7 +57,7 @@ def do_test_log_likelihood(run, explore=True, # Do not unsquash actions # (remain in normalized [-1.0; 1.0] space). - unsquash_actions=False, + unsquash_action=False, )) # Test all taken actions for their log-likelihoods vs expected values. diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index bebc9fa185b26..4f4deb15c05e3 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -10,15 +10,16 @@ import ray import ray.experimental.tf_utils from ray.util.debug import log_once -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils import force_list -from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override from ray.rllib.utils.debug import summarize -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_tf, get_variable +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.tf_ops import get_gpu_devices @@ -423,14 +424,18 @@ def compute_actions( timestep = timestep if timestep is not None else self.global_timestep builder = TFRunBuilder(self.get_session(), "compute_actions") + + input_dict = {SampleBatch.OBS: obs_batch} + if state_batches: + for i, s in enumerate(state_batches): + input_dict[f"state_in_{i}"] = s + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch + to_fetch = self._build_compute_actions( - builder, - obs_batch=obs_batch, - state_batches=state_batches, - prev_action_batch=prev_action_batch, - prev_reward_batch=prev_reward_batch, - explore=explore, - timestep=timestep) + builder, input_dict=input_dict, explore=explore, timestep=timestep) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) @@ -1005,6 +1010,12 @@ def _build_compute_actions(self, # TODO: (sven) This can be deprecated after trajectory view API flag is # removed and always True. else: + if log_once("_build_compute_actions_input_dict"): + deprecation_warning( + old="_build_compute_actions(.., obs_batch=.., ..)", + new="_build_compute_actions(.., input_dict=..)", + error=False, + ) state_batches = state_batches or [] if len(self._state_inputs) != len(state_batches): raise ValueError( diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index fb7e9519ec878..f2ec7dfaadcc7 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -6,15 +6,16 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy from ray.rllib.policy import eager_tf_policy -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils import add_mixins, force_list from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.typing import AgentID, ModelGradients, PolicyID, \ - TensorType, TrainerConfigDict +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.typing import AgentID, ModelGradients, TensorType, \ + TrainerConfigDict if TYPE_CHECKING: from ray.rllib.evaluation import MultiAgentEpisode @@ -53,7 +54,7 @@ def build_tf_policy( extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[ str, TensorType]]] = None, validate_spaces: Optional[Callable[ - [PolicyID, gym.Space, gym.Space, TrainerConfigDict], None]] = None, + [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_loss_init: Optional[Callable[[ diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index f50729d005ed2..bf1c69410ff83 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -5,21 +5,23 @@ import math import numpy as np import os -import time import threading +import time +import tree # pip install dm_tree from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, \ TYPE_CHECKING import ray from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import force_list, NullContextManager from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.threading import with_lock @@ -703,6 +705,34 @@ def apply_gradients(self, gradients: ModelGradients) -> None: self._optimizers[0].step() + @DeveloperAPI + def get_tower_stats(self, stats_name: str) -> List[TensorStructType]: + """Returns list of per-tower stats, copied to this Policy's device. + + Args: + stats_name: The name of the stats to average over (this str + must exist as a key inside each tower's `tower_stats` dict). + + Returns: + The list of stats tensor (structs) of all towers, copied to this + Policy's device. + + Raises: + AssertionError: If the `stats_name` cannot be found in any one + of the tower's `tower_stats` dicts. + """ + data = [] + for tower in self.model_gpu_towers: + if stats_name in tower.tower_stats: + data.append( + tree.map_structure(lambda s: s.to(self.device), + tower.tower_stats[stats_name])) + assert len(data) > 0, \ + f"Stats `{stats_name}` not found in any of the towers (you have " \ + f"{len(self.model_gpu_towers)} towers in total)! Make " \ + "sure you call the loss function on at least one of the towers." + return data + @override(Policy) @DeveloperAPI def get_weights(self) -> ModelWeights: diff --git a/rllib/tests/test_exec_api.py b/rllib/tests/test_exec_api.py index b415c4faadf46..11339f08640b5 100644 --- a/rllib/tests/test_exec_api.py +++ b/rllib/tests/test_exec_api.py @@ -4,6 +4,7 @@ from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ STEPS_TRAINED_COUNTER +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import framework_iterator @@ -29,7 +30,7 @@ def test_exec_plan_stats(ray_start_regular): result = trainer.train() assert isinstance(result, dict) assert "info" in result - assert "learner" in result["info"] + assert LEARNER_INFO in result["info"] assert STEPS_SAMPLED_COUNTER in result["info"] assert STEPS_TRAINED_COUNTER in result["info"] assert "timers" in result diff --git a/rllib/tests/test_supported_multi_agent.py b/rllib/tests/test_supported_multi_agent.py index 0f4063bb2e886..2c114cec4d02f 100644 --- a/rllib/tests/test_supported_multi_agent.py +++ b/rllib/tests/test_supported_multi_agent.py @@ -4,7 +4,9 @@ from ray.rllib.agents.registry import get_trainer_class from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ MultiAgentMountainCar -from ray.rllib.utils.test_utils import framework_iterator +from ray.rllib.policy.policy import PolicySpec +from ray.rllib.utils.test_utils import check_train_results, \ + framework_iterator from ray.tune import register_env @@ -13,7 +15,23 @@ def check_support_multiagent(alg, config): lambda _: MultiAgentMountainCar({"num_agents": 2})) register_env("multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})) - config["log_level"] = "ERROR" + + # Simulate a simple multi-agent setup. + policies = { + "policy_0": PolicySpec(config={"gamma": 0.99}), + "policy_1": PolicySpec(config={"gamma": 0.95}), + } + policy_ids = list(policies.keys()) + + def policy_mapping_fn(agent_id, episode, worker, **kwargs): + pol_id = policy_ids[agent_id] + return pol_id + + config["multiagent"] = { + "policies": policies, + "policy_mapping_fn": policy_mapping_fn, + } + for fw in framework_iterator(config): if fw in ["tf2", "tfe"] and \ alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]: @@ -25,7 +43,9 @@ def check_support_multiagent(alg, config): a = get_trainer_class(alg)( config=config, env="multi_agent_cartpole") - print(a.train()) + results = a.train() + check_train_results(results) + print(results) a.stop() diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index 993558e77d223..d290d3ef87f68 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -69,6 +69,11 @@ def _do_check(alg, config, a_name, o_name): try: a = get_trainer_class(alg)(config=config, env=RandomEnv) + except ray.exceptions.RayActorError as e: + if isinstance(e.args[2], UnsupportedSpaceException): + stat = "unsupported" + else: + raise except UnsupportedSpaceException: stat = "unsupported" else: @@ -99,10 +104,11 @@ def _do_check(alg, config, a_name, o_name): _do_check(alg, config, a_name, o_name) # Do the remaining obs spaces. assert len(OBSERVATION_SPACES_TO_TEST) >= len(ACTION_SPACES_TO_TEST) + fixed_action_key = next(iter(ACTION_SPACES_TO_TEST.keys())) for i, o_name in enumerate(OBSERVATION_SPACES_TO_TEST.keys()): if i < len(ACTION_SPACES_TO_TEST): continue - _do_check(alg, config, "discrete", o_name) + _do_check(alg, config, fixed_action_key, o_name) class TestSupportedSpacesPG(unittest.TestCase): diff --git a/rllib/utils/__init__.py b/rllib/utils/__init__.py index 4f1f33083f01c..e720bfebfc468 100644 --- a/rllib/utils/__init__.py +++ b/rllib/utils/__init__.py @@ -11,7 +11,7 @@ from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \ PolynomialSchedule, ExponentialSchedule, ConstantSchedule from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator from ray.tune.utils import merge_dicts, deep_update @@ -77,6 +77,7 @@ def __exit__(self, *args): "add_mixins", "check", "check_compute_single_action", + "check_train_results", "deep_update", "deprecation_warning", "fc", diff --git a/rllib/utils/exploration/stochastic_sampling.py b/rllib/utils/exploration/stochastic_sampling.py index daa6089d483b4..593233625de15 100644 --- a/rllib/utils/exploration/stochastic_sampling.py +++ b/rllib/utils/exploration/stochastic_sampling.py @@ -1,7 +1,7 @@ import functools import gym import numpy as np -from typing import Union +from typing import Optional, Union from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 @@ -61,11 +61,12 @@ def __init__(self, dtype=np.int64) @override(Exploration) - def get_exploration_action(self, - *, - action_distribution: ActionDistribution, - timestep: Union[int, TensorType], - explore: bool = True): + def get_exploration_action( + self, + *, + action_distribution: ActionDistribution, + timestep: Optional[Union[int, TensorType]] = None, + explore: bool = True): if self.framework == "torch": return self._get_torch_exploration_action(action_distribution, timestep, explore) @@ -74,7 +75,7 @@ def get_exploration_action(self, timestep, explore) def _get_tf_exploration_action_op(self, action_dist, timestep, explore): - ts = timestep if timestep is not None else self.last_timestep + 1 + ts = self.last_timestep + 1 stochastic_actions = tf.cond( pred=tf.convert_to_tensor(ts < self.random_timesteps), @@ -100,10 +101,7 @@ def _get_tf_exploration_action_op(self, action_dist, timestep, explore): # Increment `last_timestep` by 1 (or set to `timestep`). if self.framework in ["tf2", "tfe"]: - if timestep is None: - self.last_timestep.assign_add(1) - else: - self.last_timestep.assign(timestep) + self.last_timestep.assign_add(1) return action, logp else: assign_op = (tf1.assign_add(self.last_timestep, 1) diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rllib/utils/metrics/learner_info.py b/rllib/utils/metrics/learner_info.py new file mode 100644 index 0000000000000..ebe44a7c9fcda --- /dev/null +++ b/rllib/utils/metrics/learner_info.py @@ -0,0 +1,84 @@ +from collections import defaultdict +import numpy as np +import tree # pip install dm_tree +from typing import Dict + +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.typing import PolicyID + +# Instant metrics (keys for metrics.info). +LEARNER_INFO = "learner" +# By convention, metrics from optimizing the loss can be reported in the +# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. +LEARNER_STATS_KEY = "learner_stats" + + +class LearnerInfoBuilder: + def __init__(self, num_devices: int = 1): + self.num_devices = num_devices + self.results_all_towers = defaultdict(list) + self.is_finalized = False + + def add_learn_on_batch_results( + self, + results: Dict, + policy_id: PolicyID = DEFAULT_POLICY_ID, + ) -> None: + """Adds a policy.learn_on_(loaded)?_batch() result to this builder. + + Args: + results: The results returned by Policy.learn_on_batch or + Policy.learn_on_loaded_batch. + policy_id: The policy's ID, whose learn_on_(loaded)_batch method + returned `results`. + """ + assert not self.is_finalized, \ + "LearnerInfo already finalized! Cannot add more results." + + # No towers: Single CPU. + if "tower_0" not in results: + self.results_all_towers[policy_id].append(results) + # Multi-GPU case: + else: + self.results_all_towers[policy_id].append( + tree.map_structure_with_path( + lambda p, *s: all_tower_reduce(p, *s), + *(results.pop("tower_{}".format(tower_num)) + for tower_num in range(self.num_devices)))) + for k, v in results.items(): + if k == LEARNER_STATS_KEY: + for k1, v1 in results[k].items(): + self.results_all_towers[policy_id][-1][ + LEARNER_STATS_KEY][k1] = v1 + else: + self.results_all_towers[policy_id][-1][k] = v + + def finalize(self): + self.is_finalized = True + + info = {} + for policy_id, results_all_towers in self.results_all_towers.items(): + # Reduce mean across all minibatch SGD steps (axis=0 to keep + # all shapes as-is). + info[policy_id] = tree.map_structure( + lambda *s: None if s[0] is None else np.nanmean(s, axis=0), + *results_all_towers) + + return info + + +def all_tower_reduce(path, *tower_data): + """Reduces stats across towers based on their stats-dict paths.""" + # TD-errors: Need to stay per batch item in order to be able to update + # each item's weight in a prioritized replay buffer. + if len(path) == 1 and path[0] == "td_error": + return np.concatenate(tower_data, axis=0) + + # Min stats: Reduce min. + if path[-1].startswith("min_"): + return np.nanmin(tower_data) + # Max stats: Reduce max. + elif path[-1].startswith("max_"): + return np.nanmax(tower_data) + # Everything else: Reduce mean. + return np.nanmean(tower_data) diff --git a/rllib/utils/multi_agent.py b/rllib/utils/multi_agent.py index b23726cb393db..50d5227c54e75 100644 --- a/rllib/utils/multi_agent.py +++ b/rllib/utils/multi_agent.py @@ -1,9 +1,13 @@ +from typing import Tuple + from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.typing import PartialTrainerConfigDict +from ray.rllib.utils.typing import MultiAgentPolicyConfigDict, \ + PartialTrainerConfigDict -def check_multi_agent(config: PartialTrainerConfigDict): +def check_multi_agent(config: PartialTrainerConfigDict) -> \ + Tuple[MultiAgentPolicyConfigDict, bool]: """Checks, whether a (partial) config defines a multi-agent setup. Args: @@ -11,18 +15,25 @@ def check_multi_agent(config: PartialTrainerConfigDict): to check for multi-agent. Returns: - Tuple[MultiAgentPolicyConfigDict, bool]: The resulting (all - fixed) multi-agent policy dict and whether we have a - multi-agent setup or not. + The resulting (all fixed) multi-agent policy dict and whether we + have a multi-agent setup or not. """ multiagent_config = config["multiagent"] policies = multiagent_config.get("policies") + + # Nothing specified in config dict -> Assume simple single agent setup + # with DEFAULT_POLICY_ID as only policy. if not policies: policies = {DEFAULT_POLICY_ID} + # Policies given as set (of PolicyIDs) -> Setup each policy automatically + # via empty PolicySpec (will make RLlib infer obs- and action spaces + # as well as the Policy's class). if isinstance(policies, set): policies = multiagent_config["policies"] = { pid: PolicySpec() for pid in policies } + # Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only + # PolicyID found in policies dict. is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies return policies, is_multiagent diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index b163c2a36fcd4..6b4f060a95598 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -1,38 +1,17 @@ """Utils for minibatch SGD across multiple RLlib policies.""" -import numpy as np import logging -from collections import defaultdict +import numpy as np import random -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, \ MultiAgentBatch +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder logger = logging.getLogger(__name__) -def averaged(kv, axis=None): - """Average the value lists of a dictionary. - - For non-scalar values, we simply pick the first value. - - Args: - kv (dict): dictionary with values that are lists of floats. - - Returns: - dictionary with single averaged float as values. - """ - out = {} - for k, v in kv.items(): - if v[0] is not None and not isinstance(v[0], dict): - out[k] = np.mean(v, axis=axis) - else: - out[k] = v[0] - return out - - -def standardized(array): +def standardized(array: np.ndarray): """Normalize the values in an array. Args: @@ -107,7 +86,12 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) - fetches = defaultdict(dict) + # Use LearnerInfoBuilder as a unified way to build the final + # results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same structure + # no matter the setup (multi-GPU, multi-agent, minibatch SGD, + # tf vs torch). + learner_info_builder = LearnerInfoBuilder(num_devices=1) for policy_id in policies.keys(): if policy_id not in samples.policy_batches: continue @@ -116,23 +100,14 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, for field in standardize_fields: batch[field] = standardized(batch[field]) - learner_stats = defaultdict(list) - model_stats = defaultdict(list) - custom_callbacks_stats = defaultdict(list) - for i in range(num_sgd_iter): for minibatch in minibatches(batch, sgd_minibatch_size): - batch_fetches = (local_worker.learn_on_batch( + results = (local_worker.learn_on_batch( MultiAgentBatch({ policy_id: minibatch }, minibatch.count)))[policy_id] - for k, v in batch_fetches.get(LEARNER_STATS_KEY, {}).items(): - learner_stats[k].append(v) - for k, v in batch_fetches.get("model", {}).items(): - model_stats[k].append(v) - for k, v in batch_fetches.get("custom_metrics", {}).items(): - custom_callbacks_stats[k].append(v) - fetches[policy_id][LEARNER_STATS_KEY] = averaged(learner_stats) - fetches[policy_id]["model"] = averaged(model_stats) - fetches[policy_id]["custom_metrics"] = averaged(custom_callbacks_stats) - return fetches + learner_info_builder.add_learn_on_batch_results( + results, policy_id) + + learner_info = learner_info_builder.finalize() + return learner_info diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index f119d3806968f..5fcb16da6471e 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -1,10 +1,12 @@ from collections import Counter import copy -import gym +from gym.spaces import Box import logging import numpy as np +import random import re import time +import tree # pip install dm_tree from typing import Any, Dict, List import yaml @@ -29,7 +31,8 @@ def framework_iterator(config=None, frameworks=("tf2", "tf", "tfe", "torch"), - session=False): + session=False, + with_eager_tracing=False): """An generator that allows for looping through n frameworks for testing. Provides the correct config entries ("framework") as well @@ -44,6 +47,8 @@ def framework_iterator(config=None, and yield that as second return value (otherwise yield (fw, None)). Also sets a seed (42) on the session to make the test deterministic. + with_eager_tracing: Include `eager_tracing=True` in the returned + configs, when framework=[tfe|tf2]. Yields: str: If enter_session is False: @@ -103,7 +108,15 @@ def framework_iterator(config=None, elif fw == "tf": assert not tf1.executing_eagerly() - yield fw if session is False else (fw, sess) + # Additionally loop through eager_tracing=True + False, if necessary. + if fw in ["tf2", "tfe"] and with_eager_tracing: + for tracing in [True, False]: + config["eager_tracing"] = tracing + yield fw if session is False else (fw, sess) + config["eager_tracing"] = False + # Yield current framework + tf-session (if necessary). + else: + yield fw if session is False else (fw, sess) # Exit any context we may have entered. if eager_ctx: @@ -260,31 +273,6 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False): "ERROR: x ({}) is the same as y ({})!".format(x, y) -def check_learning_achieved(tune_results, min_reward, evaluation=False): - """Throws an error if `min_reward` is not reached within tune_results. - - Checks the last iteration found in tune_results for its - "episode_reward_mean" value and compares it to `min_reward`. - - Args: - tune_results: The tune.run returned results object. - min_reward (float): The min reward that must be reached. - - Raises: - ValueError: If `min_reward` not reached. - """ - # Get maximum reward of all trials - # (check if at least one trial achieved some learning) - avg_rewards = [(trial.last_result["episode_reward_mean"] - if not evaluation else - trial.last_result["evaluation"]["episode_reward_mean"]) - for trial in tune_results.trials] - best_avg_reward = max(avg_rewards) - if best_avg_reward < min_reward: - raise ValueError("`stop-reward` of {} not reached!".format(min_reward)) - print("ok") - - def check_compute_single_action(trainer, include_state=False, include_prev_action_reward=False): @@ -300,17 +288,120 @@ def check_compute_single_action(trainer, Raises: ValueError: If anything unexpected happens. """ + # Have to import this here to avoid circular dependency. + from ray.rllib.policy.sample_batch import SampleBatch + + # Some Trainers may not abide to the standard API. try: pol = trainer.get_policy() except AttributeError: pol = trainer.policy + # Get the policy's model. model = pol.model action_space = pol.action_space + def _test(what, method_to_test, obs_space, full_fetch, explore, timestep, + unsquash, clip): + call_kwargs = {} + if what is trainer: + call_kwargs["full_fetch"] = full_fetch + + obs = obs_space.sample() + if isinstance(obs_space, Box): + obs = np.clip(obs, -1.0, 1.0) + state_in = None + if include_state: + state_in = model.get_initial_state() + if not state_in: + state_in = [] + i = 0 + while f"state_in_{i}" in model.view_requirements: + state_in.append(model.view_requirements[f"state_in_{i}"] + .space.sample()) + i += 1 + action_in = action_space.sample() \ + if include_prev_action_reward else None + reward_in = 1.0 if include_prev_action_reward else None + + if method_to_test == "input_dict": + assert what is pol + + input_dict = {SampleBatch.OBS: obs} + if include_prev_action_reward: + input_dict[SampleBatch.PREV_ACTIONS] = action_in + input_dict[SampleBatch.PREV_REWARDS] = reward_in + if state_in: + for i, s in enumerate(state_in): + input_dict[f"state_in_{i}"] = s + input_dict_batched = SampleBatch( + tree.map_structure(lambda s: np.expand_dims(s, 0), input_dict)) + action = pol.compute_actions_from_input_dict( + input_dict=input_dict_batched, + explore=explore, + timestep=timestep, + **call_kwargs) + # Unbatch everything to be able to compare against single + # action below. + # ARS and ES return action batches as lists. + if isinstance(action[0], list): + action = (np.array(action[0]), action[1], action[2]) + action = tree.map_structure(lambda s: s[0], action) + + try: + action2 = pol.compute_single_action( + input_dict=input_dict, + explore=explore, + timestep=timestep, + **call_kwargs) + # Make sure these are the same, unless we have exploration + # switched on (or noisy layers). + if not explore and not pol.config.get("noisy"): + check(action, action2) + except TypeError: + pass + else: + action = what.compute_single_action( + obs, + state_in, + prev_action=action_in, + prev_reward=reward_in, + explore=explore, + timestep=timestep, + unsquash_action=unsquash, + clip_action=clip, + **call_kwargs) + + state_out = None + if state_in or full_fetch or what is pol: + action, state_out, _ = action + if state_out: + for si, so in zip(state_in, state_out): + check(list(si.shape), so.shape) + + # Test whether unsquash/clipping works on the Trainer's + # compute_single_action method: Both flags should force the action + # to be within the space's bounds. + if method_to_test == "single" and what == trainer: + if not action_space.contains(action) and \ + (clip or unsquash or not isinstance(action_space, Box)): + raise ValueError( + f"Returned action ({action}) of trainer/policy {what} " + f"not in Env's action_space {action_space}") + # We are operating in normalized space: Expect only smaller action + # values. + if isinstance(action_space, Box) and not unsquash and \ + what.config.get("normalize_actions") and \ + np.any(np.abs(action) > 3.0): + raise ValueError( + f"Returned action ({action}) of trainer/policy {what} " + "should be in normalized space, but seems too large/small " + "for that!") + + # Loop through: Policy vs Trainer; Different API methods to calculate + # actions; unsquash option; clip option; full fetch or not. for what in [pol, trainer]: if what is trainer: - method_to_test = trainer.compute_single_action # Get the obs-space from Workers.env (not Policy) due to possible # pre-processor up front. worker_set = getattr(trainer, "workers", @@ -323,53 +414,134 @@ def check_compute_single_action(trainer, lambda p: p.observation_space) obs_space = getattr(obs_space, "original_space", obs_space) else: - method_to_test = pol.compute_single_action obs_space = pol.observation_space - for explore in [True, False]: - for full_fetch in ([False, True] if what is trainer else [False]): - call_kwargs = {} - if what is trainer: - call_kwargs["full_fetch"] = full_fetch - else: - call_kwargs["clip_actions"] = True - - obs = obs_space.sample() - if isinstance(obs_space, gym.spaces.Box): - obs = np.clip(obs, -1.0, 1.0) - state_in = None - if include_state: - state_in = model.get_initial_state() - if not state_in: - state_in = [] - i = 0 - while f"state_in_{i}" in model.view_requirements: - state_in.append(model.view_requirements[ - f"state_in_{i}"].space.sample()) - i += 1 - action_in = action_space.sample() \ - if include_prev_action_reward else None - reward_in = 1.0 if include_prev_action_reward else None - action = method_to_test( - obs, - state_in, - prev_action=action_in, - prev_reward=reward_in, - explore=explore, - **call_kwargs) + for method_to_test in ["single"] + \ + (["input_dict"] if what is pol else []): + for explore in [True, False]: + for full_fetch in ([False, True] + if what is trainer else [False]): + timestep = random.randint(0, 100000) + for unsquash in [True, False]: + for clip in ([False] if unsquash else [True, False]): + _test(what, method_to_test, obs_space, full_fetch, + explore, timestep, unsquash, clip) - state_out = None - if state_in or full_fetch or what is pol: - action, state_out, _ = action - if state_out: - for si, so in zip(state_in, state_out): - check(list(si.shape), so.shape) - if not action_space.contains(action): - raise ValueError( - "Returned action ({}) of trainer/policy {} not in " - "Env's action_space " - "({})!".format(action, what, action_space)) +def check_learning_achieved(tune_results, min_reward, evaluation=False): + """Throws an error if `min_reward` is not reached within tune_results. + + Checks the last iteration found in tune_results for its + "episode_reward_mean" value and compares it to `min_reward`. + + Args: + tune_results: The tune.run returned results object. + min_reward (float): The min reward that must be reached. + + Raises: + ValueError: If `min_reward` not reached. + """ + # Get maximum reward of all trials + # (check if at least one trial achieved some learning) + avg_rewards = [(trial.last_result["episode_reward_mean"] + if not evaluation else + trial.last_result["evaluation"]["episode_reward_mean"]) + for trial in tune_results.trials] + best_avg_reward = max(avg_rewards) + if best_avg_reward < min_reward: + raise ValueError("`stop-reward` of {} not reached!".format(min_reward)) + print("ok") + + +def check_train_results(train_results): + """Checks proper structure of a Trainer.train() returned dict. + + Args: + train_results: The train results dict to check. + + Raises: + AssertionError: If `train_results` doesn't have the proper structure or + data in it. + """ + # Import these here to avoid circular dependencies. + from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID + from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY + from ray.rllib.utils.multi_agent import check_multi_agent + + # Assert that some keys are where we would expect them. + for key in [ + "agent_timesteps_total", + "config", + "custom_metrics", + "episode_len_mean", + "episode_reward_max", + "episode_reward_mean", + "episode_reward_min", + "episodes_total", + "hist_stats", + "info", + "iterations_since_restore", + "num_healthy_workers", + "perf", + "policy_reward_max", + "policy_reward_mean", + "policy_reward_min", + "sampler_perf", + "time_since_restore", + "time_this_iter_s", + "timesteps_since_restore", + "timesteps_total", + "timers", + "time_total_s", + "training_iteration", + ]: + assert key in train_results, \ + f"'{key}' not found in `train_results` ({train_results})!" + + _, is_multi_agent = check_multi_agent(train_results["config"]) + + # Check in particular the "info" dict. + info = train_results["info"] + assert LEARNER_INFO in info, \ + f"'learner' not in train_results['infos'] ({info})!" + assert "num_steps_trained" in info,\ + f"'num_steps_trained' not in train_results['infos'] ({info})!" + + learner_info = info[LEARNER_INFO] + + # Make sure we have a default_policy key if we are not in a + # multi-agent setup. + if not is_multi_agent: + # APEX algos sometimes have an empty learner info dict (no metrics + # collected yet). + assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, \ + f"'{DEFAULT_POLICY_ID}' not found in " \ + f"train_results['infos']['learner'] ({learner_info})!" + + for pid, policy_stats in learner_info.items(): + if pid == "batch_count": + continue + # Expect td-errors to be per batch-item. + if "td_error" in policy_stats: + configured_b = train_results["config"]["train_batch_size"] + actual_b = policy_stats["td_error"].shape[0] + # R2D2 case. + if (configured_b - actual_b) / actual_b > 0.1: + assert configured_b / ( + train_results["config"]["model"]["max_seq_len"] + + train_results["config"]["burn_in"]) == actual_b + + # Make sure each policy has the LEARNER_STATS_KEY under it. + assert LEARNER_STATS_KEY in policy_stats + learner_stats = policy_stats[LEARNER_STATS_KEY] + for key, value in learner_stats.items(): + # Min- and max-stats should be single values. + if key.startswith("min_") or key.startswith("max_"): + assert np.isscalar( + value), f"'key' value not a scalar ({value})!" + + return train_results def run_learning_tests_from_yaml( diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index 20b0ea3d75f98..1b577be7ef727 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -146,7 +146,7 @@ def zero_logps_from_actions(actions: TensorStructType) -> TensorType: # `deterministic_actions` or `stochastic_actions`). In case # actions are just [B], zeros_like works just fine here, but if # actions are [B, ...], we have to reduce logp back to just [B]. - if len(logp_.shape) > 1: + while len(logp_.shape) > 1: logp_ = logp_[:, 0] return logp_ diff --git a/rllib/utils/tf_run_builder.py b/rllib/utils/tf_run_builder.py index 82b904bd13164..28a48558f73e7 100644 --- a/rllib/utils/tf_run_builder.py +++ b/rllib/utils/tf_run_builder.py @@ -59,7 +59,10 @@ def get(self, to_fetch): _count = 0 -def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None): +def run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None): + if feed_dict is None: + feed_dict = {} + if timeline_dir: from tensorflow.python.client import timeline diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index a27be53cc2695..90ccc64aad126 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -48,8 +48,8 @@ def atanh(x): def concat_multi_gpu_td_errors(policy): td_error = torch.cat( [ - getattr(t, "td_error", torch.tensor([0.0])).to(policy.device) - for t in policy.model_gpu_towers + t.tower_stats.get("td_error", torch.tensor([0.0])).to( + policy.device) for t in policy.model_gpu_towers ], dim=0) policy.td_error = td_error @@ -132,7 +132,7 @@ def explained_variance(y, pred): y_var = torch.var(y, dim=[0]) diff_var = torch.var(y - pred, dim=[0]) min_ = torch.tensor([-1.0]).to(pred.device) - return torch.max(min_, 1 - (diff_var / y_var)) + return torch.max(min_, 1 - (diff_var / y_var))[0] def global_norm(tensors): diff --git a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h index 483464c1ff6eb..56be36f4c87ff 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h @@ -17,6 +17,7 @@ namespace gcs { class MockGcsNodeManager : public GcsNodeManager { public: + MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr) {} MOCK_METHOD(void, HandleRegisterNode, (const rpc::RegisterNodeRequest &request, rpc::RegisterNodeReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h b/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h index f612e6d1d2841..627e3357879e7 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h +++ b/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h @@ -1,4 +1,4 @@ -// Copyright The Ray Authors. +// Copyright 2021 The Ray Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,8 +30,8 @@ class MockGcsPlacementGroupSchedulerInterface public: MOCK_METHOD(void, ScheduleUnplacedBundles, (std::shared_ptr placement_group, - std::function)> failure_callback, - std::function)> success_callback), + PGSchedulingFailureCallback failure_callback, + PGSchedulingSuccessfulCallback success_callback), (override)); MOCK_METHOD((absl::flat_hash_map>), GetBundlesOnNode, (const NodeID &node_id), (override)); @@ -63,11 +63,12 @@ namespace gcs { class MockGcsScheduleStrategy : public GcsScheduleStrategy { public: - MOCK_METHOD(ScheduleMap, Schedule, - (std::vector> & bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD( + ScheduleResult, Schedule, + (const std::vector> &bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -78,11 +79,12 @@ namespace gcs { class MockGcsPackStrategy : public GcsPackStrategy { public: - MOCK_METHOD(ScheduleMap, Schedule, - (std::vector> & bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD( + ScheduleResult, Schedule, + (const std::vector> &bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -93,11 +95,12 @@ namespace gcs { class MockGcsSpreadStrategy : public GcsSpreadStrategy { public: - MOCK_METHOD(ScheduleMap, Schedule, - (std::vector> & bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD( + ScheduleResult, Schedule, + (const std::vector> &bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -108,11 +111,12 @@ namespace gcs { class MockGcsStrictPackStrategy : public GcsStrictPackStrategy { public: - MOCK_METHOD(ScheduleMap, Schedule, - (std::vector> & bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD( + ScheduleResult, Schedule, + (const std::vector> &bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -123,11 +127,12 @@ namespace gcs { class MockGcsStrictSpreadStrategy : public GcsStrictSpreadStrategy { public: - MOCK_METHOD(ScheduleMap, Schedule, - (std::vector> & bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD( + ScheduleResult, Schedule, + (const std::vector> &bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -160,8 +165,8 @@ class MockGcsPlacementGroupScheduler : public GcsPlacementGroupScheduler { public: MOCK_METHOD(void, ScheduleUnplacedBundles, (std::shared_ptr placement_group, - std::function)> failure_handler, - std::function)> success_handler), + PGSchedulingFailureCallback failure_handler, + PGSchedulingSuccessfulCallback success_handler), (override)); MOCK_METHOD(void, DestroyPlacementGroupBundleResourcesIfExists, (const PlacementGroupID &placement_group_id), (override)); diff --git a/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h b/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h index d981be23a5472..764bee572cabc 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h @@ -17,6 +17,7 @@ namespace gcs { class MockGcsResourceManager : public GcsResourceManager { public: + using GcsResourceManager::GcsResourceManager; MOCK_METHOD(void, HandleGetResources, (const rpc::GetResourcesRequest &request, rpc::GetResourcesReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/pubsub/gcs_pub_sub.h b/src/mock/ray/gcs/pubsub/gcs_pub_sub.h new file mode 100644 index 0000000000000..21e500da0a002 --- /dev/null +++ b/src/mock/ray/gcs/pubsub/gcs_pub_sub.h @@ -0,0 +1,27 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace gcs { + +class MockGcsPubSub : public GcsPubSub { + public: + MOCK_METHOD(Status, Publish, + (const std::string &channel, const std::string &id, const std::string &data, + const StatusCallback &done), + (override)); +}; + +} // namespace gcs +} // namespace ray diff --git a/src/mock/ray/gcs/store_client/in_memory_store_client.h b/src/mock/ray/gcs/store_client/in_memory_store_client.h new file mode 100644 index 0000000000000..08af16a075a17 --- /dev/null +++ b/src/mock/ray/gcs/store_client/in_memory_store_client.h @@ -0,0 +1,66 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace gcs { + +class MockInMemoryStoreClient : public InMemoryStoreClient { + public: + MOCK_METHOD(Status, AsyncPut, + (const std::string &table_name, const std::string &key, + const std::string &data, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncPutWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const std::string &data, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGet, + (const std::string &table_name, const std::string &key, + const OptionalItemCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGetByIndex, + (const std::string &table_name, const std::string &index_key, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncGetAll, + (const std::string &table_name, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncDelete, + (const std::string &table_name, const std::string &key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDelete, + (const std::string &table_name, const std::vector &keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, + (const std::string &table_name, const std::vector &keys, + const std::vector &index_keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteByIndex, + (const std::string &table_name, const std::string &index_key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(int, GetNextJobID, (), (override)); +}; + +} // namespace gcs +} // namespace ray diff --git a/src/mock/ray/gcs/store_client/redis_store_client.h b/src/mock/ray/gcs/store_client/redis_store_client.h new file mode 100644 index 0000000000000..153a69755d3b7 --- /dev/null +++ b/src/mock/ray/gcs/store_client/redis_store_client.h @@ -0,0 +1,67 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace gcs { + +class MockRedisStoreClient : public RedisStoreClient { + public: + MockRedisStoreClient() : RedisStoreClient(nullptr) {} + MOCK_METHOD(Status, AsyncPut, + (const std::string &table_name, const std::string &key, + const std::string &data, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncPutWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const std::string &data, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGet, + (const std::string &table_name, const std::string &key, + const OptionalItemCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGetByIndex, + (const std::string &table_name, const std::string &index_key, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncGetAll, + (const std::string &table_name, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncDelete, + (const std::string &table_name, const std::string &key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDelete, + (const std::string &table_name, const std::vector &keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, + (const std::string &table_name, const std::vector &keys, + const std::vector &index_keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteByIndex, + (const std::string &table_name, const std::string &index_key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(int, GetNextJobID, (), (override)); +}; + +} // namespace gcs +} // namespace ray diff --git a/src/mock/ray/gcs/store_client/store_client.h b/src/mock/ray/gcs/store_client/store_client.h new file mode 100644 index 0000000000000..6f4e3b5382735 --- /dev/null +++ b/src/mock/ray/gcs/store_client/store_client.h @@ -0,0 +1,66 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace gcs { + +class MockStoreClient : public StoreClient { + public: + MOCK_METHOD(Status, AsyncPut, + (const std::string &table_name, const std::string &key, + const std::string &data, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncPutWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const std::string &data, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGet, + (const std::string &table_name, const std::string &key, + const OptionalItemCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGetByIndex, + (const std::string &table_name, const std::string &index_key, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncGetAll, + (const std::string &table_name, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncDelete, + (const std::string &table_name, const std::string &key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDelete, + (const std::string &table_name, const std::vector &keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, + (const std::string &table_name, const std::vector &keys, + const std::vector &index_keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteByIndex, + (const std::string &table_name, const std::string &index_key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(int, GetNextJobID, (), (override)); +}; + +} // namespace gcs +} // namespace ray diff --git a/src/mock/ray/pubsub/publisher.h b/src/mock/ray/pubsub/publisher.h new file mode 100644 index 0000000000000..7094a9afadeac --- /dev/null +++ b/src/mock/ray/pubsub/publisher.h @@ -0,0 +1,100 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace pubsub { +namespace pub_internal { + +template +class MockSubscriptionIndex : public SubscriptionIndex { + public: +}; + +} // namespace pub_internal +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { +namespace pub_internal { + +class MockLongPollConnection : public LongPollConnection { + public: +}; + +} // namespace pub_internal +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { +namespace pub_internal { + +class MockSubscriber : public Subscriber { + public: +}; + +} // namespace pub_internal +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockPublisherInterface : public PublisherInterface { + public: + MOCK_METHOD(bool, RegisterSubscription, + (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + const std::string &key_id_binary), + (override)); + MOCK_METHOD(void, Publish, + (const rpc::ChannelType channel_type, const rpc::PubMessage &pub_message, + const std::string &key_id_binary), + (override)); + MOCK_METHOD(void, PublishFailure, + (const rpc::ChannelType channel_type, const std::string &key_id_binary), + (override)); + MOCK_METHOD(bool, UnregisterSubscription, + (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + const std::string &key_id_binary), + (override)); +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockPublisher : public Publisher { + public: + MOCK_METHOD(bool, RegisterSubscription, + (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + const std::string &key_id_binary), + (override)); + MOCK_METHOD(void, Publish, + (const rpc::ChannelType channel_type, const rpc::PubMessage &pub_message, + const std::string &key_id_binary), + (override)); + MOCK_METHOD(void, PublishFailure, + (const rpc::ChannelType channel_type, const std::string &key_id_binary), + (override)); + MOCK_METHOD(bool, UnregisterSubscription, + (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + const std::string &key_id_binary), + (override)); +}; + +} // namespace pubsub +} // namespace ray diff --git a/src/mock/ray/pubsub/subscriber.h b/src/mock/ray/pubsub/subscriber.h new file mode 100644 index 0000000000000..38dc5f32afb65 --- /dev/null +++ b/src/mock/ray/pubsub/subscriber.h @@ -0,0 +1,155 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace pubsub { + +template +class MockSubscriptionInfo : public SubscriptionInfo { + public: +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockSubscribeChannelInterface : public SubscribeChannelInterface { + public: + MOCK_METHOD(void, Subscribe, + (const rpc::Address &publisher_address, const std::string &key_id_binary, + SubscriptionCallback subscription_callback, + SubscriptionFailureCallback subscription_failure_callback), + (override)); + MOCK_METHOD(bool, Unsubscribe, + (const rpc::Address &publisher_address, const std::string &key_id_binary), + (override)); + MOCK_METHOD(void, HandlePublishedMessage, + (const rpc::Address &publisher_address, const rpc::PubMessage &pub_message), + (const, override)); + MOCK_METHOD(void, HandlePublisherFailure, (const rpc::Address &publisher_address), + (override)); + MOCK_METHOD(void, HandlePublisherFailure, + (const rpc::Address &publisher_address, const std::string &key_id_binary), + (override)); + MOCK_METHOD(bool, SubscriptionExists, (const PublisherID &publisher_id), (override)); + MOCK_METHOD(const rpc::ChannelType, GetChannelType, (), (const, override)); + MOCK_METHOD(bool, CheckNoLeaks, (), (const, override)); + MOCK_METHOD(std::string, DebugString, (), (const, override)); +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +template +class MockSubscriberChannel : public SubscriberChannel { + public: + MOCK_METHOD(void, Subscribe, + (const rpc::Address &publisher_address, const std::string &key_id, + SubscriptionCallback subscription_callback, + SubscriptionFailureCallback subscription_failure_callback), + (override)); + MOCK_METHOD(bool, Unsubscribe, + (const rpc::Address &publisher_address, const std::string &key_id), + (override)); + MOCK_METHOD(bool, CheckNoLeaks, (), (const, override)); + MOCK_METHOD(void, HandlePublishedMessage, + (const rpc::Address &publisher_address, const rpc::PubMessage &pub_message), + (const, override)); + MOCK_METHOD(void, HandlePublisherFailure, (const rpc::Address &publisher_address), + (override)); + MOCK_METHOD(void, HandlePublisherFailure, + (const rpc::Address &publisher_address, const std::string &key_id_binary), + (override)); + MOCK_METHOD(bool, SubscriptionExists, (const PublisherID &publisher_id), (override)); + MOCK_METHOD(const rpc::ChannelType, GetChannelType, (), (const, override)); + MOCK_METHOD(std::string, DebugString, (), (const, override)); +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockWaitForObjectEvictionChannel : public WaitForObjectEvictionChannel { + public: +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockWaitForRefRemovedChannel : public WaitForRefRemovedChannel { + public: +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockObjectLocationsChannel : public ObjectLocationsChannel { + public: +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockSubscriberInterface : public SubscriberInterface { + public: + MOCK_METHOD(void, Subscribe, + (std::unique_ptr sub_message, + const rpc::ChannelType channel_type, const rpc::Address &publisher_address, + const std::string &key_id_binary, + SubscriptionCallback subscription_callback, + SubscriptionFailureCallback subscription_failure_callback), + (override)); + MOCK_METHOD(bool, Unsubscribe, + (const rpc::ChannelType channel_type, const rpc::Address &publisher_address, + const std::string &key_id_binary), + (override)); + MOCK_METHOD(std::string, DebugString, (), (const, override)); +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockSubscriberClientInterface : public SubscriberClientInterface { + public: + MOCK_METHOD(void, PubsubLongPolling, + (const rpc::PubsubLongPollingRequest &request, + const rpc::ClientCallback &callback), + (override)); + MOCK_METHOD(void, PubsubCommandBatch, + (const rpc::PubsubCommandBatchRequest &request, + const rpc::ClientCallback &callback), + (override)); +}; + +} // namespace pubsub +} // namespace ray diff --git a/src/mock/ray/raylet/node_manager.h b/src/mock/ray/raylet/node_manager.h index 7edc1c9916d07..1ce3563ba450d 100644 --- a/src/mock/ray/raylet/node_manager.h +++ b/src/mock/ray/raylet/node_manager.h @@ -67,6 +67,11 @@ class MockNodeManager : public NodeManager { rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); + MOCK_METHOD(void, HandleReportWorkerBacklog, + (const rpc::ReportWorkerBacklogRequest &request, + rpc::ReportWorkerBacklogReply *reply, + rpc::SendReplyCallback send_reply_callback), + (override)); MOCK_METHOD(void, HandleReturnWorker, (const rpc::ReturnWorkerRequest &request, rpc::ReturnWorkerReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h b/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h index 498a5088b7194..3c5d4498af18b 100644 --- a/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h +++ b/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h @@ -33,8 +33,6 @@ class MockClusterTaskManagerInterface : public ClusterTaskManagerInterface { (const, override)); MOCK_METHOD(void, TaskFinished, (std::shared_ptr worker, RayTask *task), (override)); - MOCK_METHOD(void, ReturnWorkerResources, (std::shared_ptr worker), - (override)); MOCK_METHOD(bool, CancelTask, (const TaskID &task_id, bool runtime_env_setup_failed), (override)); MOCK_METHOD(void, QueueAndScheduleTask, diff --git a/src/mock/ray/raylet_client/raylet_client.h b/src/mock/ray/raylet_client/raylet_client.h index cafd952e5d6e4..c2dc3dd43097c 100644 --- a/src/mock/ray/raylet_client/raylet_client.h +++ b/src/mock/ray/raylet_client/raylet_client.h @@ -35,6 +35,12 @@ class MockWorkerLeaseInterface : public WorkerLeaseInterface { const ray::rpc::ClientCallback &callback, const int64_t backlog_size), (override)); + MOCK_METHOD( + void, RequestWorkerLease, + (const rpc::TaskSpec &task_spec, + const ray::rpc::ClientCallback &callback, + const int64_t backlog_size), + (override)); MOCK_METHOD(ray::Status, ReturnWorker, (int worker_port, const WorkerID &worker_id, bool disconnect_worker), (override)); @@ -66,7 +72,7 @@ class MockResourceReserveInterface : public ResourceReserveInterface { (override)); MOCK_METHOD( void, CancelResourceReserve, - (BundleSpecification & bundle_spec, + (const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback), (override)); MOCK_METHOD(void, ReleaseUnusedBundles, @@ -106,41 +112,27 @@ class MockResourceTrackingInterface : public ResourceTrackingInterface { namespace ray { class MockRayletClientInterface : public RayletClientInterface { - public: - MOCK_METHOD(void, GetSystemConfig, - (const rpc::ClientCallback &callback), - (override)); - MOCK_METHOD(void, GetGcsServerAddress, - (const rpc::ClientCallback &callback), - (override)); -}; - -} // namespace ray - -namespace ray { -namespace raylet { - -class MockRayletConnection : public RayletConnection { - public: -}; - -} // namespace raylet -} // namespace ray - -namespace ray { -namespace raylet { - -class MockRayletClient : public RayletClient { public: MOCK_METHOD(ray::Status, WaitForDirectActorCallArgs, (const std::vector &references, int64_t tag), (override)); + MOCK_METHOD(void, ReportWorkerBacklog, + (const WorkerID &worker_id, + const std::vector &backlog_reports), + (override)); MOCK_METHOD( void, RequestWorkerLease, (const ray::TaskSpecification &resource_spec, const ray::rpc::ClientCallback &callback, const int64_t backlog_size), (override)); + MOCK_METHOD( + void, RequestWorkerLease, + (const rpc::TaskSpec &resource_spec, + const ray::rpc::ClientCallback &callback, + const int64_t backlog_size), + (override)); + MOCK_METHOD(ray::Status, ReturnWorker, (int worker_port, const WorkerID &worker_id, bool disconnect_worker), (override)); @@ -164,7 +156,7 @@ class MockRayletClient : public RayletClient { (override)); MOCK_METHOD( void, CancelResourceReserve, - (BundleSpecification & bundle_spec, + (const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback), (override)); MOCK_METHOD(void, ReleaseUnusedBundles, @@ -191,5 +183,4 @@ class MockRayletClient : public RayletClient { (override)); }; -} // namespace raylet } // namespace ray diff --git a/src/mock/ray/rpc/worker/core_worker_client.h b/src/mock/ray/rpc/worker/core_worker_client.h new file mode 100644 index 0000000000000..a4646cef99e16 --- /dev/null +++ b/src/mock/ray/rpc/worker/core_worker_client.h @@ -0,0 +1,123 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace rpc { + +class MockWorkerAddress : public WorkerAddress { + public: +}; + +} // namespace rpc +} // namespace ray + +namespace ray { +namespace rpc { + +class MockCoreWorkerClientInterface : public ray::pubsub::MockSubscriberClientInterface, + public CoreWorkerClientInterface { + public: + MOCK_METHOD(const rpc::Address &, Addr, (), (const, override)); + MOCK_METHOD(void, PushActorTask, + (std::unique_ptr request, bool skip_queue, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, PushNormalTask, + (std::unique_ptr request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, StealTasks, + (std::unique_ptr request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, DirectActorCallArgWaitComplete, + (const DirectActorCallArgWaitCompleteRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, GetObjectStatus, + (const GetObjectStatusRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, WaitForActorOutOfScope, + (const WaitForActorOutOfScopeRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, PubsubLongPolling, + (const PubsubLongPollingRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, PubsubCommandBatch, + (const PubsubCommandBatchRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, UpdateObjectLocationBatch, + (const UpdateObjectLocationBatchRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, GetObjectLocationsOwner, + (const GetObjectLocationsOwnerRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, KillActor, + (const KillActorRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, CancelTask, + (const CancelTaskRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, RemoteCancelTask, + (const RemoteCancelTaskRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, GetCoreWorkerStats, + (const GetCoreWorkerStatsRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, LocalGC, + (const LocalGCRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, SpillObjects, + (const SpillObjectsRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, RestoreSpilledObjects, + (const RestoreSpilledObjectsRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, DeleteSpilledObjects, + (const DeleteSpilledObjectsRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, AddSpilledUrl, + (const AddSpilledUrlRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, PlasmaObjectReady, + (const PlasmaObjectReadyRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, Exit, + (const ExitRequest &request, const ClientCallback &callback), + (override)); + MOCK_METHOD(void, AssignObjectOwner, + (const AssignObjectOwnerRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(int64_t, ClientProcessedUpToSeqno, (), (override)); +}; + +} // namespace rpc +} // namespace ray diff --git a/src/mock/ray/rpc/worker/core_worker_client_pool.h b/src/mock/ray/rpc/worker/core_worker_client_pool.h new file mode 100644 index 0000000000000..d4e1ec607e5a2 --- /dev/null +++ b/src/mock/ray/rpc/worker/core_worker_client_pool.h @@ -0,0 +1,23 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace rpc { + +class MockCoreWorkerClientPool : public CoreWorkerClientPool { + public: +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/common/bundle_spec.cc b/src/ray/common/bundle_spec.cc index 339a492360d21..c5b4a711e0275 100644 --- a/src/ray/common/bundle_spec.cc +++ b/src/ray/common/bundle_spec.cc @@ -74,6 +74,10 @@ PlacementGroupID BundleSpecification::PlacementGroupId() const { return PlacementGroupID::FromBinary(message_->bundle_id().placement_group_id()); } +NodeID BundleSpecification::NodeId() const { + return NodeID::FromBinary(message_->node_id()); +} + int64_t BundleSpecification::Index() const { return message_->bundle_id().bundle_index(); } @@ -89,16 +93,19 @@ std::string BundleSpecification::DebugString() const { std::string FormatPlacementGroupResource(const std::string &original_resource_name, const PlacementGroupID &group_id, int64_t bundle_index) { - std::string str; + std::stringstream os; if (bundle_index >= 0) { - str = original_resource_name + "_group_" + std::to_string(bundle_index) + "_" + - group_id.Hex(); + os << original_resource_name << kGroupKeyword << std::to_string(bundle_index) << "_" + << group_id.Hex(); } else { RAY_CHECK(bundle_index == -1) << "Invalid index " << bundle_index; - str = original_resource_name + "_group_" + group_id.Hex(); + os << original_resource_name << kGroupKeyword << group_id.Hex(); } - RAY_CHECK(GetOriginalResourceName(str) == original_resource_name) << str; - return str; + std::string result = os.str(); + RAY_DCHECK(GetOriginalResourceName(result) == original_resource_name) + << "Generated: " << GetOriginalResourceName(result) + << " Original: " << original_resource_name; + return result; } std::string FormatPlacementGroupResource(const std::string &original_resource_name, @@ -109,12 +116,12 @@ std::string FormatPlacementGroupResource(const std::string &original_resource_na bool IsBundleIndex(const std::string &resource, const PlacementGroupID &group_id, const int bundle_index) { - return resource.find("_group_" + std::to_string(bundle_index) + "_" + group_id.Hex()) != - std::string::npos; + return resource.find(kGroupKeyword + std::to_string(bundle_index) + "_" + + group_id.Hex()) != std::string::npos; } std::string GetOriginalResourceName(const std::string &resource) { - auto idx = resource.find("_group_"); + auto idx = resource.find(kGroupKeyword); RAY_CHECK(idx >= 0) << "This isn't a placement group resource " << resource; return resource.substr(0, idx); } diff --git a/src/ray/common/bundle_spec.h b/src/ray/common/bundle_spec.h index 8437704509b58..bca5396fdc71a 100644 --- a/src/ray/common/bundle_spec.h +++ b/src/ray/common/bundle_spec.h @@ -32,6 +32,9 @@ typedef std::function ScheduleBundleCallback; /// address and the raylet's port. typedef std::function SpillbackBundleCallback; +const std::string kGroupKeyword = "_group_"; +const size_t kGroupKeywordSize = kGroupKeyword.size(); + class BundleSpecification : public MessageWrapper { public: /// Construct from a protobuf message object. @@ -54,6 +57,9 @@ class BundleSpecification : public MessageWrapper { // Return the Placement Group id which the Bundle belong to. PlacementGroupID PlacementGroupId() const; + // Get a node ID that this bundle is scheduled on. + NodeID NodeId() const; + // Return the index of the bundle. int64_t Index() const; diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index 73743820b2b9b..7eb51a953e215 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 780c1b70d3098..19180ef356b38 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -51,3 +51,6 @@ constexpr int kMessagePackOffset = 9; /// Filename of "shim process" that sets up Python worker environment. /// Should be kept in sync with SETUP_WORKER_FILENAME in ray.ray_constants. constexpr char kSetupWorkerFilename[] = "setup_worker.py"; + +/// The version of Ray +constexpr char kRayVersion[] = "2.0.0.dev0"; diff --git a/src/ray/common/id.h b/src/ray/common/id.h index 0fc5d45599392..889128e81df11 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -492,6 +492,7 @@ std::string BaseID::Hex() const { constexpr char hex[] = "0123456789abcdef"; const uint8_t *id = Data(); std::string result; + result.reserve(T::Size()); for (size_t i = 0; i < T::Size(); i++) { unsigned int val = id[i]; result.push_back(hex[val >> 4]); diff --git a/src/ray/common/network_util.h b/src/ray/common/network_util.h index 08bef7ae873af..8f268ec46b389 100644 --- a/src/ray/common/network_util.h +++ b/src/ray/common/network_util.h @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 53e0bf4d72450..0a6a61357b79f 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -183,8 +183,7 @@ RAY_CONFIG(int64_t, worker_register_timeout_seconds, 30) RAY_CONFIG(int64_t, redis_db_connect_retries, 50) RAY_CONFIG(int64_t, redis_db_connect_wait_milliseconds, 100) -/// Timeout, in milliseconds, to wait before retrying a failed pull in the -/// ObjectManager. +/// The object manager's global timer interval in milliseconds. RAY_CONFIG(int, object_manager_timer_freq_ms, 100) /// Timeout, in milliseconds, to wait before retrying a failed pull in the @@ -221,14 +220,8 @@ RAY_CONFIG(int32_t, maximum_profile_table_rows_count, 10 * 1000) /// message. RAY_CONFIG(uint32_t, object_store_get_max_ids_to_print_in_warning, 20) -// TODO: fix win32 timeout in ci and unify these two. -#ifdef _MSC_VER /// Number of threads used by rpc server in gcs server. RAY_CONFIG(uint32_t, gcs_server_rpc_server_thread_num, 1) -#else -/// Number of threads used by rpc server in gcs server. -RAY_CONFIG(uint32_t, gcs_server_rpc_server_thread_num, 8) -#endif /// Allow up to 5 seconds for connecting to gcs service. /// Note: this only takes effect when gcs service is enabled. RAY_CONFIG(int64_t, gcs_service_connect_retries, 50) @@ -241,8 +234,10 @@ RAY_CONFIG(uint64_t, gcs_redis_heartbeat_interval_milliseconds, 100) RAY_CONFIG(uint32_t, gcs_lease_worker_retry_interval_ms, 200) /// Duration to wait between retries for creating actor in gcs server. RAY_CONFIG(uint32_t, gcs_create_actor_retry_interval_ms, 200) -/// Duration to wait between retries for creating placement group in gcs server. -RAY_CONFIG(uint32_t, gcs_create_placement_group_retry_interval_ms, 200) +/// Exponential backoff params for gcs to retry creating a placement group +RAY_CONFIG(uint32_t, gcs_create_placement_group_retry_min_interval_ms, 200) +RAY_CONFIG(uint32_t, gcs_create_placement_group_retry_max_interval_ms, 5000) +RAY_CONFIG(double, gcs_create_placement_group_retry_multiplier, 1.5); /// Maximum number of destroyed actors in GCS server memory cache. RAY_CONFIG(uint32_t, maximum_gcs_destroyed_actor_cached_count, 100000) /// Maximum number of dead nodes in GCS server memory cache. @@ -311,12 +306,18 @@ RAY_CONFIG(int64_t, task_rpc_inlined_bytes_limit, 10 * 1024 * 1024) /// pipelining task submission. RAY_CONFIG(uint32_t, max_tasks_in_flight_per_worker, 1) +/// Maximum number of pending lease requests per scheduling category +RAY_CONFIG(uint64_t, max_pending_lease_requests_per_scheduling_category, 10) + /// Interval to restart dashboard agent after the process exit. RAY_CONFIG(uint32_t, agent_restart_interval_ms, 1000) /// Wait timeout for dashboard agent register. RAY_CONFIG(uint32_t, agent_register_timeout_ms, 30 * 1000) +/// Max restart count for the dashboard agent. +RAY_CONFIG(uint32_t, agent_max_restart_count, 5) + /// If the agent manager fails to communicate with the dashboard agent, we will retry /// after this interval. RAY_CONFIG(uint32_t, agent_manager_retry_interval_ms, 1000); @@ -325,12 +326,8 @@ RAY_CONFIG(uint32_t, agent_manager_retry_interval_ms, 1000); /// load reported by each raylet. RAY_CONFIG(int64_t, max_resource_shapes_per_load_report, 100) -/// If true, the worker's queue backlog size will be propagated to the heartbeat batch -/// data. -RAY_CONFIG(bool, report_worker_backlog, true) - /// The timeout for synchronous GCS requests in seconds. -RAY_CONFIG(int64_t, gcs_server_request_timeout_seconds, 5) +RAY_CONFIG(int64_t, gcs_server_request_timeout_seconds, 60) /// Whether to enable worker prestarting: https://github.com/ray-project/ray/issues/12052 RAY_CONFIG(bool, enable_worker_prestart, true) @@ -478,7 +475,7 @@ RAY_CONFIG(int64_t, grpc_keepalive_time_ms, 10000); RAY_CONFIG(int64_t, grpc_keepalive_timeout_ms, 20000); /// Whether to use log reporter in event framework -RAY_CONFIG(bool, event_log_reporter_enabled, false) +RAY_CONFIG(bool, event_log_reporter_enabled, true) /// Whether to use log reporter in event framework RAY_CONFIG(bool, actor_register_async, true) @@ -491,3 +488,11 @@ RAY_CONFIG(bool, scheduler_avoid_gpu_nodes, true) /// Whether to skip running local GC in runtime env. RAY_CONFIG(bool, runtime_env_skip_local_gc, false) + +/// Whether or not use TLS. +RAY_CONFIG(int64_t, USE_TLS, 0) + +/// Location of TLS credentials +RAY_CONFIG(std::string, TLS_SERVER_CERT, "") +RAY_CONFIG(std::string, TLS_SERVER_KEY, "") +RAY_CONFIG(std::string, TLS_CA_CERT, "") diff --git a/src/ray/common/ray_internal_flag_def.h b/src/ray/common/ray_internal_flag_def.h index 20f1ef8ccc3e3..0f42d63d3f1ef 100644 --- a/src/ray/common/ray_internal_flag_def.h +++ b/src/ray/common/ray_internal_flag_def.h @@ -27,3 +27,6 @@ RAY_INTERNAL_FLAG(std::string, JOB_ID, "") /// Raylet process ID. RAY_INTERNAL_FLAG(std::string, RAYLET_PID, "") + +/// Override the random node ID for testing. +RAY_INTERNAL_FLAG(std::string, OVERRIDE_NODE_ID_FOR_TESTING, "") diff --git a/src/ray/common/runtime_env_manager.cc b/src/ray/common/runtime_env_manager.cc index 2ec95cdecee8f..9e39488fa9149 100644 --- a/src/ray/common/runtime_env_manager.cc +++ b/src/ray/common/runtime_env_manager.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/common/runtime_env_manager.h" + #include "ray/util/logging.h" namespace ray { @@ -20,17 +21,12 @@ void RuntimeEnvManager::AddURIReference(const std::string &hex_id, const rpc::RuntimeEnv &runtime_env) { const auto &uris = runtime_env.uris(); for (const auto &uri : uris) { - AddURIReference(hex_id, uri); - } -} - -void RuntimeEnvManager::AddURIReference(const std::string &hex_id, - const std::string &uri) { - if (unused_uris_.count(uri)) { - unused_uris_.erase(uri); + if (unused_uris_.count(uri)) { + unused_uris_.erase(uri); + } + uri_reference_[uri]++; + id_to_uris_[hex_id].push_back(uri); } - uri_reference_[uri]++; - id_to_uris_[hex_id].push_back(uri); } const std::vector &RuntimeEnvManager::GetReferences( diff --git a/src/ray/common/runtime_env_manager.h b/src/ray/common/runtime_env_manager.h index 510aa5fe53aa9..f9c59d74784bb 100644 --- a/src/ray/common/runtime_env_manager.h +++ b/src/ray/common/runtime_env_manager.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once #include + #include "ray/common/id.h" #include "src/ray/protobuf/common.pb.h" @@ -37,12 +38,6 @@ class RuntimeEnvManager { /// \param[in] runtime_env The runtime env used by the id. void AddURIReference(const std::string &hex_id, const rpc::RuntimeEnv &runtime_env); - /// Increase the reference of URI by URI and runtime_env. - /// - /// \param[in] hex_id The id of the runtime env. It can be an actor or job id. - /// \param[in] uri The URI referenced by the id. - void AddURIReference(const std::string &hex_id, const std::string &uri); - /// Get the reference of URIs by id. /// /// \param[in] hex_id The id of to look. diff --git a/src/ray/common/task/task.cc b/src/ray/common/task/task.cc index 291829e36f567..4765751afa3fc 100644 --- a/src/ray/common/task/task.cc +++ b/src/ray/common/task/task.cc @@ -18,10 +18,9 @@ namespace ray { -RayTask::RayTask(const rpc::Task &message, int64_t backlog_size) +RayTask::RayTask(const rpc::Task &message) : task_spec_(message.task_spec()), - task_execution_spec_(message.task_execution_spec()), - backlog_size_(backlog_size) { + task_execution_spec_(message.task_execution_spec()) { ComputeDependencies(); } @@ -50,10 +49,6 @@ void RayTask::CopyTaskExecutionSpec(const RayTask &task) { task_execution_spec_ = task.task_execution_spec_; } -void RayTask::SetBacklogSize(int64_t backlog_size) { backlog_size_ = backlog_size; } - -int64_t RayTask::BacklogSize() const { return backlog_size_; } - std::string RayTask::DebugString() const { std::ostringstream stream; stream << "task_spec={" << task_spec_.DebugString() << "}, task_execution_spec={" diff --git a/src/ray/common/task/task.h b/src/ray/common/task/task.h index c21ec9c94da8e..52c0e9246dab2 100644 --- a/src/ray/common/task/task.h +++ b/src/ray/common/task/task.h @@ -47,9 +47,7 @@ class RayTask { /// Construct a `RayTask` object from a protobuf message. /// /// \param message The protobuf message. - /// \param backlog_size The size of the task owner's backlog size for this - /// task's shape. - explicit RayTask(const rpc::Task &message, int64_t backlog_size = -1); + explicit RayTask(const rpc::Task &message); /// Construct a `RayTask` object from a `TaskSpecification` and a /// `TaskExecutionSpecification`. @@ -103,10 +101,6 @@ class RayTask { /// Returns the cancellation task callback, or nullptr. const CancelTaskCallback &OnCancellation() const { return on_cancellation_; } - void SetBacklogSize(int64_t backlog_size); - - int64_t BacklogSize() const; - std::string DebugString() const; private: @@ -133,8 +127,6 @@ class RayTask { /// For direct task calls, overrides the cancellation behaviour to send an /// RPC back to the submitting worker. mutable CancelTaskCallback on_cancellation_ = nullptr; - /// The size of the core worker's backlog when this task was submitted. - int64_t backlog_size_ = -1; }; } // namespace ray diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 353406fd3c820..0c3d77beb5993 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -132,8 +132,10 @@ ray::FunctionDescriptor TaskSpecification::FunctionDescriptor() const { return ray::FunctionDescriptorBuilder::FromProto(message_->function_descriptor()); } +rpc::RuntimeEnv TaskSpecification::RuntimeEnv() const { return message_->runtime_env(); } + std::string TaskSpecification::SerializedRuntimeEnv() const { - return message_->serialized_runtime_env(); + return message_->runtime_env().serialized_runtime_env(); } bool TaskSpecification::HasRuntimeEnv() const { @@ -145,8 +147,7 @@ int TaskSpecification::GetRuntimeEnvHash() const { if (RayConfig::instance().worker_resource_limits_enabled()) { required_resource = GetRequiredResources().GetResourceMap(); } - WorkerCacheKey env = {OverrideEnvironmentVariables(), SerializedRuntimeEnv(), - required_resource}; + WorkerCacheKey env = {SerializedRuntimeEnv(), required_resource}; return env.IntHash(); } @@ -239,11 +240,6 @@ std::string TaskSpecification::GetDebuggerBreakpoint() const { return message_->debugger_breakpoint(); } -std::unordered_map -TaskSpecification::OverrideEnvironmentVariables() const { - return MapFromProtobuf(message_->override_environment_variables()); -} - bool TaskSpecification::IsDriverTask() const { return message_->type() == TaskType::DRIVER_TASK; } @@ -398,11 +394,9 @@ std::string TaskSpecification::CallSiteString() const { } WorkerCacheKey::WorkerCacheKey( - const std::unordered_map override_environment_variables, const std::string serialized_runtime_env, const std::unordered_map required_resources) - : override_environment_variables(override_environment_variables), - serialized_runtime_env(serialized_runtime_env), + : serialized_runtime_env(serialized_runtime_env), required_resources(std::move(required_resources)) {} bool WorkerCacheKey::operator==(const WorkerCacheKey &k) const { @@ -411,8 +405,7 @@ bool WorkerCacheKey::operator==(const WorkerCacheKey &k) const { } bool WorkerCacheKey::EnvIsEmpty() const { - return override_environment_variables.size() == 0 && - (serialized_runtime_env == "" || serialized_runtime_env == "{}") && + return (serialized_runtime_env == "" || serialized_runtime_env == "{}") && required_resources.empty(); } @@ -424,19 +417,6 @@ std::size_t WorkerCacheKey::Hash() const { // runtime envs. hash_ = 0; } else { - std::vector> env_vars( - override_environment_variables.begin(), override_environment_variables.end()); - // The environment doesn't depend the order of the variables, so the hash should not - // either. Sort the variables so different permutations yield the same hash. - std::sort(env_vars.begin(), env_vars.end()); - for (auto &pair : env_vars) { - // TODO(architkulkarni): boost::hash_combine isn't guaranteed to be equal during - // separate runs of a program, which may cause problems if these hashes are - // communicated between different Raylets and compared. - boost::hash_combine(hash_, pair.first); - boost::hash_combine(hash_, pair.second); - } - boost::hash_combine(hash_, serialized_runtime_env); std::vector> resource_vars( diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 8b10b163cc3cc..24dbf4afbae21 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -100,6 +100,8 @@ class TaskSpecification : public MessageWrapper { ray::FunctionDescriptor FunctionDescriptor() const; + [[nodiscard]] rpc::RuntimeEnv RuntimeEnv() const; + std::string SerializedRuntimeEnv() const; bool HasRuntimeEnv() const; @@ -170,8 +172,6 @@ class TaskSpecification : public MessageWrapper { std::string GetDebuggerBreakpoint() const; - std::unordered_map OverrideEnvironmentVariables() const; - bool IsDriverTask() const; Language GetLanguage() const; @@ -254,7 +254,7 @@ class TaskSpecification : public MessageWrapper { /// Field storing required placement resources. Initialized in constructor. std::shared_ptr required_placement_resources_; /// Cached scheduling class of this task. - SchedulingClass sched_cls_id_; + SchedulingClass sched_cls_id_ = 0; /// Below static fields could be mutated in `ComputeResources` concurrently due to /// multi-threading, we need a mutex to protect it. @@ -275,13 +275,10 @@ class WorkerCacheKey { /// Create a cache key with the given environment variable overrides and serialized /// runtime_env. /// - /// \param override_environment_variables The environment variable overrides set in this /// worker. \param serialized_runtime_env The JSON-serialized runtime env for this /// worker. \param required_resources The required resouce. - WorkerCacheKey( - const std::unordered_map override_environment_variables, - const std::string serialized_runtime_env, - const std::unordered_map required_resources); + WorkerCacheKey(const std::string serialized_runtime_env, + const std::unordered_map required_resources); bool operator==(const WorkerCacheKey &k) const; @@ -293,8 +290,7 @@ class WorkerCacheKey { /// Get the hash for this worker's environment. /// - /// \return The hash of the override_environment_variables and the serialized - /// runtime_env. + /// \return The hash of the serialized runtime_env. std::size_t Hash() const; /// Get the int-valued hash for this worker's environment, useful for portability in @@ -304,8 +300,6 @@ class WorkerCacheKey { int IntHash() const; private: - /// The environment variable overrides for this worker. - const std::unordered_map override_environment_variables; /// The JSON-serialized runtime env for this worker. const std::string serialized_runtime_env; /// The required resources for this worker. diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index c011829c2603d..57ee5b811663e 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -106,8 +106,7 @@ class TaskSpecBuilder { const BundleID &bundle_id, bool placement_group_capture_child_tasks, const std::string &debugger_breakpoint, const std::string &serialized_runtime_env = "{}", - const std::unordered_map &override_environment_variables = - {}, + const std::vector &runtime_env_uris = {}, const std::string &concurrency_group_name = "") { message_->set_type(TaskType::NORMAL_TASK); message_->set_name(name); @@ -129,11 +128,11 @@ class TaskSpecBuilder { message_->set_placement_group_capture_child_tasks( placement_group_capture_child_tasks); message_->set_debugger_breakpoint(debugger_breakpoint); - message_->set_serialized_runtime_env(serialized_runtime_env); - message_->set_concurrency_group_name(concurrency_group_name); - for (const auto &env : override_environment_variables) { - (*message_->mutable_override_environment_variables())[env.first] = env.second; + message_->mutable_runtime_env()->set_serialized_runtime_env(serialized_runtime_env); + for (const std::string &uri : runtime_env_uris) { + message_->mutable_runtime_env()->add_uris(uri); } + message_->set_concurrency_group_name(concurrency_group_name); return *this; } diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 016d16ddc8851..dfb4fd9a39f28 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -60,14 +60,13 @@ struct TaskOptions { std::unordered_map &resources, const std::string &concurrency_group_name = "", const std::string &serialized_runtime_env = "{}", - const std::unordered_map - &override_environment_variables = {}) + const std::vector &runtime_env_uris = {}) : name(name), num_returns(num_returns), resources(resources), concurrency_group_name(concurrency_group_name), serialized_runtime_env(serialized_runtime_env), - override_environment_variables(override_environment_variables) {} + runtime_env_uris(runtime_env_uris) {} /// The name of this task. std::string name; @@ -77,12 +76,10 @@ struct TaskOptions { std::unordered_map resources; /// The name of the concurrency group in which this task will be executed. std::string concurrency_group_name; - // Runtime Env used by this task. Propagated to child actors and tasks. + // Runtime Env used by this task. Propagated to child actors and tasks. std::string serialized_runtime_env; - /// Environment variables to update for this task. Maps a variable name to its - /// value. Can override existing environment variables and introduce new ones. - /// Propagated to child actors and/or tasks. - const std::unordered_map override_environment_variables; + // URIs contained in the runtime_env. + std::vector runtime_env_uris; }; /// Options for actor creation tasks. @@ -97,8 +94,7 @@ struct ActorCreationOptions { BundleID placement_options = std::make_pair(PlacementGroupID::Nil(), -1), bool placement_group_capture_child_tasks = true, const std::string &serialized_runtime_env = "{}", - const std::unordered_map &override_environment_variables = - {}, + const std::vector &runtime_env_uris = {}, const std::vector &concurrency_groups = {}) : max_restarts(max_restarts), max_task_retries(max_task_retries), @@ -113,7 +109,7 @@ struct ActorCreationOptions { placement_options(placement_options), placement_group_capture_child_tasks(placement_group_capture_child_tasks), serialized_runtime_env(serialized_runtime_env), - override_environment_variables(override_environment_variables), + runtime_env_uris(runtime_env_uris), concurrency_groups(concurrency_groups.begin(), concurrency_groups.end()){}; /// Maximum number of times that the actor should be restarted if it dies @@ -155,10 +151,8 @@ struct ActorCreationOptions { bool placement_group_capture_child_tasks = true; // Runtime Env used by this actor. Propagated to child actors and tasks. std::string serialized_runtime_env; - /// Environment variables to update for this actor. Maps a variable name to its - /// value. Can override existing environment variables and introduce new ones. - /// Propagated to child actors and/or tasks. - const std::unordered_map override_environment_variables; + // URIs contained in the runtime_env. + std::vector runtime_env_uris; /// The actor concurrency groups to indicate how this actor perform its /// methods concurrently. const std::vector concurrency_groups; diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 37e7797e62676..ab8f6c1884764 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -168,12 +168,7 @@ bool WorkerContext::ShouldCaptureChildTasksInPlacementGroup() const { } const std::string &WorkerContext::GetCurrentSerializedRuntimeEnv() const { - return serialized_runtime_env_; -} - -const std::unordered_map - &WorkerContext::GetCurrentOverrideEnvironmentVariables() const { - return override_environment_variables_; + return runtime_env_.serialized_runtime_env(); } void WorkerContext::SetCurrentTaskId(const TaskID &task_id) { @@ -186,10 +181,9 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { if (task_spec.IsNormalTask()) { current_task_is_direct_call_ = true; // TODO(architkulkarni): Once workers are cached by runtime env, we should - // only set serialized_runtime_env_ once and then RAY_CHECK that we + // only set runtime_env_ once and then RAY_CHECK that we // never see a new one. - serialized_runtime_env_ = task_spec.SerializedRuntimeEnv(); - override_environment_variables_ = task_spec.OverrideEnvironmentVariables(); + runtime_env_ = task_spec.RuntimeEnv(); } else if (task_spec.IsActorCreationTask()) { RAY_CHECK(current_actor_id_.IsNil()); current_actor_id_ = task_spec.ActorCreationId(); @@ -199,8 +193,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { is_detached_actor_ = task_spec.IsDetachedActor(); current_actor_placement_group_id_ = task_spec.PlacementGroupBundleId().first; placement_group_capture_child_tasks_ = task_spec.PlacementGroupCaptureChildTasks(); - serialized_runtime_env_ = task_spec.SerializedRuntimeEnv(); - override_environment_variables_ = task_spec.OverrideEnvironmentVariables(); + runtime_env_ = task_spec.RuntimeEnv(); } else if (task_spec.IsActorTask()) { RAY_CHECK(current_actor_id_ == task_spec.ActorId()); } else { diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index a403ee367c973..3c5f35718235a 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -42,9 +42,6 @@ class WorkerContext { const std::string &GetCurrentSerializedRuntimeEnv() const; - const std::unordered_map - &GetCurrentOverrideEnvironmentVariables() const; - // TODO(edoakes): remove this once Python core worker uses the task interfaces. void SetCurrentTaskId(const TaskID &task_id); @@ -98,10 +95,8 @@ class WorkerContext { PlacementGroupID current_actor_placement_group_id_; // Whether or not we should implicitly capture parent's placement group. bool placement_group_capture_child_tasks_; - // The JSON-serialized runtime env for the current actor or task. - std::string serialized_runtime_env_ = "{}"; - // The environment variable overrides for the current actor or task. - std::unordered_map override_environment_variables_; + // The runtime env for the current actor or task. + rpc::RuntimeEnv runtime_env_; /// The id of the (main) thread that constructed this worker context. boost::thread::id main_thread_id_; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index e9251cbf990ac..0ceb78c7405b8 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -27,34 +27,11 @@ namespace ray { namespace core { +namespace { // Duration between internal book-keeping heartbeats. const uint64_t kInternalHeartbeatMillis = 1000; -void BuildCommonTaskSpec( - TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, - const std::string name, const TaskID ¤t_task_id, const uint64_t task_index, - const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, - const std::vector> &args, uint64_t num_returns, - const std::unordered_map &required_resources, - const std::unordered_map &required_placement_resources, - const BundleID &bundle_id, bool placement_group_capture_child_tasks, - const std::string debugger_breakpoint, const std::string &serialized_runtime_env, - const std::unordered_map &override_environment_variables, - const std::string &concurrency_group_name = "") { - // Build common task spec. - builder.SetCommonTaskSpec( - task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, - current_task_id, task_index, caller_id, address, num_returns, required_resources, - required_placement_resources, bundle_id, placement_group_capture_child_tasks, - debugger_breakpoint, serialized_runtime_env, override_environment_variables, - concurrency_group_name); - // Set task arguments. - for (const auto &arg : args) { - builder.AddArg(*arg); - } -} - JobID GetProcessJobID(const CoreWorkerOptions &options) { if (options.worker_type == WorkerType::DRIVER) { RAY_CHECK(!options.job_id.IsNil()); @@ -89,6 +66,16 @@ ObjectLocation CreateObjectLocation(const rpc::GetObjectLocationsOwnerReply &rep /// The global instance of `CoreWorkerProcess`. std::unique_ptr core_worker_process; +/// Teriminate the process without cleaning up the resources. +/// It will flush the log if logging_enabled is set to true. +void QuickExit(bool logging_enabled) { + if (logging_enabled) { + RayLog::ShutDownRayLog(); + } + _Exit(1); +} +} // namespace + thread_local std::weak_ptr CoreWorkerProcess::current_core_worker_; void CoreWorkerProcess::Initialize(const CoreWorkerOptions &options) { @@ -103,10 +90,11 @@ void CoreWorkerProcess::Shutdown() { } RAY_CHECK(core_worker_process->options_.worker_type == WorkerType::DRIVER) << "The `Shutdown` interface is for driver only."; - RAY_CHECK(core_worker_process->global_worker_); - core_worker_process->global_worker_->Disconnect(); - core_worker_process->global_worker_->Shutdown(); - core_worker_process->RemoveWorker(core_worker_process->global_worker_); + auto global_worker = core_worker_process->GetGlobalWorker(); + RAY_CHECK(global_worker); + global_worker->Disconnect(); + global_worker->Shutdown(); + core_worker_process->RemoveWorker(global_worker); core_worker_process.reset(); } @@ -147,18 +135,8 @@ CoreWorkerProcess::CoreWorkerProcess(const CoreWorkerOptions &options) // NOTE(kfstorm): any initialization depending on RayConfig must happen after this line. InitializeSystemConfig(); - if (options_.num_workers == 1) { - // We need to create the worker instance here if: - // 1. This is a driver process. In this case, the driver is ready to use right after - // the CoreWorkerProcess::Initialize. - // 2. This is a Python worker process. In this case, Python will invoke some core - // worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need - // to create the worker instance here. One example of invocations is - // https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281. - if (options_.worker_type == WorkerType::DRIVER || - options_.language == Language::PYTHON) { - CreateWorker(); - } + if (ShouldCreateGlobalWorkerOnConstruction()) { + CreateWorker(); } // Assume stats module will be initialized exactly once in once process. @@ -168,7 +146,7 @@ CoreWorkerProcess::CoreWorkerProcess(const CoreWorkerOptions &options) // Initialize stats in core worker global tags. const ray::stats::TagsType global_tags = { {ray::stats::ComponentKey, "core_worker"}, - {ray::stats::VersionKey, "2.0.0.dev0"}, + {ray::stats::VersionKey, kRayVersion}, {ray::stats::NodeAddressKey, options_.node_ip_address}}; // NOTE(lingxuan.zlx): We assume RayConfig is initialized before it's used. @@ -256,11 +234,23 @@ void CoreWorkerProcess::InitializeSystemConfig() { RayConfig::instance().initialize(promise.get_future().get()); } +bool CoreWorkerProcess::ShouldCreateGlobalWorkerOnConstruction() const { + // We need to create the worker instance here if: + // 1. This is a driver process. In this case, the driver is ready to use right after + // the CoreWorkerProcess::Initialize. + // 2. This is a Python worker process. In this case, Python will invoke some core + // worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need + // to create the worker instance here. One example of invocations is + // https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281. + return options_.num_workers == 1 && (options_.worker_type == WorkerType::DRIVER || + options_.language == Language::PYTHON); +} + std::shared_ptr CoreWorkerProcess::TryGetWorker(const WorkerID &worker_id) { if (!core_worker_process) { return nullptr; } - absl::ReaderMutexLock workers_lock(&core_worker_process->worker_map_mutex_); + absl::ReaderMutexLock workers_lock(&core_worker_process->mutex_); auto it = core_worker_process->workers_.find(worker_id); if (it != core_worker_process->workers_.end()) { return it->second; @@ -271,8 +261,19 @@ std::shared_ptr CoreWorkerProcess::TryGetWorker(const WorkerID &work CoreWorker &CoreWorkerProcess::GetCoreWorker() { EnsureInitialized(); if (core_worker_process->options_.num_workers == 1) { - RAY_CHECK(core_worker_process->global_worker_) << "global_worker_ must not be NULL"; - return *core_worker_process->global_worker_; + auto global_worker = core_worker_process->GetGlobalWorker(); + if (core_worker_process->ShouldCreateGlobalWorkerOnConstruction() && !global_worker) { + // This could only happen when the worker has already been shutdown. + // In this case, we should exit without crashing. + // TODO (scv119): A better solution could be returning error code + // and handling it at language frontend. + RAY_LOG(ERROR) << "The global worker has already been shutdown. This happens when " + "the language frontend accesses the Ray's worker after it is " + "shutdown. The process will exit"; + QuickExit(core_worker_process->options_.enable_logging); + } + RAY_CHECK(global_worker) << "global_worker_ must not be NULL"; + return *global_worker; } auto ptr = current_core_worker_.lock(); RAY_CHECK(ptr != nullptr) @@ -283,7 +284,7 @@ CoreWorker &CoreWorkerProcess::GetCoreWorker() { void CoreWorkerProcess::SetCurrentThreadWorkerId(const WorkerID &worker_id) { EnsureInitialized(); if (core_worker_process->options_.num_workers == 1) { - RAY_CHECK(core_worker_process->global_worker_->GetWorkerID() == worker_id); + RAY_CHECK(core_worker_process->GetGlobalWorker()->GetWorkerID() == worker_id); return; } current_core_worker_ = core_worker_process->GetWorker(worker_id); @@ -291,23 +292,28 @@ void CoreWorkerProcess::SetCurrentThreadWorkerId(const WorkerID &worker_id) { std::shared_ptr CoreWorkerProcess::GetWorker( const WorkerID &worker_id) const { - absl::ReaderMutexLock lock(&worker_map_mutex_); + absl::ReaderMutexLock lock(&mutex_); auto it = workers_.find(worker_id); RAY_CHECK(it != workers_.end()) << "Worker " << worker_id << " not found."; return it->second; } +std::shared_ptr CoreWorkerProcess::GetGlobalWorker() { + absl::ReaderMutexLock lock(&mutex_); + return global_worker_; +} + std::shared_ptr CoreWorkerProcess::CreateWorker() { auto worker = std::make_shared( options_, global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom()); RAY_LOG(DEBUG) << "Worker " << worker->GetWorkerID() << " is created."; + absl::WriterMutexLock lock(&mutex_); if (options_.num_workers == 1) { global_worker_ = worker; } current_core_worker_ = worker; - absl::MutexLock lock(&worker_map_mutex_); workers_.emplace(worker->GetWorkerID(), worker); RAY_CHECK(workers_.size() <= static_cast(options_.num_workers)); return worker; @@ -315,6 +321,7 @@ std::shared_ptr CoreWorkerProcess::CreateWorker() { void CoreWorkerProcess::RemoveWorker(std::shared_ptr worker) { worker->WaitForShutdown(); + absl::WriterMutexLock lock(&mutex_); if (global_worker_) { RAY_CHECK(global_worker_ == worker); } else { @@ -322,7 +329,6 @@ void CoreWorkerProcess::RemoveWorker(std::shared_ptr worker) { } current_core_worker_.reset(); { - absl::MutexLock lock(&worker_map_mutex_); workers_.erase(worker->GetWorkerID()); RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID(); } @@ -336,9 +342,10 @@ void CoreWorkerProcess::RunTaskExecutionLoop() { RAY_CHECK(core_worker_process->options_.worker_type == WorkerType::WORKER); if (core_worker_process->options_.num_workers == 1) { // Run the task loop in the current thread only if the number of workers is 1. - auto worker = core_worker_process->global_worker_ - ? core_worker_process->global_worker_ - : core_worker_process->CreateWorker(); + auto worker = core_worker_process->GetGlobalWorker(); + if (!worker) { + worker = core_worker_process->CreateWorker(); + } worker->RunTaskExecutionLoop(); core_worker_process->RemoveWorker(worker); } else { @@ -370,9 +377,9 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ periodical_runner_(io_service_), task_queue_length_(0), num_executed_tasks_(0), - task_execution_service_work_(task_execution_service_), resource_ids_(new ResourceMappingType()), - grpc_service_(io_service_, *this) { + grpc_service_(io_service_, *this), + task_execution_service_work_(task_execution_service_) { RAY_LOG(DEBUG) << "Constructing CoreWorker, worker_id: " << worker_id; // Initialize task receivers. @@ -409,11 +416,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Avoid using FATAL log or RAY_CHECK here because they may create a core dump file. RAY_LOG(ERROR) << "Failed to register worker " << worker_id << " to Raylet. " << raylet_client_status; - if (options_.enable_logging) { - RayLog::ShutDownRayLog(); - } // Quit the process immediately. - _Exit(1); + QuickExit(options_.enable_logging); } connected_ = true; @@ -427,7 +431,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Start RPC server after all the task receivers are properly initialized and we have // our assigned port from the raylet. core_worker_server_ = std::make_unique( - WorkerTypeString(options_.worker_type), assigned_port); + WorkerTypeString(options_.worker_type), assigned_port, + options_.node_ip_address == "127.0.0.1"); core_worker_server_->RegisterService(grpc_service_); core_worker_server_->Run(); @@ -526,10 +531,6 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ options_.worker_type != WorkerType::RESTORE_WORKER), /*get_current_call_site=*/boost::bind(&CoreWorker::CurrentCallSite, this))); memory_store_.reset(new CoreWorkerMemoryStore( - [this](const RayObject &object, const ObjectID &object_id) { - PutObjectIntoPlasma(object, object_id); - return Status::OK(); - }, reference_counter_, local_raylet_client_, options_.check_signals, [this](const RayObject &obj) { // Run this on the event loop to avoid calling back into the language runtime @@ -656,7 +657,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ std::move(lease_policy), memory_store_, task_manager_, local_raylet_id, RayConfig::instance().worker_lease_timeout_milliseconds(), actor_creator_, RayConfig::instance().max_tasks_in_flight_per_worker(), - boost::asio::steady_timer(io_service_)); + boost::asio::steady_timer(io_service_), + RayConfig::instance().max_pending_lease_requests_per_scheduling_category()); auto report_locality_data_callback = [this](const ObjectID &object_id, const absl::flat_hash_set &locations, uint64_t object_size) { @@ -737,6 +739,11 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ }, event_stats_print_interval_ms); } + + // Set event context for current core worker thread. + RayEventContext::Instance().SetEventContext( + ray::rpc::Event_SourceType::Event_SourceType_CORE_WORKER, + {{"worker_id", worker_id.Hex()}}); } void CoreWorker::Shutdown() { @@ -933,17 +940,25 @@ void CoreWorker::RegisterToGcs() { } void CoreWorker::CheckForRayletFailure() { + bool should_shutdown = false; // When running worker process in container, the worker parent process is not raylet. // So we add RAY_RAYLET_PID enviroment to ray worker process. if (auto env_pid = RayConfig::instance().RAYLET_PID(); !env_pid.empty()) { auto pid = static_cast(std::stoi(env_pid)); if (!IsProcessAlive(pid)) { RAY_LOG(ERROR) << "Raylet failed. Shutting down. Raylet PID: " << pid; - Shutdown(); + should_shutdown = true; } } else if (!IsParentProcessAlive()) { RAY_LOG(ERROR) << "Raylet failed. Shutting down."; - Shutdown(); + should_shutdown = true; + } + if (should_shutdown) { + if (options_.worker_type == WorkerType::WORKER) { + task_execution_service_.post([this]() { Shutdown(); }, "CoreWorker.Shutdown"); + } else { + Shutdown(); + } } } @@ -971,6 +986,12 @@ void CoreWorker::InternalHeartbeat() { direct_actor_submitter_->CheckTimeoutTasks(); } + // Periodically report the lastest backlog so that + // local raylet will have the eventually consistent view of worker backlogs + // even in cases where backlog reports from direct_task_transport + // are lost or reordered. + direct_task_submitter_->ReportWorkerBacklog(); + // Check for unhandled exceptions to raise after a timeout on the driver. // Only do this for TTY, since shells like IPython sometimes save references // to the result and prevent normal result deletion from handling. @@ -992,36 +1013,6 @@ CoreWorker::GetAllReferenceCounts() const { return counts; } -void CoreWorker::PutObjectIntoPlasma(const RayObject &object, const ObjectID &object_id) { - bool object_exists; - // This call will only be used by PromoteObjectToPlasma, which means that the - // object will always owned by us. - RAY_CHECK_OK(plasma_store_provider_->Put( - object, object_id, /* owner_address = */ rpc_address_, &object_exists)); - if (!object_exists) { - // Tell the raylet to pin the object **after** it is created. - RAY_LOG(DEBUG) << "Pinning put object " << object_id; - local_raylet_client_->PinObjectIDs( - rpc_address_, {object_id}, - [this, object_id](const Status &status, const rpc::PinObjectIDsReply &reply) { - // Only release the object once the raylet has responded to avoid the race - // condition that the object could be evicted before the raylet pins it. - if (!plasma_store_provider_->Release(object_id).ok()) { - RAY_LOG(ERROR) << "Failed to release ObjectID (" << object_id - << "), might cause a leak in plasma."; - } - }); - } - RAY_CHECK(memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id)); -} - -void CoreWorker::PromoteObjectToPlasma(const ObjectID &object_id) { - auto value = memory_store_->GetOrPromoteToPlasma(object_id); - if (value) { - PutObjectIntoPlasma(*value, object_id); - } -} - const rpc::Address &CoreWorker::GetRpcAddress() const { return rpc_address_; } rpc::Address CoreWorker::GetOwnerAddress(const ObjectID &object_id) const { @@ -1061,7 +1052,6 @@ void CoreWorker::GetOwnershipInfo(const ObjectID &object_id, rpc::Address *owner "which task will create them. " "If this was not how your object ID was generated, please file an issue " "at https://github.com/ray-project/ray/issues/"; - RAY_LOG(DEBUG) << "Promoted object to plasma " << object_id; rpc::GetObjectStatusReply object_status; // Optimization: if the object exists, serialize and inline its status. This also @@ -1635,6 +1625,37 @@ std::unordered_map AddPlacementGroupConstraint( return resources; } +void CoreWorker::BuildCommonTaskSpec( + TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, + const std::string &name, const TaskID ¤t_task_id, uint64_t task_index, + const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, + const std::vector> &args, uint64_t num_returns, + const std::unordered_map &required_resources, + const std::unordered_map &required_placement_resources, + const BundleID &bundle_id, bool placement_group_capture_child_tasks, + const std::string &debugger_breakpoint, const std::string &serialized_runtime_env, + const std::vector &runtime_env_uris, + const std::string &concurrency_group_name) { + // Build common task spec. + builder.SetCommonTaskSpec( + task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, + current_task_id, task_index, caller_id, address, num_returns, required_resources, + required_placement_resources, bundle_id, placement_group_capture_child_tasks, + debugger_breakpoint, + // TODO(SongGuyang): Move the logic of `prepare_runtime_env` from Python to Core + // Worker. A common process is needed. + // If runtime env is not provided, use job config. Only for Java and C++ because it + // has been set in Python by `prepare_runtime_env`. + (serialized_runtime_env.empty() || serialized_runtime_env == "{}") + ? job_config_->runtime_env().serialized_runtime_env() + : serialized_runtime_env, + runtime_env_uris, concurrency_group_name); + // Set task arguments. + for (const auto &arg : args) { + builder.AddArg(*arg); + } +} + std::vector CoreWorker::SubmitTask( const RayFunction &function, const std::vector> &args, const TaskOptions &task_options, int max_retries, bool retry_exceptions, @@ -1652,21 +1673,13 @@ std::vector CoreWorker::SubmitTask( auto task_name = task_options.name.empty() ? function.GetFunctionDescriptor()->DefaultTaskName() : task_options.name; - // Propagate existing environment variable overrides, but override them with any new - // ones - std::unordered_map current_override_environment_variables = - worker_context_.GetCurrentOverrideEnvironmentVariables(); - std::unordered_map override_environment_variables = - task_options.override_environment_variables; - override_environment_variables.insert(current_override_environment_variables.begin(), - current_override_environment_variables.end()); // TODO(ekl) offload task building onto a thread pool for performance - BuildCommonTaskSpec( - builder, worker_context_.GetCurrentJobID(), task_id, task_name, - worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, - function, args, task_options.num_returns, constrained_resources, required_resources, - placement_options, placement_group_capture_child_tasks, debugger_breakpoint, - task_options.serialized_runtime_env, override_environment_variables); + BuildCommonTaskSpec(builder, worker_context_.GetCurrentJobID(), task_id, task_name, + worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), + rpc_address_, function, args, task_options.num_returns, + constrained_resources, required_resources, placement_options, + placement_group_capture_child_tasks, debugger_breakpoint, + task_options.serialized_runtime_env, task_options.runtime_env_uris); builder.SetNormalTaskSpec(max_retries, retry_exceptions); TaskSpecification task_spec = builder.Build(); RAY_LOG(DEBUG) << "Submit task " << task_spec.DebugString(); @@ -1702,12 +1715,7 @@ Status CoreWorker::CreateActor(const RayFunction &function, const JobID job_id = worker_context_.GetCurrentJobID(); // Propagate existing environment variable overrides, but override them with any new // ones - std::unordered_map current_override_environment_variables = - worker_context_.GetCurrentOverrideEnvironmentVariables(); - std::unordered_map override_environment_variables = - actor_creation_options.override_environment_variables; - override_environment_variables.insert(current_override_environment_variables.begin(), - current_override_environment_variables.end()); + std::vector return_ids; TaskSpecBuilder builder; auto new_placement_resources = AddPlacementGroupConstraint(actor_creation_options.placement_resources, @@ -1728,7 +1736,7 @@ Status CoreWorker::CreateActor(const RayFunction &function, actor_creation_options.placement_group_capture_child_tasks, "", /* debugger_breakpoint */ actor_creation_options.serialized_runtime_env, - override_environment_variables); + actor_creation_options.runtime_env_uris); auto actor_handle = std::make_unique( actor_id, GetCallerId(), rpc_address_, job_id, @@ -1905,7 +1913,6 @@ std::vector CoreWorker::SubmitActorTask( const auto task_name = task_options.name.empty() ? function.GetFunctionDescriptor()->DefaultTaskName() : task_options.name; - const std::unordered_map override_environment_variables = {}; BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id, task_name, worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, function, args, num_returns, task_options.resources, @@ -1913,7 +1920,7 @@ std::vector CoreWorker::SubmitActorTask( true, /* placement_group_capture_child_tasks */ "", /* debugger_breakpoint */ "{}", /* serialized_runtime_env */ - override_environment_variables, + {}, /* runtime_env_uris */ task_options.concurrency_group_name); // NOTE: placement_group_capture_child_tasks and runtime_env will // be ignored in the actor because we should always follow the actor's option. @@ -2184,6 +2191,14 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, task_queue_length_ -= 1; num_executed_tasks_ += 1; + // Modify the worker's per function counters. + std::string func_name = task_spec.FunctionDescriptor()->CallString(); + { + absl::MutexLock l(&task_counter_.tasks_counter_mutex_); + task_counter_.Add(TaskCounter::kPending, func_name, -1); + task_counter_.Add(TaskCounter::kRunning, func_name, 1); + } + if (!options_.is_local_mode) { worker_context_.SetCurrentTask(task_spec); SetCurrentTaskId(task_spec.TaskId()); @@ -2279,8 +2294,16 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, resource_ids_.reset(new ResourceMappingType()); } } - RAY_LOG(INFO) << "Finished executing task " << task_spec.TaskId() - << ", status=" << status; + + // Modify the worker's per function counters. + { + absl::MutexLock l(&task_counter_.tasks_counter_mutex_); + task_counter_.Add(TaskCounter::kRunning, func_name, -1); + task_counter_.Add(TaskCounter::kFinished, func_name, 1); + } + + RAY_LOG(DEBUG) << "Finished executing task " << task_spec.TaskId() + << ", status=" << status; if (status.IsCreationTaskError()) { Exit(rpc::WorkerExitType::CREATION_TASK_ERROR, creation_task_exception_pb_bytes); } else if (status.IsIntentionalSystemExit()) { @@ -2447,8 +2470,15 @@ void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request, return; } - // Increment the task_queue_length + // Increment the task_queue_length and per function counter. task_queue_length_ += 1; + std::string func_name = + FunctionDescriptorBuilder::FromProto(request.task_spec().function_descriptor()) + ->CallString(); + { + absl::MutexLock l(&task_counter_.tasks_counter_mutex_); + task_counter_.Add(TaskCounter::kPending, func_name, 1); + } // For actor tasks, we just need to post a HandleActorTask instance to the task // execution service. @@ -2855,13 +2885,10 @@ void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request, << " has received a force kill request after the cancellation. Killing " "a worker..."; Disconnect(); - if (options_.enable_logging) { - RayLog::ShutDownRayLog(); - } - // NOTE(hchen): Use `_Exit()` to force-exit this process without doing cleanup. + // NOTE(hchen): Use `QuickExit()` to force-exit this process without doing cleanup. // `exit()` will destruct static objects in an incorrect order, which will lead to // core dumps. - _Exit(1); + QuickExit(options_.enable_logging); } } @@ -2894,13 +2921,10 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request, "please create the Java actor with some dynamic options to make it being " "hosted in a dedicated worker process."; } - if (options_.enable_logging) { - RayLog::ShutDownRayLog(); - } - // NOTE(hchen): Use `_Exit()` to force-exit this process without doing cleanup. + // NOTE(hchen): Use `QuickExit()` to force-exit this process without doing cleanup. // `exit()` will destruct static objects in an incorrect order, which will lead to // core dumps. - _Exit(1); + QuickExit(options_.enable_logging); } else { Exit(rpc::WorkerExitType::INTENDED_EXIT); } @@ -3057,15 +3081,16 @@ void CoreWorker::HandleExit(const rpc::ExitRequest &request, rpc::ExitReply *rep // any object pinning RPCs in flight. bool is_idle = !own_objects && pins_in_flight == 0; reply->set_success(is_idle); - send_reply_callback(Status::OK(), - [this, is_idle]() { - // If the worker is idle, we exit. - if (is_idle) { - Exit(rpc::WorkerExitType::IDLE_EXIT); - } - }, - // We need to kill it regardless if the RPC failed. - [this]() { Exit(rpc::WorkerExitType::INTENDED_EXIT); }); + send_reply_callback( + Status::OK(), + [this, is_idle]() { + // If the worker is idle, we exit. + if (is_idle) { + Exit(rpc::WorkerExitType::IDLE_EXIT); + } + }, + // We need to kill it regardless if the RPC failed. + [this]() { Exit(rpc::WorkerExitType::INTENDED_EXIT); }); } void CoreWorker::HandleAssignObjectOwner(const rpc::AssignObjectOwnerRequest &request, @@ -3191,6 +3216,25 @@ std::shared_ptr CoreWorker::GetGcsClient() const { return gcs_cl bool CoreWorker::IsExiting() const { return exiting_; } +std::unordered_map> CoreWorker::GetActorCallStats() + const { + absl::MutexLock l(&task_counter_.tasks_counter_mutex_); + std::unordered_map> total_counts; + + for (const auto &count : task_counter_.pending_tasks_counter_map_) { + total_counts[count.first].resize(3, 0); + total_counts[count.first][0] = count.second; + } + for (const auto &count : task_counter_.running_tasks_counter_map_) { + total_counts[count.first][1] = count.second; + } + for (const auto &count : task_counter_.finished_tasks_counter_map_) { + total_counts[count.first][2] = count.second; + } + + return total_counts; +} + Status CoreWorker::WaitForActorRegistered(const std::vector &ids) { std::vector actor_ids; for (const auto &id : ids) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 3ef1e2476f6d2..883a1b013ff81 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -294,23 +294,29 @@ class CoreWorkerProcess { void InitializeSystemConfig(); + /// Check that if the global worker should be created on construction. + bool ShouldCreateGlobalWorkerOnConstruction() const; + /// Get the `CoreWorker` instance by worker ID. /// /// \param[in] workerId The worker ID. /// \return The `CoreWorker` instance. std::shared_ptr GetWorker(const WorkerID &worker_id) const - LOCKS_EXCLUDED(worker_map_mutex_); + LOCKS_EXCLUDED(mutex_); /// Create a new `CoreWorker` instance. /// /// \return The newly created `CoreWorker` instance. - std::shared_ptr CreateWorker() LOCKS_EXCLUDED(worker_map_mutex_); + std::shared_ptr CreateWorker() LOCKS_EXCLUDED(mutex_); /// Remove an existing `CoreWorker` instance. /// /// \param[in] The existing `CoreWorker` instance. /// \return Void. - void RemoveWorker(std::shared_ptr worker) LOCKS_EXCLUDED(worker_map_mutex_); + void RemoveWorker(std::shared_ptr worker) LOCKS_EXCLUDED(mutex_); + + /// Get the `GlobalWorker` instance, if the number of workers is 1. + std::shared_ptr GetGlobalWorker() LOCKS_EXCLUDED(mutex_); /// The various options. const CoreWorkerOptions options_; @@ -320,17 +326,16 @@ class CoreWorkerProcess { static thread_local std::weak_ptr current_core_worker_; /// The only core worker instance, if the number of workers is 1. - std::shared_ptr global_worker_; + std::shared_ptr global_worker_ GUARDED_BY(mutex_); /// The worker ID of the global worker, if the number of workers is 1. const WorkerID global_worker_id_; /// Map from worker ID to worker. - std::unordered_map> workers_ - GUARDED_BY(worker_map_mutex_); + std::unordered_map> workers_ GUARDED_BY(mutex_); - /// To protect accessing the `workers_` map. - mutable absl::Mutex worker_map_mutex_; + /// To protect access to workers_ and global_worker_ + mutable absl::Mutex mutex_; }; /// The root class that contains all the core and language-independent functionalities @@ -440,22 +445,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// (local, submitted_task) reference counts. For debugging purposes. std::unordered_map> GetAllReferenceCounts() const; - /// Put an object into plasma. It's a version of Put that directly put the - /// object into plasma and also pin the object. - /// - /// \param[in] The ray object. - /// \param[in] object_id The object ID to serialize. - /// appended to the serialized object ID. - void PutObjectIntoPlasma(const RayObject &object, const ObjectID &object_id); - - /// Promote an object to plasma. If the - /// object already exists locally, it will be put into the plasma store. If - /// it doesn't yet exist, it will be spilled to plasma once available. - /// - /// \param[in] object_id The object ID to serialize. - /// appended to the serialized object ID. - void PromoteObjectToPlasma(const ObjectID &object_id); - /// Get the RPC address of this worker. /// /// \param[out] The RPC address of this worker. @@ -1044,7 +1033,24 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Return true if the core worker is in the exit process. bool IsExiting() const; + /// Retrieve the current statistics about tasks being received and executing. + /// \return an unordered_map mapping function name to list of (num_received, + /// num_executing, num_executed). It is a std map instead of absl due to its + /// interface with language bindings. + std::unordered_map> GetActorCallStats() const; + private: + void BuildCommonTaskSpec( + TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, + const std::string &name, const TaskID ¤t_task_id, uint64_t task_index, + const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, + const std::vector> &args, uint64_t num_returns, + const std::unordered_map &required_resources, + const std::unordered_map &required_placement_resources, + const BundleID &bundle_id, bool placement_group_capture_child_tasks, + const std::string &debugger_breakpoint, const std::string &serialized_runtime_env, + const std::vector &runtime_env_uris, + const std::string &concurrency_group_name = ""); void SetCurrentTaskId(const TaskID &task_id); void SetActorId(const ActorID &actor_id); @@ -1366,12 +1372,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Number of executed tasks. std::atomic num_executed_tasks_; - /// Event loop where tasks are processed. - instrumented_io_context task_execution_service_; - - /// The asio work to keep task_execution_service_ alive. - boost::asio::io_service::work task_execution_service_work_; - /// Profiler including a background thread that pushes profiling events to the GCS. std::shared_ptr profiler_; @@ -1390,6 +1390,14 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { // Interface that receives tasks from direct actor calls. std::unique_ptr direct_task_receiver_; + /// Event loop where tasks are processed. + /// task_execution_service_ should be destructed first to avoid + /// issues like https://github.com/ray-project/ray/issues/18857 + instrumented_io_context task_execution_service_; + + /// The asio work to keep task_execution_service_ alive. + boost::asio::io_service::work task_execution_service_work_; + // Queue of tasks to resubmit when the specified time passes. std::deque> to_resubmit_ GUARDED_BY(mutex_); @@ -1408,14 +1416,47 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void PlasmaCallback(SetResultCallback success, std::shared_ptr ray_object, ObjectID object_id, void *py_future); - /// Whether we are shutting down and not running further tasks. - bool exiting_ = false; + /// we are shutting down and not running further tasks. + /// when exiting_ is set to true HandlePushTask becomes no-op. + std::atomic exiting_ = false; int64_t max_direct_call_object_size_; friend class CoreWorkerTest; std::unique_ptr job_config_; + + /// Simple container for per function task counters. The counters will be + /// keyed by the function name in task spec. + struct TaskCounter { + /// A task can only be one of the following state. Received state in particular + /// covers from the point of RPC call to beginning execution. + enum TaskStatusType { kPending, kRunning, kFinished }; + + /// This mutex should be used by caller to ensure consistency when transitioning + /// a task's state. + mutable absl::Mutex tasks_counter_mutex_; + absl::flat_hash_map pending_tasks_counter_map_ + GUARDED_BY(tasks_counter_mutex_); + absl::flat_hash_map running_tasks_counter_map_ + GUARDED_BY(tasks_counter_mutex_); + absl::flat_hash_map finished_tasks_counter_map_ + GUARDED_BY(tasks_counter_mutex_); + + void Add(TaskStatusType type, const std::string &func_name, int value) { + tasks_counter_mutex_.AssertHeld(); + if (type == kPending) { + pending_tasks_counter_map_[func_name] += value; + } else if (type == kRunning) { + running_tasks_counter_map_[func_name] += value; + } else if (type == kFinished) { + finished_tasks_counter_map_[func_name] += value; + } else { + RAY_CHECK(false) << "This line should not be reached."; + } + } + }; + TaskCounter task_counter_; }; } // namespace core diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc index 70a2626847574..6af083669f5ed 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc @@ -217,10 +217,9 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *env, } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativePromoteAndGetOwnershipInfo( - JNIEnv *env, jclass, jbyteArray objectId) { +Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *env, jclass, + jbyteArray objectId) { auto object_id = JavaByteArrayToId(env, objectId); - CoreWorkerProcess::GetCoreWorker().PromoteObjectToPlasma(object_id); rpc::Address address; // TODO(ekl) send serialized object status to Java land. std::string serialized_object_status; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h index 8001bbf20df06..9358f4473c228 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h @@ -105,13 +105,12 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *, jcl /* * Class: io_ray_runtime_object_NativeObjectStore - * Method: nativePromoteAndGetOwnershipInfo + * Method: nativeGetOwnershipInfo * Signature: ([B)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativePromoteAndGetOwnershipInfo(JNIEnv *, - jclass, - jbyteArray); +Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *, jclass, + jbyteArray); /* * Class: io_ray_runtime_object_NativeObjectStore diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index dd05bc76aa6e0..56a0ad473c64d 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -235,9 +235,9 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env, ray_namespace, /*is_asyncio=*/false, placement_options, - true, - "{}", - {}, + /*placement_group_capture_child_tasks=*/true, + /*serialized_runtime_env=*/"{}", + /*runtime_env_uris=*/{}, concurrency_groups}; return actor_creation_options; } diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 58c67a2010213..e0a9a783dd657 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -14,8 +14,6 @@ #pragma once -#include - #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index 5877d7f654dfc..c265bc7af753b 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -16,9 +16,9 @@ #include +#include "absl/functional/bind_front.h" #include "gmock/gmock.h" #include "gtest/gtest.h" - #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/asio/periodical_runner.h" #include "ray/common/ray_object.h" @@ -270,7 +270,7 @@ class MockWorkerClient : public MockCoreWorkerClientInterface { auto borrower_callback = [=]() { auto ref_removed_callback = - boost::bind(&ReferenceCounter::HandleRefRemoved, &rc_, _1); + absl::bind_front(&ReferenceCounter::HandleRefRemoved, &rc_); rc_.SetRefRemovedCallback(object_id, contained_in_id, owner_address, ref_removed_callback); }; @@ -656,7 +656,7 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { auto subscriber = std::make_shared(); auto rc = std::shared_ptr(new ReferenceCounter( rpc::WorkerAddress(rpc::Address()), publisher.get(), subscriber.get())); - CoreWorkerMemoryStore store(nullptr, rc); + CoreWorkerMemoryStore store(rc); // Tests putting an object with no references is ignored. RAY_CHECK(store.Put(buffer, id2)); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 680c9c13616bc..b32b612166820 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -139,13 +139,11 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { } CoreWorkerMemoryStore::CoreWorkerMemoryStore( - std::function store_in_plasma, std::shared_ptr counter, std::shared_ptr raylet_client, std::function check_signals, std::function unhandled_exception_handler) - : store_in_plasma_(store_in_plasma), - ref_counter_(counter), + : ref_counter_(std::move(counter)), raylet_client_(raylet_client), check_signals_(check_signals), unhandled_exception_handler_(unhandled_exception_handler) {} @@ -186,24 +184,6 @@ std::shared_ptr CoreWorkerMemoryStore::GetIfExists(const ObjectID &ob return ptr; } -std::shared_ptr CoreWorkerMemoryStore::GetOrPromoteToPlasma( - const ObjectID &object_id) { - absl::MutexLock lock(&mu_); - auto iter = objects_.find(object_id); - if (iter != objects_.end()) { - auto obj = iter->second; - obj->SetAccessed(); - if (obj->IsInPlasmaError()) { - return nullptr; - } - return obj; - } - RAY_CHECK(store_in_plasma_ != nullptr) - << "Cannot promote object without plasma provider callback."; - promoted_to_plasma_.insert(object_id); - return nullptr; -} - bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_id) { std::vector)>> async_callbacks; auto object_entry = std::make_shared(object.GetData(), object.GetMetadata(), @@ -212,7 +192,6 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ // TODO(edoakes): we should instead return a flag to the caller to put the object in // plasma. - bool should_put_in_plasma = false; { absl::MutexLock lock(&mu_); @@ -228,15 +207,6 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ object_async_get_requests_.erase(async_callback_it); } - auto promoted_it = promoted_to_plasma_.find(object_id); - if (promoted_it != promoted_to_plasma_.end()) { - RAY_CHECK(store_in_plasma_ != nullptr); - // Only need to promote to plasma if it wasn't already put into plasma - // by the task that created the object. - should_put_in_plasma = !object.IsInPlasmaError(); - promoted_to_plasma_.erase(promoted_it); - } - bool should_add_entry = true; auto object_request_iter = object_get_requests_.find(object_id); if (object_request_iter != object_get_requests_.end()) { @@ -268,14 +238,6 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ } } - // Must be called without holding the lock because store_in_plasma_ goes - // through the regular CoreWorker::Put() codepath, which calls into the - // in-memory store (would cause deadlock). - if (should_put_in_plasma) { - store_in_plasma_(object, object_id); - stored_in_direct_memory = false; - } - // It's important for performance to run the callbacks outside the lock. for (const auto &cb : async_callbacks) { cb(object_entry); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 542fac1ea2ea6..70bebac7f01a5 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -44,12 +44,10 @@ class CoreWorkerMemoryStore { public: /// Create a memory store. /// - /// \param[in] store_in_plasma If not null, this is used to spill to plasma. /// \param[in] counter If not null, this enables ref counting for local objects, /// and the `remove_after_get` flag for Get() will be ignored. /// \param[in] raylet_client If not null, used to notify tasks blocked / unblocked. CoreWorkerMemoryStore( - std::function store_in_plasma = nullptr, std::shared_ptr counter = nullptr, std::shared_ptr raylet_client = nullptr, std::function check_signals = nullptr, @@ -104,14 +102,6 @@ class CoreWorkerMemoryStore { void GetAsync(const ObjectID &object_id, std::function)> callback); - /// Get a single object if available. If the object is not local yet, or if the object - /// is local but is ErrorType::OBJECT_IN_PLASMA, then nullptr will be returned, and - /// the store will ensure the object is promoted to plasma once available. - /// - /// \param[in] object_id The object id to get. - /// \return pointer to the local object, or nullptr if promoted to plasma. - std::shared_ptr GetOrPromoteToPlasma(const ObjectID &object_id); - /// Delete a list of objects from the object store. /// NOTE(swang): Objects that contain IsInPlasmaError will not be /// deleted from the in-memory store. Instead, any future Get @@ -187,9 +177,6 @@ class CoreWorkerMemoryStore { /// properly. void EraseObjectAndUpdateStats(const ObjectID &object_id) EXCLUSIVE_LOCKS_REQUIRED(mu_); - /// Optional callback for putting objects into the plasma store. - std::function store_in_plasma_; - /// If enabled, holds a reference to local worker ref counter. TODO(ekl) make this /// mandatory once Java is supported. std::shared_ptr ref_counter_ = nullptr; @@ -200,9 +187,6 @@ class CoreWorkerMemoryStore { /// Protects the data structures below. mutable absl::Mutex mu_; - /// Set of objects that should be promoted to plasma once available. - absl::flat_hash_set promoted_to_plasma_ GUARDED_BY(mu_); - /// Map from object ID to `RayObject`. /// NOTE: This map should be modified by EmplaceObjectAndUpdateStats and /// EraseObjectAndUpdateStats. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 29d95cb8fa9b8..3e0ddd631d45f 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include "absl/container/flat_hash_map.h" diff --git a/src/ray/core_worker/test/direct_task_transport_mock_test.cc b/src/ray/core_worker/test/direct_task_transport_mock_test.cc index 0af5c20c4eb15..8312d79a0bc43 100644 --- a/src/ray/core_worker/test/direct_task_transport_mock_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_mock_test.cc @@ -28,7 +28,7 @@ using namespace ::testing; class DirectTaskTransportTest : public ::testing::Test { public: void SetUp() override { - raylet_client = std::make_shared(); + raylet_client = std::make_shared(); task_finisher = std::make_shared(); actor_creator = std::make_shared(); lease_policy = std::make_shared(); @@ -57,7 +57,7 @@ class DirectTaskTransportTest : public ::testing::Test { } std::unique_ptr task_submitter; - std::shared_ptr raylet_client; + std::shared_ptr raylet_client; std::shared_ptr task_finisher; std::shared_ptr actor_creator; std::shared_ptr lease_policy; diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 473136255bc72..b631b1d372177 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -153,6 +153,19 @@ class MockRayletClient : public WorkerLeaseInterface { return Status::OK(); } + void ReportWorkerBacklog( + const WorkerID &worker_id, + const std::vector &backlog_reports) override { + reported_backlog_size = 0; + reported_backlogs.clear(); + for (const auto &backlog_report : backlog_reports) { + reported_backlog_size += backlog_report.backlog_size(); + const TaskSpecification resource_spec(backlog_report.resource_spec()); + const SchedulingClass scheduling_class = resource_spec.GetSchedulingClass(); + reported_backlogs[scheduling_class] = backlog_report.backlog_size(); + } + } + void RequestWorkerLease( const TaskSpecification &resource_spec, const rpc::ClientCallback &callback, @@ -161,6 +174,14 @@ class MockRayletClient : public WorkerLeaseInterface { callbacks.push_back(callback); } + void RequestWorkerLease( + const rpc::TaskSpec &task_spec, + const ray::rpc::ClientCallback &callback, + const int64_t backlog_size = -1) override { + num_workers_requested += 1; + callbacks.push_back(callback); + } + void ReleaseUnusedWorkers( const std::vector &workers_in_use, const rpc::ClientCallback &callback) override {} @@ -222,6 +243,8 @@ class MockRayletClient : public WorkerLeaseInterface { int num_workers_returned = 0; int num_workers_disconnected = 0; int num_leases_canceled = 0; + int reported_backlog_size = 0; + std::map reported_backlogs; std::list> callbacks = {}; std::list> cancel_callbacks = {}; }; @@ -246,11 +269,18 @@ class MockActorCreator : public ActorCreatorInterface { } void AsyncWaitForActorRegisterFinish(const ActorID &, - gcs::StatusCallback callback) override {} + gcs::StatusCallback callback) override { + callbacks.push_back(callback); + } - bool IsActorInRegistering(const ActorID &actor_id) const override { return false; } + [[nodiscard]] bool IsActorInRegistering(const ActorID &actor_id) const override { + return actor_pending; + } ~MockActorCreator() {} + + std::list callbacks; + bool actor_pending = false; }; class MockLeasePolicy : public LeasePolicyInterface { @@ -272,30 +302,6 @@ class MockLeasePolicy : public LeasePolicyInterface { int num_lease_policy_consults = 0; }; -TEST(TestMemoryStore, TestPromoteToPlasma) { - size_t num_plasma_puts = 0; - auto mem = std::make_shared( - [&](const RayObject &obj, const ObjectID &obj_id) { num_plasma_puts += 1; }); - ObjectID obj1 = ObjectID::FromRandom(); - ObjectID obj2 = ObjectID::FromRandom(); - auto data = GenerateRandomObject(); - ASSERT_TRUE(mem->Put(*data, obj1)); - - // Test getting an already existing object. - ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj1) != nullptr); - ASSERT_TRUE(num_plasma_puts == 0); - - // Testing getting an object that doesn't exist yet causes promotion. - ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) == nullptr); - ASSERT_TRUE(num_plasma_puts == 0); - ASSERT_FALSE(mem->Put(*data, obj2)); - ASSERT_TRUE(num_plasma_puts == 1); - - // The next time you get it, it's already there so no need to promote. - ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) != nullptr); - ASSERT_TRUE(num_plasma_puts == 1); -} - TEST(LocalDependencyResolverTest, TestNoDependencies) { auto store = std::make_shared(); auto task_finisher = std::make_shared(); @@ -308,6 +314,77 @@ TEST(LocalDependencyResolverTest, TestNoDependencies) { ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); } +TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies1) { + // Actor dependency resolved first. + auto store = std::make_shared(); + auto task_finisher = std::make_shared(); + MockActorCreator actor_creator; + LocalDependencyResolver resolver(*store, *task_finisher, actor_creator); + TaskSpecification task; + ObjectID obj = ObjectID::FromRandom(); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); + + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id); + task.GetMutableMessage().add_args()->add_nested_inlined_refs()->set_object_id( + actor_handle_id.Binary()); + + int num_resolved = 0; + actor_creator.actor_pending = true; + resolver.ResolveDependencies(task, [&](const Status &) { num_resolved++; }); + ASSERT_EQ(num_resolved, 0); + ASSERT_EQ(resolver.NumPendingTasks(), 1); + + for (const auto &cb : actor_creator.callbacks) { + cb(Status()); + } + ASSERT_EQ(num_resolved, 0); + + std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + auto data = RayObject(nullptr, meta_buffer, std::vector()); + ASSERT_TRUE(store->Put(data, obj)); + ASSERT_EQ(num_resolved, 1); + + ASSERT_EQ(resolver.NumPendingTasks(), 0); +} + +TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies2) { + // Object dependency resolved first. + auto store = std::make_shared(); + auto task_finisher = std::make_shared(); + MockActorCreator actor_creator; + LocalDependencyResolver resolver(*store, *task_finisher, actor_creator); + TaskSpecification task; + ObjectID obj = ObjectID::FromRandom(); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); + + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id); + task.GetMutableMessage().add_args()->add_nested_inlined_refs()->set_object_id( + actor_handle_id.Binary()); + + int num_resolved = 0; + actor_creator.actor_pending = true; + resolver.ResolveDependencies(task, [&](const Status &) { num_resolved++; }); + ASSERT_EQ(num_resolved, 0); + ASSERT_EQ(resolver.NumPendingTasks(), 1); + + std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + auto data = RayObject(nullptr, meta_buffer, std::vector()); + ASSERT_EQ(num_resolved, 0); + ASSERT_TRUE(store->Put(data, obj)); + + for (const auto &cb : actor_creator.callbacks) { + cb(Status()); + } + ASSERT_EQ(num_resolved, 1); + ASSERT_EQ(resolver.NumPendingTasks(), 0); +} + TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { auto store = std::make_shared(); auto task_finisher = std::make_shared(); @@ -563,9 +640,78 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator); + CoreWorkerDirectTaskSubmitter submitter( + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 2); + + TaskSpecification task1 = BuildEmptyTaskSpec(); + TaskSpecification task2 = BuildEmptyTaskSpec(); + TaskSpecification task3 = BuildEmptyTaskSpec(); + + ASSERT_TRUE(submitter.SubmitTask(task1).ok()); + ASSERT_TRUE(submitter.SubmitTask(task2).ok()); + ASSERT_TRUE(submitter.SubmitTask(task3).ok()); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_returned, 0); + ASSERT_EQ(raylet_client->reported_backlog_size, 0); + ASSERT_EQ(worker_client->callbacks.size(), 0); + + // Trigger the periodic backlog report + submitter.ReportWorkerBacklog(); + ASSERT_EQ(raylet_client->reported_backlog_size, 1); + + // Task 1 is pushed; worker 3 is requested. + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, NodeID::Nil())); + ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 3); + ASSERT_EQ(raylet_client->num_workers_requested, 3); + ASSERT_EQ(raylet_client->reported_backlog_size, 0); + + // Task 2 is pushed; no more workers requested. + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); + ASSERT_EQ(worker_client->callbacks.size(), 2); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 3); + ASSERT_EQ(raylet_client->num_workers_requested, 3); + ASSERT_EQ(raylet_client->reported_backlog_size, 0); + + // Task 3 is pushed; no more workers requested. + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, NodeID::Nil())); + ASSERT_EQ(worker_client->callbacks.size(), 3); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 3); + ASSERT_EQ(raylet_client->num_workers_requested, 3); + ASSERT_EQ(raylet_client->reported_backlog_size, 0); + + // All workers returned. + while (!worker_client->callbacks.empty()) { + ASSERT_TRUE(worker_client->ReplyPushTask()); + } + ASSERT_EQ(raylet_client->num_workers_returned, 3); + ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 3); + ASSERT_EQ(task_finisher->num_tasks_failed, 0); + ASSERT_EQ(raylet_client->num_leases_canceled, 0); + ASSERT_EQ(raylet_client->reported_backlog_size, 0); + ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); + + // Check that there are no entries left in the scheduling_key_entries_ hashmap. These + // would otherwise cause a memory leak. + ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic()); +} + +TEST(DirectTaskTransportTest, TestSubmitMultipleTasks) { + rpc::Address address; + auto raylet_client = std::make_shared(); + auto worker_client = std::make_shared(); + auto store = std::make_shared(); + auto client_pool = std::make_shared( + [&](const rpc::Address &addr) { return worker_client; }); + auto task_finisher = std::make_shared(); + auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); + CoreWorkerDirectTaskSubmitter submitter( + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -576,18 +722,21 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { ASSERT_TRUE(submitter.SubmitTask(task3).ok()); ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); ASSERT_EQ(raylet_client->num_workers_requested, 1); + ASSERT_EQ(raylet_client->reported_backlog_size, 0); // Task 1 is pushed; worker 2 is requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 1); ASSERT_EQ(lease_policy->num_lease_policy_consults, 2); ASSERT_EQ(raylet_client->num_workers_requested, 2); + ASSERT_EQ(raylet_client->reported_backlog_size, 1); // Task 2 is pushed; worker 3 is requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 2); ASSERT_EQ(lease_policy->num_lease_policy_consults, 3); ASSERT_EQ(raylet_client->num_workers_requested, 3); + ASSERT_EQ(raylet_client->reported_backlog_size, 0); // Task 3 is pushed; no more workers requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, NodeID::Nil())); @@ -604,6 +753,7 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { ASSERT_EQ(task_finisher->num_tasks_complete, 3); ASSERT_EQ(task_finisher->num_tasks_failed, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 0); + ASSERT_EQ(raylet_client->reported_backlog_size, 0); ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); // Check that there are no entries left in the scheduling_key_entries_ hashmap. These @@ -621,9 +771,9 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator); + CoreWorkerDirectTaskSubmitter submitter( + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -684,9 +834,9 @@ TEST(DirectTaskTransportTest, TestRetryLeaseCancellation) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator); + CoreWorkerDirectTaskSubmitter submitter( + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); TaskSpecification task3 = BuildEmptyTaskSpec(); @@ -797,9 +947,9 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator); + CoreWorkerDirectTaskSubmitter submitter( + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -1013,9 +1163,9 @@ void TestSchedulingKey(const std::shared_ptr store, auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator); + CoreWorkerDirectTaskSubmitter submitter( + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); ASSERT_TRUE(submitter.SubmitTask(same1).ok()); ASSERT_TRUE(submitter.SubmitTask(same2).ok()); @@ -1130,6 +1280,65 @@ TEST(DirectTaskTransportTest, TestSchedulingKeys) { TestSchedulingKey(store, same_deps_1, same_deps_2, different_deps); } +TEST(DirectTaskTransportTest, TestBacklogReport) { + rpc::Address address; + auto raylet_client = std::make_shared(); + auto worker_client = std::make_shared(); + auto store = std::make_shared(); + auto client_pool = std::make_shared( + [&](const rpc::Address &addr) { return worker_client; }); + auto task_finisher = std::make_shared(); + auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); + CoreWorkerDirectTaskSubmitter submitter( + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); + + TaskSpecification task1 = BuildEmptyTaskSpec(); + + std::unordered_map resources1({{"a", 1.0}}); + std::unordered_map resources2({{"b", 2.0}}); + FunctionDescriptor descriptor1 = + FunctionDescriptorBuilder::BuildPython("a", "", "", ""); + FunctionDescriptor descriptor2 = + FunctionDescriptorBuilder::BuildPython("b", "", "", ""); + ObjectID plasma1 = ObjectID::FromRandom(); + ObjectID plasma2 = ObjectID::FromRandom(); + // Force plasma objects to be promoted. + std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + auto plasma_data = RayObject(nullptr, meta_buffer, std::vector()); + ASSERT_TRUE(store->Put(plasma_data, plasma1)); + ASSERT_TRUE(store->Put(plasma_data, plasma2)); + + // Same SchedulingClass, different SchedulingKey + TaskSpecification task2 = BuildTaskSpec(resources1, descriptor1); + task2.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + plasma1.Binary()); + TaskSpecification task3 = BuildTaskSpec(resources1, descriptor1); + task3.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + plasma2.Binary()); + TestSchedulingKey(store, task2, task2, task3); + + TaskSpecification task4 = BuildTaskSpec(resources2, descriptor2); + + ASSERT_TRUE(submitter.SubmitTask(task1).ok()); + // One is requested and one is in the backlog for each SchedulingKey + ASSERT_TRUE(submitter.SubmitTask(task2).ok()); + ASSERT_TRUE(submitter.SubmitTask(task2).ok()); + ASSERT_TRUE(submitter.SubmitTask(task3).ok()); + ASSERT_TRUE(submitter.SubmitTask(task3).ok()); + ASSERT_TRUE(submitter.SubmitTask(task4).ok()); + ASSERT_TRUE(submitter.SubmitTask(task4).ok()); + + submitter.ReportWorkerBacklog(); + ASSERT_EQ(raylet_client->reported_backlogs.size(), 3); + ASSERT_EQ(raylet_client->reported_backlogs[task1.GetSchedulingClass()], 0); + ASSERT_EQ(raylet_client->reported_backlogs[task2.GetSchedulingClass()], 2); + ASSERT_EQ(raylet_client->reported_backlogs[task4.GetSchedulingClass()], 1); +} + TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { rpc::Address address; auto raylet_client = std::make_shared(); @@ -1140,10 +1349,10 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), - /*lease_timeout_ms=*/5, actor_creator); + CoreWorkerDirectTaskSubmitter submitter( + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), + /*lease_timeout_ms=*/5, actor_creator, 1, absl::nullopt, 1); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); TaskSpecification task3 = BuildEmptyTaskSpec(); @@ -1322,7 +1531,8 @@ TEST(DirectTaskTransportTest, TestPipeliningConcurrentWorkerLeases) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker, + absl::nullopt, 1); // Prepare 20 tasks and save them in a vector. std::vector tasks; @@ -1396,7 +1606,8 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker, + absl::nullopt, 2); // prepare 30 tasks and save them in a vector std::vector tasks; @@ -1405,16 +1616,16 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) { } ASSERT_EQ(tasks.size(), 30); - // Submit the 30 tasks and check that one worker is requested + // Submit the 30 tasks and check that two workers are requested for (auto task : tasks) { ASSERT_TRUE(submitter.SubmitTask(task).ok()); } - ASSERT_EQ(raylet_client->num_workers_requested, 1); + ASSERT_EQ(raylet_client->num_workers_requested, 2); // Task 1-10 are pushed, and a new worker is requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 10); - ASSERT_EQ(raylet_client->num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 3); // The lease is not cancelled, as there is more work to do ASSERT_EQ(raylet_client->num_leases_canceled, 0); @@ -1441,7 +1652,7 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) { ASSERT_EQ(raylet_client->num_workers_returned, 1); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(task_finisher->num_tasks_complete, 30); - ASSERT_EQ(raylet_client->num_leases_canceled, 1); + ASSERT_EQ(raylet_client->num_leases_canceled, 2); ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease()); // The second lease request is returned immediately. @@ -1451,8 +1662,19 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) { ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 30); ASSERT_EQ(task_finisher->num_tasks_failed, 0); - ASSERT_EQ(raylet_client->num_leases_canceled, 1); - ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); + ASSERT_EQ(raylet_client->num_workers_requested, 3); + ASSERT_EQ(raylet_client->num_leases_canceled, 3); + ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease()); + + // The third lease request is returned immediately. + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); + ASSERT_EQ(worker_client->callbacks.size(), 0); + ASSERT_EQ(raylet_client->num_workers_returned, 3); + ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 30); + ASSERT_EQ(task_finisher->num_tasks_failed, 0); + ASSERT_EQ(raylet_client->num_leases_canceled, 3); + ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease()); // Check that there are no entries left in the scheduling_key_entries_ hashmap. These // would otherwise cause a memory leak. @@ -1476,7 +1698,8 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker, + absl::nullopt, 1); // prepare 30 tasks and save them in a vector std::vector tasks; @@ -1661,7 +1884,8 @@ TEST(DirectTaskTransportTest, TestStealingTasks) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker, + absl::nullopt, 1); // prepare 20 tasks and save them in a vector std::vector tasks; @@ -1841,7 +2065,8 @@ TEST(DirectTaskTransportTest, TestNoStealingByExpiredWorker) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), 1000, actor_creator, max_tasks_in_flight_per_worker); + NodeID::Nil(), 1000, actor_creator, max_tasks_in_flight_per_worker, absl::nullopt, + 1); // prepare 30 tasks and save them in a vector std::vector tasks; @@ -1979,23 +2204,24 @@ TEST(DirectTaskTransportTest, TestNoWorkerRequestedIfStealingUnavailable) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker, + absl::nullopt, 2); - // prepare 2 tasks and save them in a vector + // prepare 10 tasks and save them in a vector std::vector tasks; for (int i = 0; i < 10; i++) { tasks.push_back(BuildEmptyTaskSpec()); } ASSERT_EQ(tasks.size(), 10); - // submit both tasks + // submit all tasks for (int i = 1; i <= 10; i++) { auto task = tasks.front(); ASSERT_TRUE(submitter.SubmitTask(task).ok()); tasks.erase(tasks.begin()); } ASSERT_EQ(tasks.size(), 0); - ASSERT_EQ(raylet_client->num_workers_requested, 1); + ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(task_finisher->num_tasks_complete, 0); ASSERT_EQ(task_finisher->num_tasks_failed, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 0); @@ -2006,7 +2232,7 @@ TEST(DirectTaskTransportTest, TestNoWorkerRequestedIfStealingUnavailable) { std::string worker1_id = "worker1_ID_abcdefghijklmnopq"; ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil(), false, worker1_id)); - ASSERT_EQ(raylet_client->num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 3); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 0); @@ -2020,7 +2246,7 @@ TEST(DirectTaskTransportTest, TestNoWorkerRequestedIfStealingUnavailable) { ASSERT_TRUE(worker_client->ReplyPushTask()); } - ASSERT_EQ(raylet_client->num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 3); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 9); @@ -2036,23 +2262,23 @@ TEST(DirectTaskTransportTest, TestNoWorkerRequestedIfStealingUnavailable) { worker2_id)); // Check that no more workers are requested now that there are no more stealable tasks. - ASSERT_EQ(raylet_client->num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 3); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_returned, 1); ASSERT_EQ(task_finisher->num_tasks_complete, 9); ASSERT_EQ(task_finisher->num_tasks_failed, 0); - ASSERT_EQ(raylet_client->num_leases_canceled, 0); + ASSERT_EQ(raylet_client->num_leases_canceled, 1); ASSERT_EQ(worker_client->callbacks.size(), 1); ASSERT_EQ(worker_client->steal_callbacks.size(), 0); // Last task runs and first worker is returned ASSERT_TRUE(worker_client->ReplyPushTask()); - ASSERT_EQ(raylet_client->num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 3); ASSERT_EQ(raylet_client->num_workers_returned, 2); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 10); ASSERT_EQ(task_finisher->num_tasks_failed, 0); - ASSERT_EQ(raylet_client->num_leases_canceled, 0); + ASSERT_EQ(raylet_client->num_leases_canceled, 2); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(worker_client->steal_callbacks.size(), 0); } diff --git a/src/ray/core_worker/test/memory_store_test.cc b/src/ray/core_worker/test/memory_store_test.cc index 84a7c8f7996ac..feee9973db850 100644 --- a/src/ray/core_worker/test/memory_store_test.cc +++ b/src/ray/core_worker/test/memory_store_test.cc @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/synchronization/mutex.h" - #include "ray/core_worker/store_provider/memory_store/memory_store.h" +#include "absl/synchronization/mutex.h" #include "gtest/gtest.h" #include "ray/common/test_util.h" @@ -29,8 +28,7 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { std::shared_ptr provider = std::make_shared( - nullptr, nullptr, nullptr, nullptr, - [&](const RayObject &obj) { unhandled_count++; }); + nullptr, nullptr, nullptr, [&](const RayObject &obj) { unhandled_count++; }); RayObject obj1(rpc::ErrorType::TASK_EXECUTION_EXCEPTION); RayObject obj2(rpc::ErrorType::TASK_EXECUTION_EXCEPTION); auto id1 = ObjectID::FromRandom(); @@ -52,7 +50,7 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { RAY_CHECK(provider->Put(obj1, id1)); RAY_CHECK(provider->Put(obj1, id2)); RAY_UNUSED(provider->Get({id1}, 1, 100, context, false, &results)); - provider->GetOrPromoteToPlasma(id2); + RAY_UNUSED(provider->Get({id2}, 1, 100, context, false, &results)); provider->Delete({id1, id2}); ASSERT_EQ(unhandled_count, 0); @@ -68,8 +66,7 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { TEST(TestMemoryStore, TestMemoryStoreStats) { /// Simple validation for test memory store stats. std::shared_ptr provider = - std::make_shared(nullptr, nullptr, nullptr, nullptr, - nullptr); + std::make_shared(nullptr, nullptr, nullptr, nullptr); // Iterate through the memory store and compare the values that are obtained by // GetMemoryStoreStatisticalData. diff --git a/src/ray/core_worker/transport/dependency_resolver.cc b/src/ray/core_worker/transport/dependency_resolver.cc index 3948c3732f1c4..da52aff657627 100644 --- a/src/ray/core_worker/transport/dependency_resolver.cc +++ b/src/ray/core_worker/transport/dependency_resolver.cc @@ -139,11 +139,13 @@ void LocalDependencyResolver::ResolveDependencies( for (const auto &actor_id : state->actor_dependencies) { actor_creator_.AsyncWaitForActorRegisterFinish( - actor_id, [state, on_complete](Status status) { + actor_id, [this, state, on_complete](const Status &status) { if (!status.ok()) { state->status = status; } - if (--state->actor_dependencies_remaining == 0) { + if (--state->actor_dependencies_remaining == 0 && + state->obj_dependencies_remaining == 0) { + num_pending_--; on_complete(state->status); } }); diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index 508f205337566..912abca3b7aed 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -374,15 +374,14 @@ void CoreWorkerDirectTaskSubmitter::CancelWorkerLeaseIfNeeded( RAY_LOG(DEBUG) << "Task queue is empty, and there are no stealable tasks; canceling lease request"; - auto &pending_lease_request = scheduling_key_entry.pending_lease_request; - if (pending_lease_request.first) { + for (auto &pending_lease_request : scheduling_key_entry.pending_lease_requests) { // There is an in-flight lease request. Cancel it. - auto &lease_client = pending_lease_request.first; - auto &lease_id = pending_lease_request.second; - RAY_LOG(DEBUG) << "Canceling lease request " << lease_id; + auto lease_client = GetOrConnectLeaseClient(&pending_lease_request.second); + auto &task_id = pending_lease_request.first; + RAY_LOG(DEBUG) << "Canceling lease request " << task_id; lease_client->CancelWorkerLease( - lease_id, [this, scheduling_key](const Status &status, - const rpc::CancelWorkerLeaseReply &reply) { + task_id, [this, scheduling_key](const Status &status, + const rpc::CancelWorkerLeaseReply &reply) { absl::MutexLock lock(&mu_); if (status.ok() && !reply.success()) { // The cancellation request can fail if the raylet does not have @@ -423,15 +422,58 @@ CoreWorkerDirectTaskSubmitter::GetOrConnectLeaseClient( return lease_client; } +void CoreWorkerDirectTaskSubmitter::ReportWorkerBacklog() { + absl::MutexLock lock(&mu_); + ReportWorkerBacklogInternal(); +} + +void CoreWorkerDirectTaskSubmitter::ReportWorkerBacklogInternal() { + absl::flat_hash_map> backlogs; + for (auto &scheduling_key_and_entry : scheduling_key_entries_) { + const SchedulingClass scheduling_class = std::get<0>(scheduling_key_and_entry.first); + if (backlogs.find(scheduling_class) == backlogs.end()) { + backlogs[scheduling_class].first = scheduling_key_and_entry.second.resource_spec; + backlogs[scheduling_class].second = 0; + } + // We report backlog size per scheduling class not per scheduling key + // so we need to aggregate backlog sizes of different scheduling keys + // with the same scheduling class + backlogs[scheduling_class].second += scheduling_key_and_entry.second.BacklogSize(); + scheduling_key_and_entry.second.last_reported_backlog_size = + scheduling_key_and_entry.second.BacklogSize(); + } + + std::vector backlog_reports; + for (const auto &backlog : backlogs) { + rpc::WorkerBacklogReport backlog_report; + backlog_report.mutable_resource_spec()->CopyFrom(backlog.second.first.GetMessage()); + backlog_report.set_backlog_size(backlog.second.second); + backlog_reports.emplace_back(backlog_report); + } + local_lease_client_->ReportWorkerBacklog(WorkerID::FromBinary(rpc_address_.worker_id()), + backlog_reports); +} + +void CoreWorkerDirectTaskSubmitter::ReportWorkerBacklogIfNeeded( + const SchedulingKey &scheduling_key) { + const auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; + + if (scheduling_key_entry.last_reported_backlog_size != + scheduling_key_entry.BacklogSize()) { + ReportWorkerBacklogInternal(); + } +} + void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( const SchedulingKey &scheduling_key, const rpc::Address *raylet_address) { auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; - auto &pending_lease_request = scheduling_key_entry.pending_lease_request; - if (pending_lease_request.first) { - // There's already an outstanding lease request for this type of task. + if (scheduling_key_entry.pending_lease_requests.size() == + max_pending_lease_requests_per_scheduling_category_) { return; } + RAY_CHECK(scheduling_key_entry.pending_lease_requests.size() < + max_pending_lease_requests_per_scheduling_category_); // Check whether we really need a new worker or whether we have // enough room in an existing worker's pipeline to send the new tasks. If the pipelines @@ -444,7 +486,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( return; } - auto &task_queue = scheduling_key_entry.task_queue; + const auto &task_queue = scheduling_key_entry.task_queue; // Check if the task queue is empty. If that is the case, it only makes sense to // consider requesting a new worker if work stealing is enabled, and there is at least a // worker with stealable tasks. If work stealing is not enabled, or there is no tasks @@ -461,15 +503,18 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( } return; } + } else if (scheduling_key_entry.task_queue.size() <= + scheduling_key_entry.pending_lease_requests.size()) { + // All tasks have corresponding pending leases, no need to request more + return; } + num_leases_requested_++; // Create a TaskSpecification with an overwritten TaskID to make sure we don't reuse the // same TaskID to request a worker - num_leases_requested_++; auto resource_spec_msg = scheduling_key_entry.resource_spec.GetMutableMessage(); resource_spec_msg.set_task_id(TaskID::ForFakeTask().Binary()); - TaskSpecification resource_spec = TaskSpecification(resource_spec_msg); - + const TaskSpecification resource_spec = TaskSpecification(resource_spec_msg); rpc::Address best_node_address; if (raylet_address == nullptr) { // If no raylet address is given, find the best worker for our next lease request. @@ -478,22 +523,17 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( } auto lease_client = GetOrConnectLeaseClient(raylet_address); - TaskID task_id = resource_spec.TaskId(); - // Subtract 1 so we don't double count the task we are requesting for. - int64_t queue_size = task_queue.size() - 1; + const TaskID task_id = resource_spec.TaskId(); lease_client->RequestWorkerLease( resource_spec, - [this, scheduling_key](const Status &status, - const rpc::RequestWorkerLeaseReply &reply) { + [this, scheduling_key, task_id, raylet_address = *raylet_address]( + const Status &status, const rpc::RequestWorkerLeaseReply &reply) { absl::MutexLock lock(&mu_); auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; - auto &pending_lease_request = scheduling_key_entry.pending_lease_request; - RAY_CHECK(pending_lease_request.first); - auto lease_client = std::move(pending_lease_request.first); - const auto task_id = pending_lease_request.second; - pending_lease_request = std::make_pair(nullptr, TaskID::Nil()); + auto lease_client = GetOrConnectLeaseClient(&raylet_address); + scheduling_key_entry.pending_lease_requests.erase(task_id); if (status.ok()) { if (reply.runtime_env_setup_failed()) { @@ -551,8 +591,9 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( RAY_LOG(FATAL) << status.ToString(); } }, - queue_size); - pending_lease_request = std::make_pair(lease_client, task_id); + task_queue.size()); + scheduling_key_entry.pending_lease_requests.emplace(task_id, *raylet_address); + ReportWorkerBacklogIfNeeded(scheduling_key); } void CoreWorkerDirectTaskSubmitter::PushNormalTask( diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 25f7a18912795..7731d95ad6626 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -66,7 +66,9 @@ class CoreWorkerDirectTaskSubmitter { int64_t lease_timeout_ms, std::shared_ptr actor_creator, uint32_t max_tasks_in_flight_per_worker = ::RayConfig::instance().max_tasks_in_flight_per_worker(), - absl::optional cancel_timer = absl::nullopt) + absl::optional cancel_timer = absl::nullopt, + uint64_t max_pending_lease_requests_per_scheduling_category = + ::RayConfig::instance().max_pending_lease_requests_per_scheduling_category()) : rpc_address_(rpc_address), local_lease_client_(lease_client), lease_client_factory_(lease_client_factory), @@ -78,6 +80,8 @@ class CoreWorkerDirectTaskSubmitter { actor_creator_(actor_creator), client_cache_(core_worker_client_pool), max_tasks_in_flight_per_worker_(max_tasks_in_flight_per_worker), + max_pending_lease_requests_per_scheduling_category_( + max_pending_lease_requests_per_scheduling_category), cancel_retry_timer_(std::move(cancel_timer)) {} /// Schedule a task for direct submission to a worker. @@ -107,6 +111,11 @@ class CoreWorkerDirectTaskSubmitter { return num_leases_requested_; } + /// Report worker backlog information to the local raylet. + /// Since each worker only reports to its local rayet + /// we avoid double counting backlogs in autoscaler. + void ReportWorkerBacklog(); + private: /// Schedule more work onto an idle worker or return it back to the raylet if /// no more tasks are queued for submission. If an error was encountered @@ -127,6 +136,14 @@ class CoreWorkerDirectTaskSubmitter { std::shared_ptr GetOrConnectLeaseClient( const rpc::Address *raylet_address) EXCLUSIVE_LOCKS_REQUIRED(mu_); + /// Report worker backlog information to the local raylet + void ReportWorkerBacklogInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Report backlog if the backlog size is changed for this scheduling key + /// since last report + void ReportWorkerBacklogIfNeeded(const SchedulingKey &scheduling_key) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + /// Request a new worker from the raylet if no such requests are currently in /// flight and there are tasks queued. If a raylet address is provided, then /// the worker should be requested from the raylet at that address. Else, the @@ -237,6 +254,9 @@ class CoreWorkerDirectTaskSubmitter { // worker using a single lease. const uint32_t max_tasks_in_flight_per_worker_; + // Max number of pending lease requests per SchedulingKey. + const uint64_t max_pending_lease_requests_per_scheduling_category_; + /// A LeaseEntry struct is used to condense the metadata about a single executor: /// (1) The lease client through which the worker should be returned /// (2) The expiration time of a worker's lease. @@ -296,8 +316,7 @@ class CoreWorkerDirectTaskSubmitter { struct SchedulingKeyEntry { // Keep track of pending worker lease requests to the raylet. - std::pair, TaskID> pending_lease_request = - std::make_pair(nullptr, TaskID::Nil()); + absl::flat_hash_map pending_lease_requests; TaskSpecification resource_spec = TaskSpecification(); // Tasks that are queued for execution. We keep an individual queue per // scheduling class to ensure fairness. @@ -308,11 +327,12 @@ class CoreWorkerDirectTaskSubmitter { absl::flat_hash_set(); // Keep track of how many tasks with this SchedulingKey are in flight, in total uint32_t total_tasks_in_flight = 0; + int64_t last_reported_backlog_size = 0; // Check whether it's safe to delete this SchedulingKeyEntry from the // scheduling_key_entries_ hashmap. inline bool CanDelete() const { - if (!pending_lease_request.first && task_queue.empty() && + if (pending_lease_requests.empty() && task_queue.empty() && active_workers.size() == 0 && total_tasks_in_flight == 0) { return true; } @@ -339,6 +359,18 @@ class CoreWorkerDirectTaskSubmitter { // If any worker has more than one task in flight, then that task can be stolen. return total_tasks_in_flight > active_workers.size(); } + + // Get the current backlog size for this scheduling key + [[nodiscard]] inline int64_t BacklogSize() const { + if (task_queue.size() < pending_lease_requests.size()) { + // During work stealing we may have more pending lease requests than the number of + // queued tasks + return 0; + } + + // Subtract tasks with pending lease requests so we don't double count them. + return task_queue.size() - pending_lease_requests.size(); + } }; // For each Scheduling Key, scheduling_key_entries_ contains a SchedulingKeyEntry struct diff --git a/src/ray/gcs/asio.h b/src/ray/gcs/asio.h index fdcbbbf3cc3ef..d37083986ae1e 100644 --- a/src/ray/gcs/asio.h +++ b/src/ray/gcs/asio.h @@ -38,7 +38,7 @@ #include #include -#include +#include #include #include diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index e3bdcd96d79ab..6e54cb6b4f047 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -731,7 +731,7 @@ Status ServiceBasedNodeResourceInfoAccessor::AsyncUpdateResources( }); }; - sequencer_.Post(node_id, operation); + sequencer_.Post(node_id, std::move(operation)); return Status::OK(); } diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc index 223ee7ca71b52..a4950fabb0f14 100644 --- a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -36,6 +36,7 @@ class GlobalStateAccessorTest : public ::testing::Test { config.grpc_server_name = "MockedGcsServer"; config.grpc_server_thread_num = 1; config.redis_address = "127.0.0.1"; + config.node_ip_address = "127.0.0.1"; config.enable_sharding_conn = false; config.redis_port = TEST_REDIS_SERVER_PORTS.front(); diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc index 0e51ca7b84cce..0adf74b5c4e8b 100644 --- a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -47,6 +47,7 @@ class ServiceBasedGcsClientTest : public ::testing::Test { config_.grpc_server_name = "MockedGcsServer"; config_.grpc_server_thread_num = 1; config_.redis_address = "127.0.0.1"; + config_.node_ip_address = "127.0.0.1"; config_.enable_sharding_conn = false; config_.redis_port = TEST_REDIS_SERVER_PORTS.front(); // Tests legacy code paths. The poller and broadcaster have their own dedicated unit diff --git a/src/ray/gcs/gcs_server/gcs_actor_distribution.cc b/src/ray/gcs/gcs_server/gcs_actor_distribution.cc index 6eb523cdf730b..bec0fb7b89f7c 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_distribution.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_distribution.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/gcs/gcs_server/gcs_actor_distribution.h" + #include "ray/util/event.h" namespace ray { @@ -49,6 +50,9 @@ GcsBasedActorScheduler::GcsBasedActorScheduler( gcs_resource_scheduler_(std::move(gcs_resource_scheduler)) {} NodeID GcsBasedActorScheduler::SelectNode(std::shared_ptr actor) { + if (actor->GetActorWorkerAssignment()) { + ResetActorWorkerAssignment(actor.get()); + } // TODO(Chong-Li): Java actors may not need a sole assignment (worker process). bool need_sole_actor_worker_assignment = true; if (auto selected_actor_worker_assignment = SelectOrAllocateActorWorkerAssignment( @@ -221,5 +225,31 @@ void GcsBasedActorScheduler::HandleWorkerLeaseRejectedReply( Reschedule(actor); } +void GcsBasedActorScheduler::AddResourcesChangedListener(std::function listener) { + RAY_CHECK(listener != nullptr); + resource_changed_listeners_.emplace_back(std::move(listener)); +} + +void GcsBasedActorScheduler::NotifyClusterResourcesChanged() { + for (auto &listener : resource_changed_listeners_) { + listener(); + } +} + +void GcsBasedActorScheduler::ResetActorWorkerAssignment(GcsActor *actor) { + if (gcs_resource_manager_->ReleaseResources( + actor->GetActorWorkerAssignment()->GetNodeID(), + actor->GetActorWorkerAssignment()->GetResources())) { + NotifyClusterResourcesChanged(); + }; + actor->SetActorWorkerAssignment(nullptr); +} + +void GcsBasedActorScheduler::OnActorDestruction(std::shared_ptr actor) { + if (actor && actor->GetActorWorkerAssignment()) { + ResetActorWorkerAssignment(actor.get()); + } +} + } // namespace gcs } // namespace ray \ No newline at end of file diff --git a/src/ray/gcs/gcs_server/gcs_actor_distribution.h b/src/ray/gcs/gcs_server/gcs_actor_distribution.h index b8e2b6b2bd6d4..55f0f492e9a74 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_distribution.h +++ b/src/ray/gcs/gcs_server/gcs_actor_distribution.h @@ -93,6 +93,14 @@ class GcsBasedActorScheduler : public GcsActorScheduler { virtual ~GcsBasedActorScheduler() = default; + /// Handle the destruction of an actor. + /// + /// \param actor The actor to be destoryed. + void OnActorDestruction(std::shared_ptr actor) override; + + /// Add resources changed event handler. + void AddResourcesChangedListener(std::function listener); + protected: /// Select a node for the actor based on cluster resources. /// @@ -143,8 +151,17 @@ class GcsBasedActorScheduler : public GcsActorScheduler { void HandleWorkerLeaseRejectedReply(std::shared_ptr actor, const rpc::RequestWorkerLeaseReply &reply); + /// Reset the actor's current assignment, while releasing acquired resources. + void ResetActorWorkerAssignment(GcsActor *actor); + + /// Notify that the cluster resources are changed. + void NotifyClusterResourcesChanged(); + std::shared_ptr gcs_resource_manager_; + /// The resource changed listeners. + std::vector> resource_changed_listeners_; + /// Gcs resource scheduler std::shared_ptr gcs_resource_scheduler_; }; diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index c9f8c62375a6f..48a469f1cb377 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -107,6 +107,7 @@ void GcsActor::SetActorWorkerAssignment( ///////////////////////////////////////////////////////////////////////////////////////// GcsActorManager::GcsActorManager( + boost::asio::io_context &io_context, std::shared_ptr scheduler, std::shared_ptr gcs_table_storage, std::shared_ptr gcs_pub_sub, RuntimeEnvManager &runtime_env_manager, @@ -115,7 +116,8 @@ GcsActorManager::GcsActorManager( std::function, boost::posix_time::milliseconds)> run_delayed, const rpc::ClientFactoryFn &worker_client_factory) - : gcs_actor_scheduler_(std::move(scheduler)), + : io_context_(io_context), + gcs_actor_scheduler_(std::move(scheduler)), gcs_table_storage_(std::move(gcs_table_storage)), gcs_pub_sub_(std::move(gcs_pub_sub)), worker_client_factory_(worker_client_factory), @@ -126,6 +128,17 @@ GcsActorManager::GcsActorManager( actor_gc_delay_(RayConfig::instance().gcs_actor_table_min_duration_ms()) { RAY_CHECK(worker_client_factory_); RAY_CHECK(destroy_owned_placement_group_if_needed_); + if (RayConfig::instance().gcs_actor_scheduling_enabled()) { + auto gcs_actor_scheduler = + std::dynamic_pointer_cast(gcs_actor_scheduler_); + gcs_actor_scheduler->AddResourcesChangedListener([this] { + bool posted = GetSchedulePendingActorsPosted(); + if (!posted) { + SetSchedulePendingActorsPosted(true); + io_context_.post([this] { SchedulePendingActors(); }); + } + }); + } } void GcsActorManager::HandleRegisterActor(const rpc::RegisterActorRequest &request, @@ -187,13 +200,13 @@ void GcsActorManager::HandleGetActorInfo(const rpc::GetActorInfoRequest &request const auto ®istered_actor_iter = registered_actors_.find(actor_id); if (registered_actor_iter != registered_actors_.end()) { - reply->mutable_actor_table_data()->CopyFrom( - registered_actor_iter->second->GetActorTableData()); + reply->unsafe_arena_set_allocated_actor_table_data( + registered_actor_iter->second->GetMutableActorTableData()); } else { const auto &destroyed_actor_iter = destroyed_actors_.find(actor_id); if (destroyed_actor_iter != destroyed_actors_.end()) { - reply->mutable_actor_table_data()->CopyFrom( - destroyed_actor_iter->second->GetActorTableData()); + reply->unsafe_arena_set_allocated_actor_table_data( + destroyed_actor_iter->second->GetMutableActorTableData()); } } @@ -210,10 +223,12 @@ void GcsActorManager::HandleGetAllActorInfo(const rpc::GetAllActorInfoRequest &r ++counts_[CountType::GET_ALL_ACTOR_INFO_REQUEST]; if (request.show_dead_jobs() == false) { for (const auto &iter : registered_actors_) { - reply->add_actor_table_data()->CopyFrom(iter.second->GetActorTableData()); + reply->mutable_actor_table_data()->UnsafeArenaAddAllocated( + const_cast(iter.second->GetMutableActorTableData())); } for (const auto &iter : destroyed_actors_) { - reply->add_actor_table_data()->CopyFrom(iter.second->GetActorTableData()); + reply->mutable_actor_table_data()->UnsafeArenaAddAllocated( + const_cast(iter.second->GetMutableActorTableData())); } RAY_LOG(DEBUG) << "Finished getting all actor info."; GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); @@ -227,7 +242,9 @@ void GcsActorManager::HandleGetAllActorInfo(const rpc::GetAllActorInfoRequest &r [reply, send_reply_callback]( const std::unordered_map &result) { for (const auto &pair : result) { - reply->add_actor_table_data()->CopyFrom(pair.second); + // TODO yic: Fix const cast + reply->mutable_actor_table_data()->UnsafeArenaAddAllocated( + const_cast(&pair.second)); } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); RAY_LOG(DEBUG) << "Finished getting all actor info."; @@ -258,7 +275,8 @@ void GcsActorManager::HandleGetNamedActorInfo( RAY_LOG(WARNING) << stream.str(); status = Status::NotFound(stream.str()); } else { - reply->mutable_actor_table_data()->CopyFrom(iter->second->GetActorTableData()); + reply->unsafe_arena_set_allocated_actor_table_data( + iter->second->GetMutableActorTableData()); RAY_LOG(DEBUG) << "Finished getting actor info, job id = " << actor_id.JobId() << ", actor id = " << actor_id; } @@ -275,10 +293,9 @@ void GcsActorManager::HandleListNamedActors(const rpc::ListNamedActorsRequest &r std::vector> actors = ListNamedActors(request.all_namespaces(), ray_namespace); for (const auto &actor : actors) { - rpc::NamedActorInfo named_actor_info; - named_actor_info.set_ray_namespace(actor.first); - named_actor_info.set_name(actor.second); - reply->add_named_actors_list()->CopyFrom(named_actor_info); + auto named_actor_indo = reply->add_named_actors_list(); + named_actor_indo->set_ray_namespace(actor.first); + named_actor_indo->set_name(actor.second); } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); ++counts_[CountType::LIST_NAMED_ACTORS_REQUEST]; @@ -381,13 +398,9 @@ Status GcsActorManager::RegisterActor(const ray::rpc::RegisterActorRequest &requ // owner to determine when the actor should be removed. PollOwnerForActorOutOfScope(actor); } else { - // If it's a detached actor, we need to register the runtime env it used to GC - auto job_id = JobID::FromBinary(request.task_spec().job_id()); - const auto &uris = runtime_env_manager_.GetReferences(job_id.Hex()); - auto actor_id_hex = actor->GetActorID().Hex(); - for (const auto &uri : uris) { - runtime_env_manager_.AddURIReference(actor_id_hex, uri); - } + // If it's a detached actor, we need to register the runtime env it used to GC. + runtime_env_manager_.AddURIReference(actor->GetActorID().Hex(), + request.task_spec().runtime_env()); } // The backend storage is supposed to be reliable, so the status must be ok. @@ -575,6 +588,11 @@ void GcsActorManager::DestroyActor(const ActorID &actor_id) { RAY_LOG(INFO) << "Tried to destroy actor that does not exist " << actor_id; return; } + + if (RayConfig::instance().gcs_actor_scheduling_enabled()) { + gcs_actor_scheduler_->OnActorDestruction(it->second); + } + const auto &task_id = it->second->GetCreationTaskSpecification().TaskId(); it->second->GetMutableActorTableData()->mutable_task_spec()->Clear(); it->second->GetMutableActorTableData()->set_timestamp(current_sys_time_ms()); @@ -957,6 +975,7 @@ void GcsActorManager::OnActorCreationSuccess(const std::shared_ptr &ac } void GcsActorManager::SchedulePendingActors() { + schedule_pending_actors_posted_ = false; if (pending_actors_.empty()) { return; } @@ -968,6 +987,14 @@ void GcsActorManager::SchedulePendingActors() { } } +bool GcsActorManager::GetSchedulePendingActorsPosted() const { + return schedule_pending_actors_posted_; +} + +void GcsActorManager::SetSchedulePendingActorsPosted(bool posted) { + schedule_pending_actors_posted_ = posted; +} + void GcsActorManager::Initialize(const GcsInitData &gcs_init_data) { const auto &jobs = gcs_init_data.Jobs(); std::unordered_map> node_to_workers; diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.h b/src/ray/gcs/gcs_server/gcs_actor_manager.h index 569c9b2b19172..9050eb4dfc9fe 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.h +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.h @@ -21,6 +21,7 @@ #include "ray/common/runtime_env_manager.h" #include "ray/common/task/task_execution_spec.h" #include "ray/common/task/task_spec.h" +#include "ray/gcs/gcs_server/gcs_actor_distribution.h" #include "ray/gcs/gcs_server/gcs_actor_scheduler.h" #include "ray/gcs/gcs_server/gcs_init_data.h" #include "ray/gcs/gcs_server/gcs_table_storage.h" @@ -86,7 +87,8 @@ class GcsActor { break; } - actor_table_data_.set_serialized_runtime_env(task_spec.serialized_runtime_env()); + actor_table_data_.set_serialized_runtime_env( + task_spec.runtime_env().serialized_runtime_env()); } /// Get the node id on which this actor is created. @@ -193,6 +195,7 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// \param gcs_table_storage Used to flush actor data to storage. /// \param gcs_pub_sub Used to publish gcs message. GcsActorManager( + boost::asio::io_context &io_context, std::shared_ptr scheduler, std::shared_ptr gcs_table_storage, std::shared_ptr gcs_pub_sub, RuntimeEnvManager &runtime_env_manager, @@ -341,6 +344,10 @@ class GcsActorManager : public rpc::ActorInfoHandler { std::string DebugString() const; + bool GetSchedulePendingActorsPosted() const; + + void SetSchedulePendingActorsPosted(bool posted); + private: /// A data structure representing an actor's owner. struct Owner { @@ -485,6 +492,7 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// according to its owner, or the owner dies. absl::flat_hash_map> owners_; + boost::asio::io_context &io_context_; /// The scheduler to schedule all registered actors. std::shared_ptr gcs_actor_scheduler_; /// Used to update actor information upon creation, deletion, etc. @@ -508,6 +516,9 @@ class GcsActorManager : public rpc::ActorInfoHandler { run_delayed_; const boost::posix_time::milliseconds actor_gc_delay_; + /// Indicate whether a call of SchedulePendingActors has been posted. + bool schedule_pending_actors_posted_; + // Debug info. enum CountType { REGISTER_ACTOR_REQUEST = 0, diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc index 81d476a80854b..cc4a426cea653 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc @@ -38,7 +38,6 @@ GcsActorScheduler::GcsActorScheduler( gcs_pub_sub_(std::move(gcs_pub_sub)), schedule_failure_handler_(std::move(schedule_failure_handler)), schedule_success_handler_(std::move(schedule_success_handler)), - report_worker_backlog_(RayConfig::instance().report_worker_backlog()), raylet_client_pool_(raylet_client_pool), core_worker_clients_(client_factory) { RAY_CHECK(schedule_failure_handler_ != nullptr && schedule_success_handler_ != nullptr); @@ -230,14 +229,13 @@ void GcsActorScheduler::LeaseWorkerFromNode(std::shared_ptr actor, auto lease_client = GetOrConnectLeaseClient(remote_address); // Actor leases should be sent to the raylet immediately, so we should never build up a // backlog in GCS. - int backlog_size = report_worker_backlog_ ? 0 : -1; lease_client->RequestWorkerLease( - actor->GetCreationTaskSpecification(), + actor->GetActorTableData().task_spec(), [this, actor, node](const Status &status, const rpc::RequestWorkerLeaseReply &reply) { HandleWorkerLeaseReply(actor, node, status, reply); }, - backlog_size); + 0); } void GcsActorScheduler::RetryLeasingWorkerFromNode( diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h index 34d7d3ea3a186..55bd6b6bd73f6 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h @@ -75,6 +75,11 @@ class GcsActorSchedulerInterface { virtual void ReleaseUnusedWorkers( const std::unordered_map> &node_to_workers) = 0; + /// Handle the destruction of an actor. + /// + /// \param actor The actor to be destoryed. + virtual void OnActorDestruction(std::shared_ptr actor) = 0; + virtual ~GcsActorSchedulerInterface() {} }; @@ -146,6 +151,11 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { void ReleaseUnusedWorkers( const std::unordered_map> &node_to_workers) override; + /// Handle the destruction of an actor. + /// + /// \param actor The actor to be destoryed. + void OnActorDestruction(std::shared_ptr actor) override {} + protected: /// The GcsLeasedWorker is kind of abstraction of remote leased worker inside raylet. It /// contains the address of remote leased worker as well as the leased resources and the @@ -302,8 +312,6 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { /// The handler to handle the successful scheduling. std::function, const rpc::PushTaskReply &reply)> schedule_success_handler_; - /// Whether or not to report the backlog of actors waiting to be scheduled. - bool report_worker_backlog_; /// The nodes which are releasing unused workers. absl::flat_hash_set nodes_of_releasing_unused_workers_; /// The cached raylet clients used to communicate with raylet. diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index 91782a712db5b..c84e19372a6b4 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -84,11 +84,15 @@ void GcsNodeManager::HandleUnregisterNode(const rpc::UnregisterNodeRequest &requ void GcsNodeManager::HandleGetAllNodeInfo(const rpc::GetAllNodeInfoRequest &request, rpc::GetAllNodeInfoReply *reply, rpc::SendReplyCallback send_reply_callback) { + // Here the unsafe allocate is safe here, because entry.second's life cycle is longer + // then reply. + // The request will be sent when call send_reply_callback and after that, reply will + // not be used any more. But entry is still valid. for (const auto &entry : alive_nodes_) { - reply->add_node_info_list()->CopyFrom(*entry.second); + reply->mutable_node_info_list()->UnsafeArenaAddAllocated(entry.second.get()); } for (const auto &entry : dead_nodes_) { - reply->add_node_info_list()->CopyFrom(*entry.second); + reply->mutable_node_info_list()->UnsafeArenaAddAllocated(entry.second.get()); } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); ++counts_[CountType::GET_ALL_NODE_INFO_REQUEST]; diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc index f41f9d45bd6e7..7879a9fd71bce 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc @@ -45,22 +45,26 @@ std::string GcsPlacementGroup::GetRayNamespace() const { return placement_group_table_data_.ray_namespace(); } -std::vector> GcsPlacementGroup::GetBundles() const { - const auto &bundles = placement_group_table_data_.bundles(); - std::vector> ret_bundles; - for (const auto &bundle : bundles) { - ret_bundles.push_back(std::make_shared(bundle)); +std::vector> &GcsPlacementGroup::GetBundles() + const { + // Fill the cache if it wasn't. + if (cached_bundle_specs_.empty()) { + const auto &bundles = placement_group_table_data_.bundles(); + for (const auto &bundle : bundles) { + cached_bundle_specs_.push_back(std::make_shared(bundle)); + } } - return ret_bundles; + return cached_bundle_specs_; } -std::vector> GcsPlacementGroup::GetUnplacedBundles() - const { - const auto &bundles = placement_group_table_data_.bundles(); - std::vector> unplaced_bundles; - for (const auto &bundle : bundles) { - if (NodeID::FromBinary(bundle.node_id()).IsNil()) { - unplaced_bundles.push_back(std::make_shared(bundle)); +std::vector> +GcsPlacementGroup::GetUnplacedBundles() const { + const auto &bundle_specs = GetBundles(); + + std::vector> unplaced_bundles; + for (const auto &bundle : bundle_specs) { + if (bundle->NodeId().IsNil()) { + unplaced_bundles.push_back(bundle); } } return unplaced_bundles; @@ -83,6 +87,8 @@ std::string GcsPlacementGroup::DebugString() const { } rpc::Bundle *GcsPlacementGroup::GetMutableBundle(int bundle_index) { + // Invalidate the cache. + cached_bundle_specs_.clear(); return placement_group_table_data_.mutable_bundles(bundle_index); } @@ -176,7 +182,7 @@ void GcsPlacementGroupManager::RegisterPlacementGroup( .emplace_back(std::move(callback)); registered_placement_groups_.emplace(placement_group->GetPlacementGroupID(), placement_group); - pending_placement_groups_.emplace_back(placement_group); + AddToPendingQueue(placement_group); RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group_id, placement_group->GetPlacementGroupTableData(), @@ -221,7 +227,8 @@ PlacementGroupID GcsPlacementGroupManager::GetPlacementGroupIDByName( } void GcsPlacementGroupManager::OnPlacementGroupCreationFailed( - std::shared_ptr placement_group, bool is_feasible) { + std::shared_ptr placement_group, ExponentialBackOff backoff, + bool is_feasible) { RAY_LOG(DEBUG) << "Failed to create placement group " << placement_group->GetName() << ", id: " << placement_group->GetPlacementGroupID() << ", try again."; @@ -229,7 +236,6 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationFailed( // We will attempt to schedule this placement_group once an eligible node is // registered. infeasible_placement_groups_.emplace_back(std::move(placement_group)); - MarkSchedulingDone(); } else { auto state = placement_group->GetState(); RAY_CHECK(state == rpc::PlacementGroupTableData::RESCHEDULING || @@ -241,14 +247,13 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationFailed( // NOTE: If a node is dead, the placement group scheduler should try to recover the // group by rescheduling the bundles of the dead node. This should have higher // priority than trying to place other placement groups. - pending_placement_groups_.emplace_front(std::move(placement_group)); + AddToPendingQueue(std::move(placement_group), /* rank */ 0); } else { - pending_placement_groups_.emplace_back(std::move(placement_group)); + AddToPendingQueue(std::move(placement_group), std::nullopt, backoff); } - - MarkSchedulingDone(); - RetryCreatingPlacementGroup(); } + io_context_.post([this] { SchedulePendingPlacementGroups(); }); + MarkSchedulingDone(); } void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( @@ -256,16 +261,11 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( RAY_LOG(INFO) << "Successfully created placement group " << placement_group->GetName() << ", id: " << placement_group->GetPlacementGroupID(); placement_group->UpdateState(rpc::PlacementGroupTableData::CREATED); - // Mark the scheduling done firstly. - MarkSchedulingDone(); auto placement_group_id = placement_group->GetPlacementGroupID(); RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group_id, placement_group->GetPlacementGroupTableData(), [this, placement_group_id](Status status) { RAY_CHECK_OK(status); - - SchedulePendingPlacementGroups(); - // Invoke all callbacks for all `WaitPlacementGroupUntilReady` requests of this // placement group and remove all of them from // placement_group_to_create_callbacks_. @@ -278,6 +278,8 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( placement_group_to_create_callbacks_.erase(pg_to_create_iter); } })); + io_context_.post([this] { SchedulePendingPlacementGroups(); }); + MarkSchedulingDone(); } void GcsPlacementGroupManager::SchedulePendingPlacementGroups() { @@ -294,16 +296,28 @@ void GcsPlacementGroupManager::SchedulePendingPlacementGroups() { bool is_new_placement_group_scheduled = false; while (!pending_placement_groups_.empty() && !is_new_placement_group_scheduled) { - const auto placement_group = pending_placement_groups_.front(); - pending_placement_groups_.pop_front(); + auto iter = pending_placement_groups_.begin(); + if (iter->first > absl::GetCurrentTimeNanos()) { + // Here the rank equals the time to schedule, and it's an ordered tree, + // it means all the other tasks should be scheduled after this one. + // If the first one won't be scheduled, we just skip. + // Tick will cover the next time retry. + break; + } + auto backoff = iter->second.first; + auto placement_group = std::move(iter->second.second); + pending_placement_groups_.erase(iter); + const auto &placement_group_id = placement_group->GetPlacementGroupID(); // Do not reschedule if the placement group has removed already. if (registered_placement_groups_.contains(placement_group_id)) { MarkSchedulingStarted(placement_group_id); gcs_placement_group_scheduler_->ScheduleUnplacedBundles( placement_group, - [this](std::shared_ptr placement_group, bool is_insfeasble) { - OnPlacementGroupCreationFailed(std::move(placement_group), is_insfeasble); + [this, backoff](std::shared_ptr placement_group, + bool is_insfeasble) { + OnPlacementGroupCreationFailed(std::move(placement_group), backoff, + is_insfeasble); }, [this](std::shared_ptr placement_group) { OnPlacementGroupCreationSuccess(std::move(placement_group)); @@ -312,6 +326,7 @@ void GcsPlacementGroupManager::SchedulePendingPlacementGroups() { } // If the placement group is not registered == removed. } + ++counts_[CountType::SCHEDULING_PENDING_PLACEMENT_GROUP]; } void GcsPlacementGroupManager::HandleCreatePlacementGroup( @@ -393,18 +408,10 @@ void GcsPlacementGroupManager::RemovePlacementGroup( } // Remove a placement group from a pending list if exists. - auto pending_it = std::find_if( - pending_placement_groups_.begin(), pending_placement_groups_.end(), - [placement_group_id](const std::shared_ptr &placement_group) { - return placement_group->GetPlacementGroupID() == placement_group_id; - }); - if (pending_it != pending_placement_groups_.end()) { - // The placement group was pending scheduling, remove it from the queue. - pending_placement_groups_.erase(pending_it); - } + RemoveFromPendingQueue(placement_group_id); // Remove a placement group from infeasible queue if exists. - pending_it = std::find_if( + auto pending_it = std::find_if( infeasible_placement_groups_.begin(), infeasible_placement_groups_.end(), [placement_group_id](const std::shared_ptr &placement_group) { return placement_group->GetPlacementGroupID() == placement_group_id; @@ -573,9 +580,36 @@ void GcsPlacementGroupManager::WaitPlacementGroup( } } -void GcsPlacementGroupManager::RetryCreatingPlacementGroup() { - execute_after(io_context_, [this] { SchedulePendingPlacementGroups(); }, - RayConfig::instance().gcs_create_placement_group_retry_interval_ms()); +void GcsPlacementGroupManager::AddToPendingQueue( + std::shared_ptr pg, std::optional rank, + std::optional exp_backer) { + if (!rank) { + rank = absl::GetCurrentTimeNanos(); + } + + if (!exp_backer) { + exp_backer = ExponentialBackOff( + 1000000 * + RayConfig::instance().gcs_create_placement_group_retry_min_interval_ms(), + RayConfig::instance().gcs_create_placement_group_retry_multiplier(), + 1000000 * + RayConfig::instance().gcs_create_placement_group_retry_max_interval_ms()); + } else { + *rank += static_cast(exp_backer->Next()); + } + auto val = std::make_pair(*exp_backer, std::move(pg)); + pending_placement_groups_.emplace(*rank, std::move(val)); +} + +void GcsPlacementGroupManager::RemoveFromPendingQueue(const PlacementGroupID &pg_id) { + auto it = std::find_if(pending_placement_groups_.begin(), + pending_placement_groups_.end(), [&pg_id](const auto &val) { + return val.second.second->GetPlacementGroupID() == pg_id; + }); + // The placement group was pending scheduling, remove it from the queue. + if (it != pending_placement_groups_.end()) { + pending_placement_groups_.erase(it); + } } void GcsPlacementGroupManager::OnNodeDead(const NodeID &node_id) { @@ -593,7 +627,7 @@ void GcsPlacementGroupManager::OnNodeDead(const NodeID &node_id) { // creating until a node with the resources is added. we will solve it in next pr. if (iter->second->GetState() != rpc::PlacementGroupTableData::RESCHEDULING) { iter->second->UpdateState(rpc::PlacementGroupTableData::RESCHEDULING); - pending_placement_groups_.emplace_front(iter->second); + AddToPendingQueue(iter->second, 0); } } } @@ -609,9 +643,9 @@ void GcsPlacementGroupManager::OnNodeAdd(const NodeID &node_id) { // Move all the infeasible placement groups to the pending queue so that we can // reschedule them. if (infeasible_placement_groups_.size() > 0) { - auto end_it = pending_placement_groups_.end(); - pending_placement_groups_.insert(end_it, infeasible_placement_groups_.cbegin(), - infeasible_placement_groups_.cend()); + for (auto &pg : infeasible_placement_groups_) { + AddToPendingQueue(std::move(pg)); + } infeasible_placement_groups_.clear(); } SchedulePendingPlacementGroups(); @@ -667,14 +701,16 @@ void GcsPlacementGroupManager::Tick() { // Note that we don't currently have a known race condition that requires this, but we // added as a safety check. https://github.com/ray-project/ray/pull/18419 SchedulePendingPlacementGroups(); - execute_after(io_context_, [this] { Tick(); }, 1000 /* milliseconds */); + execute_after( + io_context_, [this] { Tick(); }, 1000 /* milliseconds */); } void GcsPlacementGroupManager::UpdatePlacementGroupLoad() { std::shared_ptr placement_group_load = std::make_shared(); int total_cnt = 0; - for (const auto &pending_pg_spec : pending_placement_groups_) { + for (const auto &elem : pending_placement_groups_) { + const auto pending_pg_spec = elem.second.second; auto placement_group_data = placement_group_load->add_placement_group_data(); auto placement_group_table_data = pending_pg_spec->GetPlacementGroupTableData(); placement_group_data->Swap(&placement_group_table_data); @@ -710,7 +746,7 @@ void GcsPlacementGroupManager::Initialize(const GcsInitData &gcs_init_data) { if (item.second.state() == rpc::PlacementGroupTableData::PENDING || item.second.state() == rpc::PlacementGroupTableData::RESCHEDULING) { - pending_placement_groups_.emplace_back(std::move(placement_group)); + AddToPendingQueue(std::move(placement_group)); } if (item.second.state() == rpc::PlacementGroupTableData::CREATED || @@ -749,6 +785,8 @@ std::string GcsPlacementGroupManager::DebugString() const { << counts_[CountType::WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST] << ", GetNamedPlacementGroup request count: " << counts_[CountType::GET_NAMED_PLACEMENT_GROUP_REQUEST] + << ", Scheduling pending placement group count: " + << counts_[CountType::SCHEDULING_PENDING_PLACEMENT_GROUP] << ", Registered placement groups count: " << registered_placement_groups_.size() << ", Named placement group count: " << num_pgs << ", Pending placement groups count: " << pending_placement_groups_.size() diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h index bc3407fd8ac02..93bc68d306e43 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h @@ -13,8 +13,12 @@ // limitations under the License. #pragma once +#include + +#include #include +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "ray/common/asio/instrumented_io_context.h" @@ -89,10 +93,10 @@ class GcsPlacementGroup { std::string GetRayNamespace() const; /// Get the bundles of this placement_group (including unplaced). - std::vector> GetBundles() const; + std::vector> &GetBundles() const; /// Get the unplaced bundles of this placement group. - std::vector> GetUnplacedBundles() const; + std::vector> GetUnplacedBundles() const; /// Get the Strategy rpc::PlacementStrategy GetStrategy() const; @@ -121,9 +125,14 @@ class GcsPlacementGroup { bool IsDetached() const; private: + FRIEND_TEST(GcsPlacementGroupManagerTest, TestPlacementGroupBundleCache); /// The placement_group meta data which contains the task specification as well as the /// state of the gcs placement_group and so on (see gcs.proto). rpc::PlacementGroupTableData placement_group_table_data_; + /// Creating bundle specification requires heavy computation because it needs to compute + /// formatted strings for all resources (heavy string operations). To optimize the CPU + /// usage, we cache bundle specs. + mutable std::vector> cached_bundle_specs_; }; /// GcsPlacementGroupManager is responsible for managing the lifecycle of all placement @@ -209,7 +218,7 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { /// \param placement_group The placement_group whose creation task is infeasible. /// \param is_feasible whether the scheduler can be retry or not currently. void OnPlacementGroupCreationFailed(std::shared_ptr placement_group, - bool is_feasible = true); + ExponentialBackOff backoff, bool is_feasible); /// Handle placement_group creation task success. This should be called when the /// placement_group creation task has been scheduled successfully. @@ -277,6 +286,19 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { std::string DebugString() const; private: + /// Push a placement group to pending queue. + /// + /// \param pg The placementgroup we are adding + /// \param rank The rank for this placement group. Semantically it's the time + /// this placement group to be scheduled. By default it'll be assigned to be + /// the current time. + /// \param exp_backer The exponential backoff. A default one will be given if + /// it's not set. This will be used to generate the deferred time for this pg. + void AddToPendingQueue(std::shared_ptr pg, + std::optional rank = std::nullopt, + std::optional exp_backer = std::nullopt); + void RemoveFromPendingQueue(const PlacementGroupID &pg_id); + /// Try to create placement group after a short time. void RetryCreatingPlacementGroup(); @@ -322,12 +344,17 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { absl::flat_hash_map> registered_placement_groups_; - /// The pending placement_groups which will not be scheduled until there's a resource - /// change. - /// NOTE: When we remove placement group, we need to look for - /// `pending_placement_groups_` and delete the specific placement group, so we can't use - /// `std::priority_queue`. - std::deque> pending_placement_groups_; + /// The pending placement_groups which will not be scheduled until there's a + /// resource change. The pending queue is represented as an ordered map, where + /// the key is the time to schedule the pg and value if a pair containing the + /// actual placement group and a exp-backoff. + /// When error happens, we'll retry it later and this can be simply done by + /// inserting an element into the queue with a bigger key. With this, we don't + /// need to post retry job to io context. And when schedule pending placement + /// group, we always start with the one with the smallest key. + absl::btree_multimap>> + pending_placement_groups_; /// The infeasible placement_groups that can't be scheduled currently. std::deque> infeasible_placement_groups_; @@ -363,9 +390,14 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { GET_ALL_PLACEMENT_GROUP_REQUEST = 3, WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST = 4, GET_NAMED_PLACEMENT_GROUP_REQUEST = 5, - CountType_MAX = 6, + SCHEDULING_PENDING_PLACEMENT_GROUP = 6, + CountType_MAX = 7, }; uint64_t counts_[CountType::CountType_MAX] = {0}; + + FRIEND_TEST(GcsPlacementGroupManagerMockTest, PendingQueuePriorityReschedule); + FRIEND_TEST(GcsPlacementGroupManagerMockTest, PendingQueuePriorityFailed); + FRIEND_TEST(GcsPlacementGroupManagerMockTest, PendingQueuePriorityOrder); }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc index c2ca3c3c8cd40..7c9391315a945 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc @@ -39,7 +39,7 @@ GcsPlacementGroupScheduler::GcsPlacementGroupScheduler( } std::vector GcsScheduleStrategy::GetRequiredResourcesFromBundles( - const std::vector> &bundles) { + const std::vector> &bundles) { std::vector required_resources; for (const auto &bundle : bundles) { required_resources.push_back(bundle->GetRequiredResources()); @@ -48,7 +48,7 @@ std::vector GcsScheduleStrategy::GetRequiredResourcesFromBundles( } ScheduleResult GcsScheduleStrategy::GenerateScheduleResult( - const std::vector> &bundles, + const std::vector> &bundles, const std::vector &selected_nodes, const SchedulingResultStatus &status) { ScheduleMap schedule_map; if (status == SUCCESS && !selected_nodes.empty()) { @@ -62,7 +62,7 @@ ScheduleResult GcsScheduleStrategy::GenerateScheduleResult( } ScheduleResult GcsStrictPackStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { const auto &required_resources = GetRequiredResourcesFromBundles(bundles); @@ -73,7 +73,7 @@ ScheduleResult GcsStrictPackStrategy::Schedule( } ScheduleResult GcsPackStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { // The current algorithm is to select a node and deploy as many bundles as possible. @@ -87,7 +87,7 @@ ScheduleResult GcsPackStrategy::Schedule( } ScheduleResult GcsSpreadStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { const auto &required_resources = GetRequiredResourcesFromBundles(bundles); @@ -98,7 +98,7 @@ ScheduleResult GcsSpreadStrategy::Schedule( } ScheduleResult GcsStrictSpreadStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { // TODO(ffbin): A bundle may require special resources, such as GPU. We need to @@ -211,7 +211,7 @@ void GcsPlacementGroupScheduler::MarkScheduleCancelled( } void GcsPlacementGroupScheduler::PrepareResources( - const std::shared_ptr &bundle, + const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback &callback) { if (!node.has_value()) { @@ -240,7 +240,7 @@ void GcsPlacementGroupScheduler::PrepareResources( } void GcsPlacementGroupScheduler::CommitResources( - const std::shared_ptr &bundle, + const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback callback) { RAY_CHECK(node.has_value()); @@ -265,7 +265,7 @@ void GcsPlacementGroupScheduler::CommitResources( } void GcsPlacementGroupScheduler::CancelResourceReserve( - const std::shared_ptr &bundle_spec, + const std::shared_ptr &bundle_spec, const absl::optional> &node) { if (!node.has_value()) { RAY_LOG(INFO) << "Node for a placement group id " << bundle_spec->PlacementGroupId() @@ -660,7 +660,7 @@ void BundleLocationIndex::AddNodes( LeaseStatusTracker::LeaseStatusTracker( std::shared_ptr placement_group, - const std::vector> &unplaced_bundles, + const std::vector> &unplaced_bundles, const ScheduleMap &schedule_map) : placement_group_(placement_group), bundles_to_schedule_(unplaced_bundles) { preparing_bundle_locations_ = std::make_shared(); @@ -675,13 +675,13 @@ LeaseStatusTracker::LeaseStatusTracker( } bool LeaseStatusTracker::MarkPreparePhaseStarted( - const NodeID &node_id, std::shared_ptr bundle) { + const NodeID &node_id, const std::shared_ptr &bundle) { const auto &bundle_id = bundle->BundleId(); return node_to_bundles_when_preparing_[node_id].emplace(bundle_id).second; } void LeaseStatusTracker::MarkPrepareRequestReturned( - const NodeID &node_id, const std::shared_ptr bundle, + const NodeID &node_id, const std::shared_ptr &bundle, const Status &status) { RAY_CHECK(prepare_request_returned_count_ <= bundles_to_schedule_.size()); auto leasing_bundles = node_to_bundles_when_preparing_.find(node_id); @@ -715,7 +715,7 @@ bool LeaseStatusTracker::AllPrepareRequestsSuccessful() const { } void LeaseStatusTracker::MarkCommitRequestReturned( - const NodeID &node_id, const std::shared_ptr bundle, + const NodeID &node_id, const std::shared_ptr &bundle, const Status &status) { commit_request_returned_count_ += 1; // If the request succeeds, record it. @@ -762,7 +762,7 @@ const std::shared_ptr &LeaseStatusTracker::GetBundleLocations() return bundle_locations_; } -const std::vector> +const std::vector> &LeaseStatusTracker::GetBundlesToSchedule() const { return bundles_to_schedule_; } diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h index bdfee4276dec5..4e921ab13e248 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h @@ -49,9 +49,8 @@ struct pair_hash { }; using ScheduleMap = std::unordered_map; using ScheduleResult = std::pair; -using BundleLocations = - absl::flat_hash_map>, - pair_hash>; +using BundleLocations = absl::flat_hash_map< + BundleID, std::pair>, pair_hash>; class GcsPlacementGroupSchedulerInterface { public: @@ -112,7 +111,7 @@ class GcsScheduleStrategy { public: virtual ~GcsScheduleStrategy() {} virtual ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) = 0; @@ -122,7 +121,7 @@ class GcsScheduleStrategy { /// \param bundles Bundles to be scheduled. /// \return Required resources. std::vector GetRequiredResourcesFromBundles( - const std::vector> &bundles); + const std::vector> &bundles); /// Generate `ScheduleResult` from bundles and nodes . /// @@ -131,7 +130,7 @@ class GcsScheduleStrategy { /// \param status Status of the scheduling result. /// \return The scheduling result from the required resource. ScheduleResult GenerateScheduleResult( - const std::vector> &bundles, + const std::vector> &bundles, const std::vector &selected_nodes, const SchedulingResultStatus &status); }; @@ -141,7 +140,7 @@ class GcsScheduleStrategy { class GcsPackStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -150,7 +149,7 @@ class GcsPackStrategy : public GcsScheduleStrategy { class GcsSpreadStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -160,7 +159,7 @@ class GcsSpreadStrategy : public GcsScheduleStrategy { class GcsStrictPackStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -171,7 +170,7 @@ class GcsStrictPackStrategy : public GcsScheduleStrategy { class GcsStrictSpreadStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -192,7 +191,7 @@ class LeaseStatusTracker { public: LeaseStatusTracker( std::shared_ptr placement_group, - const std::vector> &unplaced_bundles, + const std::vector> &unplaced_bundles, const ScheduleMap &schedule_map); ~LeaseStatusTracker() = default; @@ -202,7 +201,7 @@ class LeaseStatusTracker { /// \param bundle Bundle specification the node is supposed to prepare. /// \return False if the prepare phase was already started. True otherwise. bool MarkPreparePhaseStarted(const NodeID &node_id, - std::shared_ptr bundle); + const std::shared_ptr &bundle); /// Indicate the tracker that all prepare requests are returned. /// @@ -210,9 +209,9 @@ class LeaseStatusTracker { /// \param bundle Bundle specification the node was supposed to schedule. /// \param status Status of the prepare response. /// \param void - void MarkPrepareRequestReturned(const NodeID &node_id, - std::shared_ptr bundle, - const Status &status); + void MarkPrepareRequestReturned( + const NodeID &node_id, const std::shared_ptr &bundle, + const Status &status); /// Used to know if all prepare requests are returned. /// @@ -230,7 +229,7 @@ class LeaseStatusTracker { /// \param bundle Bundle specification the node was supposed to schedule. /// \param status Status of the returned commit request. void MarkCommitRequestReturned(const NodeID &node_id, - const std::shared_ptr bundle, + const std::shared_ptr &bundle, const Status &status); /// Used to know if all commit requests are returend. @@ -251,7 +250,8 @@ class LeaseStatusTracker { /// Return bundles that should be scheduled. /// /// \return List of bundle specification that are supposed to be scheduled. - const std::vector> &GetBundlesToSchedule() const; + [[nodiscard]] const std::vector> + &GetBundlesToSchedule() const; /// This method returns bundle locations that succeed to prepare resources. /// @@ -324,7 +324,7 @@ class LeaseStatusTracker { node_to_bundles_when_preparing_; /// Bundles to schedule. - std::vector> bundles_to_schedule_; + std::vector> bundles_to_schedule_; /// Location of bundles. std::shared_ptr bundle_locations_; @@ -460,7 +460,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// \param node A node to prepare resources for a given bundle. /// \param callback void PrepareResources( - const std::shared_ptr &bundle, + const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback &callback); @@ -470,7 +470,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// \param bundle A bundle to schedule on a node. /// \param node A node to commit resources for a given bundle. /// \param callback - void CommitResources(const std::shared_ptr &bundle, + void CommitResources(const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback callback); @@ -481,7 +481,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// \param bundle A description of the bundle to return. /// \param node The node that the worker will be returned for. void CancelResourceReserve( - const std::shared_ptr &bundle_spec, + const std::shared_ptr &bundle_spec, const absl::optional> &node); /// Get an existing lease client or connect a new one or connect a new one. diff --git a/src/ray/gcs/gcs_server/gcs_resource_manager.cc b/src/ray/gcs/gcs_server/gcs_resource_manager.cc index 983edfe7df9c3..eec2d5d4dfd22 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_resource_manager.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/gcs/gcs_server/gcs_resource_manager.h" + #include "ray/common/ray_config.h" #include "ray/stats/stats.h" @@ -233,10 +234,8 @@ void GcsResourceManager::HandleGetAllResourceUsage( aggregate_demand.set_num_infeasible_requests_queued( aggregate_demand.num_infeasible_requests_queued() + demand.num_infeasible_requests_queued()); - if (RayConfig::instance().report_worker_backlog()) { - aggregate_demand.set_backlog_size(aggregate_demand.backlog_size() + - demand.backlog_size()); - } + aggregate_demand.set_backlog_size(aggregate_demand.backlog_size() + + demand.backlog_size()); } batch->add_batch()->CopyFrom(usage.second); diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 6a4d60c685e9b..84821c6af1d3a 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -35,7 +35,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, : config_(config), main_service_(main_service), rpc_server_(config.grpc_server_name, config.grpc_server_port, - config.grpc_server_thread_num, + config.node_ip_address == "127.0.0.1", config.grpc_server_thread_num, /*keepalive_time_ms=*/RayConfig::instance().grpc_keepalive_time_ms()), client_call_manager_(main_service), raylet_client_pool_( @@ -267,7 +267,8 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { client_factory); } gcs_actor_manager_ = std::make_shared( - std::move(scheduler), gcs_table_storage_, gcs_pub_sub_, *runtime_env_manager_, + main_service_, std::move(scheduler), gcs_table_storage_, gcs_pub_sub_, + *runtime_env_manager_, [this](const ActorID &actor_id) { gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(actor_id); }, @@ -478,7 +479,7 @@ void GcsServer::InstallEventListeners() { gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenJobDead(*job_id); }); - // Install scheduling policy event listeners. + // Install scheduling event listeners. if (RayConfig::instance().gcs_actor_scheduling_enabled()) { gcs_resource_manager_->AddResourcesChangedListener([this] { main_service_.post([this] { @@ -513,9 +514,10 @@ void GcsServer::PrintDebugInfo() { // TODO(ffbin): We will get the session_dir in the next PR, and write the log to // gcs_debug_state.txt. RAY_LOG(INFO) << stream.str(); - execute_after(main_service_, [this] { PrintDebugInfo(); }, - (RayConfig::instance().gcs_dump_debug_log_interval_minutes() * - 60000) /* milliseconds */); + execute_after( + main_service_, [this] { PrintDebugInfo(); }, + (RayConfig::instance().gcs_dump_debug_log_interval_minutes() * + 60000) /* milliseconds */); } void GcsServer::PrintAsioStats() { @@ -524,8 +526,9 @@ void GcsServer::PrintAsioStats() { RayConfig::instance().event_stats_print_interval_ms(); if (event_stats_print_interval_ms != -1 && RayConfig::instance().event_stats()) { RAY_LOG(INFO) << "Event stats:\n\n" << main_service_.StatsString() << "\n\n"; - execute_after(main_service_, [this] { PrintAsioStats(); }, - event_stats_print_interval_ms /* milliseconds */); + execute_after( + main_service_, [this] { PrintAsioStats(); }, + event_stats_print_interval_ms /* milliseconds */); } } diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index cadb70a3f3541..507ab2820cab7 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -16,7 +16,6 @@ #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/runtime_env_manager.h" -#include "ray/gcs/gcs_server/gcs_actor_distribution.h" #include "ray/gcs/gcs_server/gcs_heartbeat_manager.h" #include "ray/gcs/gcs_server/gcs_init_data.h" #include "ray/gcs/gcs_server/gcs_kv_manager.h" diff --git a/src/ray/gcs/gcs_server/gcs_server_main.cc b/src/ray/gcs/gcs_server/gcs_server_main.cc index 79a64514409b6..abb39469b427e 100644 --- a/src/ray/gcs/gcs_server/gcs_server_main.cc +++ b/src/ray/gcs/gcs_server/gcs_server_main.cc @@ -80,13 +80,12 @@ int main(int argc, char *argv[]) { storage->InternalConfigTable().Put(ray::UniqueID::Nil(), config, on_done)); boost::asio::io_service::work work(service); service.run(); - }) - .detach(); + }).detach(); promise->get_future().get(); const ray::stats::TagsType global_tags = { {ray::stats::ComponentKey, "gcs_server"}, - {ray::stats::VersionKey, "2.0.0.dev0"}, + {ray::stats::VersionKey, kRayVersion}, {ray::stats::NodeAddressKey, node_ip_address}}; ray::stats::Init(global_tags, metrics_agent_port); diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.h b/src/ray/gcs/gcs_server/gcs_table_storage.h index 84a70a347ebf7..ed48cf71abdf2 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/ray/gcs/gcs_server/gcs_table_storage.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "ray/common/asio/instrumented_io_context.h" @@ -51,8 +52,8 @@ using rpc::WorkerTableData; template class GcsTable { public: - explicit GcsTable(std::shared_ptr &store_client) - : store_client_(store_client) {} + explicit GcsTable(std::shared_ptr store_client) + : store_client_(std::move(store_client)) {} virtual ~GcsTable() = default; @@ -106,8 +107,8 @@ class GcsTable { template class GcsTableWithJobId : public GcsTable { public: - explicit GcsTableWithJobId(std::shared_ptr &store_client) - : GcsTable(store_client) {} + explicit GcsTableWithJobId(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) {} /// Write data to the table asynchronously. /// @@ -152,16 +153,16 @@ class GcsTableWithJobId : public GcsTable { class GcsJobTable : public GcsTable { public: - explicit GcsJobTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsJobTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::JOB); } }; class GcsActorTable : public GcsTableWithJobId { public: - explicit GcsActorTable(std::shared_ptr &store_client) - : GcsTableWithJobId(store_client) { + explicit GcsActorTable(std::shared_ptr store_client) + : GcsTableWithJobId(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::ACTOR); } @@ -172,16 +173,16 @@ class GcsActorTable : public GcsTableWithJobId { class GcsPlacementGroupTable : public GcsTable { public: - explicit GcsPlacementGroupTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsPlacementGroupTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::PLACEMENT_GROUP); } }; class GcsTaskTable : public GcsTableWithJobId { public: - explicit GcsTaskTable(std::shared_ptr &store_client) - : GcsTableWithJobId(store_client) { + explicit GcsTaskTable(std::shared_ptr store_client) + : GcsTableWithJobId(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::TASK); } @@ -191,8 +192,8 @@ class GcsTaskTable : public GcsTableWithJobId { class GcsTaskLeaseTable : public GcsTableWithJobId { public: - explicit GcsTaskLeaseTable(std::shared_ptr &store_client) - : GcsTableWithJobId(store_client) { + explicit GcsTaskLeaseTable(std::shared_ptr store_client) + : GcsTableWithJobId(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::TASK_LEASE); } @@ -203,8 +204,8 @@ class GcsTaskLeaseTable : public GcsTableWithJobId { class GcsTaskReconstructionTable : public GcsTableWithJobId { public: - explicit GcsTaskReconstructionTable(std::shared_ptr &store_client) - : GcsTableWithJobId(store_client) { + explicit GcsTaskReconstructionTable(std::shared_ptr store_client) + : GcsTableWithJobId(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::TASK_RECONSTRUCTION); } @@ -214,8 +215,8 @@ class GcsTaskReconstructionTable class GcsObjectTable : public GcsTableWithJobId { public: - explicit GcsObjectTable(std::shared_ptr &store_client) - : GcsTableWithJobId(store_client) { + explicit GcsObjectTable(std::shared_ptr store_client) + : GcsTableWithJobId(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::OBJECT); } @@ -225,56 +226,56 @@ class GcsObjectTable : public GcsTableWithJobId { class GcsNodeTable : public GcsTable { public: - explicit GcsNodeTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsNodeTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::NODE); } }; class GcsNodeResourceTable : public GcsTable { public: - explicit GcsNodeResourceTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsNodeResourceTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::NODE_RESOURCE); } }; class GcsPlacementGroupScheduleTable : public GcsTable { public: - explicit GcsPlacementGroupScheduleTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsPlacementGroupScheduleTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::PLACEMENT_GROUP_SCHEDULE); } }; class GcsResourceUsageBatchTable : public GcsTable { public: - explicit GcsResourceUsageBatchTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsResourceUsageBatchTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::RESOURCE_USAGE_BATCH); } }; class GcsProfileTable : public GcsTable { public: - explicit GcsProfileTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsProfileTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::PROFILE); } }; class GcsWorkerTable : public GcsTable { public: - explicit GcsWorkerTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsWorkerTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::WORKERS); } }; class GcsInternalConfigTable : public GcsTable { public: - explicit GcsInternalConfigTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsInternalConfigTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::INTERNAL_CONFIG); } }; @@ -285,6 +286,29 @@ class GcsInternalConfigTable : public GcsTable { /// derive from this class and override class member variables. class GcsTableStorage { public: + explicit GcsTableStorage(std::shared_ptr store_client) + : store_client_(std::move(store_client)) { + job_table_ = std::make_unique(store_client_); + actor_table_ = std::make_unique(store_client_); + placement_group_table_ = std::make_unique(store_client_); + task_table_ = std::make_unique(store_client_); + task_lease_table_ = std::make_unique(store_client_); + task_reconstruction_table_ = + std::make_unique(store_client_); + object_table_ = std::make_unique(store_client_); + node_table_ = std::make_unique(store_client_); + node_resource_table_ = std::make_unique(store_client_); + placement_group_schedule_table_ = + std::make_unique(store_client_); + placement_group_schedule_table_ = + std::make_unique(store_client_); + resource_usage_batch_table_ = + std::make_unique(store_client_); + profile_table_ = std::make_unique(store_client_); + worker_table_ = std::make_unique(store_client_); + system_config_table_ = std::make_unique(store_client_); + } + GcsJobTable &JobTable() { RAY_CHECK(job_table_ != nullptr); return *job_table_; @@ -383,26 +407,8 @@ class GcsTableStorage { /// that uses redis as storage. class RedisGcsTableStorage : public GcsTableStorage { public: - explicit RedisGcsTableStorage(std::shared_ptr redis_client) { - store_client_ = std::make_shared(redis_client); - job_table_.reset(new GcsJobTable(store_client_)); - actor_table_.reset(new GcsActorTable(store_client_)); - placement_group_table_.reset(new GcsPlacementGroupTable(store_client_)); - task_table_.reset(new GcsTaskTable(store_client_)); - task_lease_table_.reset(new GcsTaskLeaseTable(store_client_)); - task_reconstruction_table_.reset(new GcsTaskReconstructionTable(store_client_)); - object_table_.reset(new GcsObjectTable(store_client_)); - node_table_.reset(new GcsNodeTable(store_client_)); - node_resource_table_.reset(new GcsNodeResourceTable(store_client_)); - placement_group_schedule_table_.reset( - new GcsPlacementGroupScheduleTable(store_client_)); - placement_group_schedule_table_.reset( - new GcsPlacementGroupScheduleTable(store_client_)); - resource_usage_batch_table_.reset(new GcsResourceUsageBatchTable(store_client_)); - profile_table_.reset(new GcsProfileTable(store_client_)); - worker_table_.reset(new GcsWorkerTable(store_client_)); - system_config_table_.reset(new GcsInternalConfigTable(store_client_)); - } + explicit RedisGcsTableStorage(std::shared_ptr redis_client) + : GcsTableStorage(std::make_shared(std::move(redis_client))) {} }; /// \class InMemoryGcsTableStorage @@ -410,24 +416,8 @@ class RedisGcsTableStorage : public GcsTableStorage { /// that uses memory as storage. class InMemoryGcsTableStorage : public GcsTableStorage { public: - explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service) { - store_client_ = std::make_shared(main_io_service); - job_table_.reset(new GcsJobTable(store_client_)); - actor_table_.reset(new GcsActorTable(store_client_)); - placement_group_table_.reset(new GcsPlacementGroupTable(store_client_)); - task_table_.reset(new GcsTaskTable(store_client_)); - task_lease_table_.reset(new GcsTaskLeaseTable(store_client_)); - task_reconstruction_table_.reset(new GcsTaskReconstructionTable(store_client_)); - object_table_.reset(new GcsObjectTable(store_client_)); - node_table_.reset(new GcsNodeTable(store_client_)); - node_resource_table_.reset(new GcsNodeResourceTable(store_client_)); - placement_group_schedule_table_.reset( - new GcsPlacementGroupScheduleTable(store_client_)); - resource_usage_batch_table_.reset(new GcsResourceUsageBatchTable(store_client_)); - profile_table_.reset(new GcsProfileTable(store_client_)); - worker_table_.reset(new GcsWorkerTable(store_client_)); - system_config_table_.reset(new GcsInternalConfigTable(store_client_)); - } + explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service) + : GcsTableStorage(std::make_shared(main_io_service)) {} }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc index b921fd2acd2a0..f43d40dd392ac 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc @@ -33,6 +33,7 @@ class MockActorScheduler : public gcs::GcsActorSchedulerInterface { void Reschedule(std::shared_ptr actor) {} void ReleaseUnusedWorkers( const std::unordered_map> &node_to_workers) {} + void OnActorDestruction(std::shared_ptr actor) {} MOCK_METHOD1(CancelOnNode, std::vector(const NodeID &node_id)); MOCK_METHOD2(CancelOnWorker, ActorID(const NodeID &node_id, const WorkerID &worker_id)); @@ -105,8 +106,8 @@ class GcsActorManagerTest : public ::testing::Test { store_client_ = std::make_shared(io_service_); gcs_table_storage_ = std::make_shared(io_service_); gcs_actor_manager_.reset(new gcs::GcsActorManager( - mock_actor_scheduler_, gcs_table_storage_, gcs_pub_sub_, *runtime_env_mgr_, - [](const ActorID &actor_id) {}, + io_service_, mock_actor_scheduler_, gcs_table_storage_, gcs_pub_sub_, + *runtime_env_mgr_, [](const ActorID &actor_id) {}, [this](const JobID &job_id) { return job_namespace_table_[job_id]; }, [this](std::function fn, boost::posix_time::milliseconds delay) { if (skip_delay_) { @@ -953,6 +954,7 @@ TEST_F(GcsActorManagerTest, TestRayNamespace) { } TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { + google::protobuf::Arena arena; skip_delay_ = false; auto job_id_1 = JobID::FromInt(1); auto request1 = Mocker::GenRegisterActorRequest(job_id_1, /*max_restarts=*/0, @@ -971,7 +973,8 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { { rpc::GetAllActorInfoRequest request; - rpc::GetAllActorInfoReply reply; + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); bool called = false; auto callback = [&called](Status status, std::function success, std::function failure) { called = true; }; @@ -981,7 +984,8 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { } { rpc::GetAllActorInfoRequest request; - rpc::GetAllActorInfoReply reply; + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); request.set_show_dead_jobs(true); std::promise promise; auto callback = [&promise](Status status, std::function success, @@ -994,7 +998,8 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { delayed_to_run_(); { rpc::GetAllActorInfoRequest request; - rpc::GetAllActorInfoReply reply; + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); request.set_show_dead_jobs(true); std::promise promise; auto callback = [&promise](Status status, std::function success, diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc new file mode 100644 index 0000000000000..0829caf3e0d91 --- /dev/null +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc @@ -0,0 +1,139 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off +#include "gtest/gtest.h" +#include "gmock/gmock.h" +#include "ray/gcs/gcs_server/gcs_actor_manager.h" +#include "ray/gcs/gcs_server/gcs_actor_scheduler.h" +#include "mock/ray/gcs/store_client/store_client.h" +#include "mock/ray/gcs/gcs_server/gcs_node_manager.h" +#include "mock/ray/raylet_client/raylet_client.h" +#include "mock/ray/pubsub/subscriber.h" +#include "mock/ray/gcs/pubsub/gcs_pub_sub.h" +#include "mock/ray/rpc/worker/core_worker_client.h" +// clang-format on +using namespace ::testing; + +namespace ray { +namespace gcs { +struct MockCallback { + MOCK_METHOD(void, Call, ((std::shared_ptr))); + void operator()(std::shared_ptr a) { return Call(a); } +}; + +class GcsActorSchedulerTest : public Test { + public: + void SetUp() override { + store_client = std::make_shared(); + actor_table = std::make_unique(store_client); + gcs_node_manager = std::make_unique(); + pub_sub = std::make_shared(); + raylet_client = std::make_shared(); + core_worker_client = std::make_shared(); + client_pool = std::make_shared( + [this](const rpc::Address &) { return raylet_client; }); + actor_scheduler = std::make_unique( + io_context, *actor_table, *gcs_node_manager, pub_sub, + [this](auto a) { schedule_failure_handler(a); }, + [this](auto a, const rpc::PushTaskReply) { schedule_success_handler(a); }, + client_pool, [this](const rpc::Address &) { return core_worker_client; }); + auto node_info = std::make_shared(); + node_info->set_state(rpc::GcsNodeInfo::ALIVE); + node_id = NodeID::FromRandom(); + node_info->set_node_id(node_id.Binary()); + worker_id = WorkerID::FromRandom(); + gcs_node_manager->AddNode(node_info); + } + std::shared_ptr raylet_client; + instrumented_io_context io_context; + std::shared_ptr store_client; + std::unique_ptr actor_table; + std::unique_ptr actor_scheduler; + std::unique_ptr gcs_node_manager; + std::shared_ptr pub_sub; + std::shared_ptr core_worker_client; + std::shared_ptr client_pool; + MockCallback schedule_failure_handler; + MockCallback schedule_success_handler; + NodeID node_id; + WorkerID worker_id; +}; + +TEST_F(GcsActorSchedulerTest, KillWorkerLeak1) { + // Ensure worker is not leak in the following case: + // 1. Gcs start to lease a worker + // 2. Gcs cancel the actor + // 3. Gcs lease reply with a grant + // We'd like to test the worker got released eventually. + // Worker is released with actor killing + auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); + rpc::ActorTableData actor_data; + actor_data.set_state(rpc::ActorTableData::PENDING_CREATION); + actor_data.set_actor_id(actor_id.Binary()); + auto actor = std::make_shared(actor_data); + std::function cb; + EXPECT_CALL(*raylet_client, RequestWorkerLease(Matcher(), _, _)) + .WillOnce(testing::SaveArg<1>(&cb)); + // Ensure actor is killed + EXPECT_CALL(*core_worker_client, KillActor(_, _)); + actor_scheduler->Schedule(actor); + actor->GetMutableActorTableData()->set_state(rpc::ActorTableData::DEAD); + actor_scheduler->CancelOnNode(node_id); + ray::rpc::RequestWorkerLeaseReply reply; + reply.mutable_worker_address()->set_raylet_id(node_id.Binary()); + reply.mutable_worker_address()->set_worker_id(worker_id.Binary()); + cb(Status::OK(), reply); +} + +TEST_F(GcsActorSchedulerTest, KillWorkerLeak2) { + // Ensure worker is not leak in the following case: + // 1. Actor is in pending creation + // 2. Gcs push creation task to run in worker + // 3. Cancel the task + // 4. Task creating reply received + // We'd like to test the worker got released eventually. + // Worker is released with actor killing + auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); + rpc::ActorTableData actor_data; + actor_data.set_state(rpc::ActorTableData::PENDING_CREATION); + actor_data.set_actor_id(actor_id.Binary()); + auto actor = std::make_shared(actor_data); + rpc::ClientCallback request_worker_lease_cb; + // Ensure actor is killed + EXPECT_CALL(*core_worker_client, KillActor(_, _)); + EXPECT_CALL(*raylet_client, RequestWorkerLease(Matcher(), _, _)) + .WillOnce(testing::SaveArg<1>(&request_worker_lease_cb)); + + std::function async_put_with_index_cb; + // Leasing successfully + EXPECT_CALL(*store_client, AsyncPutWithIndex(_, _, _, _, _)) + .WillOnce(DoAll(SaveArg<4>(&async_put_with_index_cb), Return(Status::OK()))); + actor_scheduler->Schedule(actor); + rpc::RequestWorkerLeaseReply reply; + reply.mutable_worker_address()->set_raylet_id(node_id.Binary()); + reply.mutable_worker_address()->set_worker_id(worker_id.Binary()); + request_worker_lease_cb(Status::OK(), reply); + + rpc::ClientCallback push_normal_task_cb; + // Worker start to run task + EXPECT_CALL(*core_worker_client, PushNormalTask(_, _)) + .WillOnce(testing::SaveArg<1>(&push_normal_task_cb)); + async_put_with_index_cb(Status::OK()); + actor->GetMutableActorTableData()->set_state(rpc::ActorTableData::DEAD); + actor_scheduler->CancelOnWorker(node_id, worker_id); + push_normal_task_cb(Status::OK(), rpc::PushTaskReply()); +} +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc index ada5f0094872b..48793907f117f 100644 --- a/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc @@ -147,13 +147,15 @@ TEST_F(GcsBasedActorSchedulerTest, TestNotEnoughClusterResources) { ASSERT_TRUE(actor->GetNodeID().IsNil()); } -TEST_F(GcsBasedActorSchedulerTest, TestScheduleOneActor) { +TEST_F(GcsBasedActorSchedulerTest, TestScheduleAndDestroyOneActor) { // Add a node with 64 memory units and 8 CPU. std::unordered_map node_resources = {{kMemory_ResourceLabel, 64}, {kCPU_ResourceLabel, 8}}; auto node = AddNewNode(node_resources); auto node_id = NodeID::FromBinary(node->node_id()); ASSERT_EQ(1, gcs_node_manager_->GetAllAliveNodes().size()); + auto cluster_resources_before_scheduling = gcs_resource_manager_->GetClusterResources(); + ASSERT_TRUE(cluster_resources_before_scheduling.contains(node_id)); // Schedule a actor (requiring 32 memory units and 4 CPU). std::unordered_map required_placement_resources = { @@ -182,6 +184,20 @@ TEST_F(GcsBasedActorSchedulerTest, TestScheduleOneActor) { ASSERT_EQ(actor, success_actors_.front()); ASSERT_EQ(actor->GetNodeID(), node_id); ASSERT_EQ(actor->GetWorkerID(), worker_id); + + auto cluster_resources_after_scheduling = gcs_resource_manager_->GetClusterResources(); + ASSERT_TRUE(cluster_resources_after_scheduling.contains(node_id)); + ASSERT_FALSE( + cluster_resources_before_scheduling[node_id].GetAvailableResources().IsEqual( + cluster_resources_after_scheduling[node_id].GetAvailableResources())); + + // When destroying an actor, its acquired resources have to be returned. + gcs_actor_scheduler_->OnActorDestruction(actor); + auto cluster_resources_after_destruction = gcs_resource_manager_->GetClusterResources(); + ASSERT_TRUE(cluster_resources_after_destruction.contains(node_id)); + ASSERT_TRUE( + cluster_resources_before_scheduling[node_id].GetAvailableResources().IsEqual( + cluster_resources_after_destruction[node_id].GetAvailableResources())); } TEST_F(GcsBasedActorSchedulerTest, TestBalancedSchedule) { diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc new file mode 100644 index 0000000000000..e017fb793bafe --- /dev/null +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc @@ -0,0 +1,174 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off +#include "gtest/gtest.h" +#include "gmock/gmock.h" +#include "ray/gcs/gcs_server/gcs_placement_group_manager.h" +#include "mock/ray/gcs/gcs_server/gcs_placement_group_manager.h" +#include "mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h" +#include "mock/ray/gcs/gcs_server/gcs_resource_manager.h" +#include "mock/ray/gcs/store_client/store_client.h" +#include "ray/gcs/test/gcs_test_util.h" +// clang-format on + +using namespace ::testing; +using namespace ray; +using namespace ray::gcs; +namespace ray { +namespace gcs { + +class GcsPlacementGroupManagerMockTest : public Test { + public: + void SetUp() override { + store_client_ = std::make_shared(); + gcs_table_storage_ = std::make_shared(store_client_); + gcs_placement_group_scheduler_ = + std::make_shared(); + resource_manager_ = + std::make_shared(io_context_, nullptr, nullptr, true); + + gcs_placement_group_manager_ = std::make_unique( + io_context_, gcs_placement_group_scheduler_, gcs_table_storage_, + *resource_manager_, [](auto &) { return ""; }); + } + + std::unique_ptr gcs_placement_group_manager_; + std::shared_ptr gcs_placement_group_scheduler_; + std::shared_ptr gcs_table_storage_; + std::shared_ptr store_client_; + std::shared_ptr resource_manager_; + instrumented_io_context io_context_; +}; + +TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityReschedule) { + // Test priority works + // When return with reschedule, it should be given with the highest pri + auto req = + Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); + auto pg = std::make_shared(req, ""); + auto cb = [](Status s) {}; + PGSchedulingFailureCallback failure_callback; + PGSchedulingSuccessfulCallback success_callback; + StatusCallback put_cb; + EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _)) + .WillOnce(DoAll(SaveArg<3>(&put_cb), Return(Status::OK()))); + EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_, _, _)) + .WillOnce(DoAll(SaveArg<1>(&failure_callback), SaveArg<2>(&success_callback))); + auto now = absl::GetCurrentTimeNanos(); + gcs_placement_group_manager_->RegisterPlacementGroup(pg, cb); + auto &pending_queue = gcs_placement_group_manager_->pending_placement_groups_; + ASSERT_EQ(1, pending_queue.size()); + ASSERT_LE(now, pending_queue.begin()->first); + ASSERT_GE(absl::GetCurrentTimeNanos(), pending_queue.begin()->first); + put_cb(Status::OK()); + pg->UpdateState(rpc::PlacementGroupTableData::RESCHEDULING); + failure_callback(pg, true); + ASSERT_EQ(1, pending_queue.size()); + ASSERT_GE(0, pending_queue.begin()->first); +} + +TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityFailed) { + // Test priority works + // When return with a failure, exp backoff should work + auto req = + Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); + auto pg = std::make_shared(req, ""); + auto cb = [](Status s) {}; + PGSchedulingFailureCallback failure_callback; + PGSchedulingSuccessfulCallback success_callback; + StatusCallback put_cb; + EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _)) + .WillOnce(DoAll(SaveArg<3>(&put_cb), Return(Status::OK()))); + EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_, _, _)) + .Times(2) + .WillRepeatedly( + DoAll(SaveArg<1>(&failure_callback), SaveArg<2>(&success_callback))); + auto now = absl::GetCurrentTimeNanos(); + gcs_placement_group_manager_->RegisterPlacementGroup(pg, cb); + auto &pending_queue = gcs_placement_group_manager_->pending_placement_groups_; + ASSERT_EQ(1, pending_queue.size()); + ASSERT_LE(now, pending_queue.begin()->first); + ASSERT_GE(absl::GetCurrentTimeNanos(), pending_queue.begin()->first); + put_cb(Status::OK()); + pg->UpdateState(rpc::PlacementGroupTableData::PENDING); + now = absl::GetCurrentTimeNanos(); + failure_callback(pg, true); + auto exp_backer = ExponentialBackOff( + 1000000 * RayConfig::instance().gcs_create_placement_group_retry_min_interval_ms(), + RayConfig::instance().gcs_create_placement_group_retry_multiplier(), + 1000000 * RayConfig::instance().gcs_create_placement_group_retry_max_interval_ms()); + auto next = exp_backer.Next(); + ASSERT_DOUBLE_EQ( + next, + 1000000 * RayConfig::instance().gcs_create_placement_group_retry_min_interval_ms()); + ASSERT_EQ(1, pending_queue.size()); + auto rank = pending_queue.begin()->first; + ASSERT_LE(now + next, rank); + // ScheduleUnplacedBundles is not called here + gcs_placement_group_manager_->SchedulePendingPlacementGroups(); + ASSERT_EQ(1, pending_queue.size()); + ASSERT_EQ(rank, pending_queue.begin()->first); + + absl::SleepFor(absl::Milliseconds(1) + + absl::Nanoseconds(rank - absl::GetCurrentTimeNanos())); + gcs_placement_group_manager_->SchedulePendingPlacementGroups(); + ASSERT_EQ(0, pending_queue.size()); + pg->UpdateState(rpc::PlacementGroupTableData::PENDING); + now = absl::GetCurrentTimeNanos(); + failure_callback(pg, true); + next = RayConfig::instance().gcs_create_placement_group_retry_multiplier() * next; + ASSERT_EQ(1, pending_queue.size()); + ASSERT_LE(now + next, pending_queue.begin()->first); +} + +TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityOrder) { + // Test priority works + // Add two pgs + // Fail one and make sure it's scheduled later + auto req1 = + Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); + auto pg1 = std::make_shared(req1, ""); + auto req2 = + Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); + auto pg2 = std::make_shared(req2, ""); + auto cb = [](Status s) {}; + PGSchedulingFailureCallback failure_callback; + PGSchedulingSuccessfulCallback success_callback; + StatusCallback put_cb; + EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _)) + .Times(2) + .WillRepeatedly(DoAll(SaveArg<3>(&put_cb), Return(Status::OK()))); + EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_, _, _)) + .Times(2) + .WillRepeatedly( + DoAll(SaveArg<1>(&failure_callback), SaveArg<2>(&success_callback))); + gcs_placement_group_manager_->RegisterPlacementGroup(pg1, cb); + gcs_placement_group_manager_->RegisterPlacementGroup(pg2, cb); + auto &pending_queue = gcs_placement_group_manager_->pending_placement_groups_; + ASSERT_EQ(2, pending_queue.size()); + put_cb(Status::OK()); + ASSERT_EQ(1, pending_queue.size()); + // PG1 is scheduled first, so PG2 is in pending queue + ASSERT_EQ(pg2, pending_queue.begin()->second.second); + failure_callback(pg1, true); + ASSERT_EQ(2, pending_queue.size()); + gcs_placement_group_manager_->SchedulePendingPlacementGroups(); + // PG2 is scheduled for the next, so PG1 is in pending queue + ASSERT_EQ(1, pending_queue.size()); + ASSERT_EQ(pg1, pending_queue.begin()->second.second); +} + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc index 7c941aa27f815..8eeed97f7eca6 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc @@ -22,6 +22,7 @@ #include "ray/gcs/test/gcs_test_util.h" namespace ray { +namespace gcs { using ::testing::_; using StatusCallback = std::function; @@ -135,6 +136,8 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { EXPECT_TRUE(WaitForCondition(condition, 10 * 1000)); } + ExponentialBackOff GetExpBackOff() { return ExponentialBackOff(0, 1); } + std::shared_ptr mock_placement_group_scheduler_; std::unique_ptr gcs_placement_group_manager_; std::unordered_map job_namespace_table_; @@ -148,6 +151,26 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { std::shared_ptr redis_client_; }; +TEST_F(GcsPlacementGroupManagerTest, TestPlacementGroupBundleCache) { + auto request = Mocker::GenCreatePlacementGroupRequest(); + std::atomic registered_placement_group_count(0); + RegisterPlacementGroup(request, + [®istered_placement_group_count](const Status &status) { + ++registered_placement_group_count; + }); + ASSERT_EQ(registered_placement_group_count, 1); + WaitForExpectedPgCount(1); + auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); + ASSERT_TRUE(placement_group->cached_bundle_specs_.empty()); + // Fill the cache and verify it. + const auto &bundle_specs = placement_group->GetBundles(); + ASSERT_EQ(placement_group->cached_bundle_specs_, bundle_specs); + ASSERT_FALSE(placement_group->cached_bundle_specs_.empty()); + // Invalidate the cache and verify it. + RAY_UNUSED(placement_group->GetMutableBundle(0)); + ASSERT_TRUE(placement_group->cached_bundle_specs_.empty()); +} + TEST_F(GcsPlacementGroupManagerTest, TestBasic) { auto request = Mocker::GenCreatePlacementGroupRequest(); std::atomic registered_placement_group_count(0); @@ -176,7 +199,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestSchedulingFailed) { auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.clear(); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), true); gcs_placement_group_manager_->SchedulePendingPlacementGroups(); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_.size(), 1); mock_placement_group_scheduler_->placement_groups_.clear(); @@ -240,7 +264,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeAdd) { mock_placement_group_scheduler_->placement_groups_.pop_back(); // If the creation of placement group fails, it will be rescheduled after a short time. - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), true); WaitForExpectedPgCount(1); } @@ -255,7 +280,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestRemovingPendingPlacementGroup) { auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.clear(); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), true); ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::PENDING); const auto &placement_group_id = placement_group->GetPlacementGroupID(); gcs_placement_group_manager_->RemovePlacementGroup(placement_group_id, @@ -291,7 +317,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestRemovingLeasingPlacementGroup) { gcs_placement_group_manager_->RemovePlacementGroup(placement_group_id, [](const Status &status) {}); ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::REMOVED); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), true); // Make sure it is not rescheduled gcs_placement_group_manager_->SchedulePendingPlacementGroups(); @@ -354,7 +381,6 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeDead) { placement_group->GetMutableBundle(0)->set_node_id(NodeID::FromRandom().Binary()); placement_group->GetMutableBundle(1)->set_node_id(NodeID::FromRandom().Binary()); mock_placement_group_scheduler_->placement_groups_.pop_back(); - // If a node dies, we will set the bundles above it to be unplaced and reschedule the // placement group. The placement group state is set to `RESCHEDULING` and will be // scheduled first. @@ -373,14 +399,15 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeDead) { placement_group->GetPlacementGroupID()); const auto &bundles = mock_placement_group_scheduler_->placement_groups_[0]->GetBundles(); - EXPECT_TRUE(NodeID::FromBinary(bundles[0]->GetMutableMessage().node_id()).IsNil()); - EXPECT_FALSE(NodeID::FromBinary(bundles[1]->GetMutableMessage().node_id()).IsNil()); + EXPECT_TRUE(NodeID::FromBinary(bundles[0]->GetMessage().node_id()).IsNil()); + EXPECT_FALSE(NodeID::FromBinary(bundles[1]->GetMessage().node_id()).IsNil()); // If `RESCHEDULING` placement group fails to create, we will schedule it again first. placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.pop_back(); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_.size(), 0); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), true); WaitForExpectedPgCount(1); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_[0]->GetPlacementGroupID(), placement_group->GetPlacementGroupID()); @@ -526,7 +553,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestSchedulingCanceledWhenPgIsInfeasible) { mock_placement_group_scheduler_->placement_groups_.clear(); // Mark it non-retryable. - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, false); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), false); // Schedule twice to make sure it will not be scheduled afterward. gcs_placement_group_manager_->SchedulePendingPlacementGroups(); @@ -607,6 +635,7 @@ TEST_F(GcsPlacementGroupManagerTest, TestRayNamespace) { } } +} // namespace gcs } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index cbe1ba78495f4..5d265ac1bbb59 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -33,6 +33,7 @@ class GcsServerTest : public ::testing::Test { config.grpc_server_name = "MockedGcsServer"; config.grpc_server_thread_num = 1; config.redis_address = "127.0.0.1"; + config.node_ip_address = "127.0.0.1"; config.enable_sharding_conn = false; config.redis_port = TEST_REDIS_SERVER_PORTS.front(); gcs_server_.reset(new gcs::GcsServer(config, io_service_)); diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h index 249ac5a9fdd6a..11f0783bb8465 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -70,6 +70,10 @@ struct GcsServerMocker { return Status::OK(); } + void ReportWorkerBacklog( + const WorkerID &worker_id, + const std::vector &backlog_reports) override {} + /// WorkerLeaseInterface void RequestWorkerLease( const ray::TaskSpecification &resource_spec, @@ -79,6 +83,14 @@ struct GcsServerMocker { callbacks.push_back(callback); } + void RequestWorkerLease( + const rpc::TaskSpec &spec, + const rpc::ClientCallback &callback, + const int64_t backlog_size = -1) override { + num_workers_requested += 1; + callbacks.push_back(callback); + } + /// WorkerLeaseInterface void ReleaseUnusedWorkers( const std::vector &workers_in_use, @@ -180,7 +192,7 @@ struct GcsServerMocker { /// ResourceReserveInterface void CancelResourceReserve( - BundleSpecification &bundle_spec, + const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) override { num_return_requested += 1; diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.h b/src/ray/gcs/pubsub/gcs_pub_sub.h index b871a02b13ddd..70828a3679691 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.h +++ b/src/ray/gcs/pubsub/gcs_pub_sub.h @@ -96,6 +96,9 @@ class GcsPubSub { std::string DebugString() const; + protected: + GcsPubSub() : GcsPubSub(nullptr) {} + private: /// Represents a caller's command to subscribe or unsubscribe to a given /// channel. diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index c7244aac80549..443c42f9dee69 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -15,7 +15,7 @@ #pragma once #include -#include +#include #include #include #include diff --git a/src/ray/object_manager/object_buffer_pool.cc b/src/ray/object_manager/object_buffer_pool.cc index e6e214b3062f2..b25439cd7203c 100644 --- a/src/ray/object_manager/object_buffer_pool.cc +++ b/src/ray/object_manager/object_buffer_pool.cc @@ -14,6 +14,7 @@ #include "ray/object_manager/object_buffer_pool.h" +#include "absl/time/time.h" #include "ray/common/status.h" #include "ray/util/logging.h" @@ -21,26 +22,49 @@ namespace ray { ObjectBufferPool::ObjectBufferPool(const std::string &store_socket_name, uint64_t chunk_size) - : default_chunk_size_(chunk_size) { - store_socket_name_ = store_socket_name; + : store_socket_name_(store_socket_name), default_chunk_size_(chunk_size) { RAY_CHECK_OK(store_client_.Connect(store_socket_name_.c_str(), "", 0, 300)); } ObjectBufferPool::~ObjectBufferPool() { - // Abort everything in progress. - auto create_buf_state_copy = create_buffer_state_; - for (const auto &pair : create_buf_state_copy) { - AbortCreate(pair.first); + absl::MutexLock lock(&pool_mutex_); + auto inflight_ops = create_buffer_ops_; + pool_mutex_.Unlock(); + + for (const auto &[id, cond_var] : inflight_ops) { + cond_var->SignalAll(); + } + auto no_inflight = [this]() { + pool_mutex_.AssertReaderHeld(); + return create_buffer_ops_.empty(); + }; + // Assume no request would arrive, acquire pool_mutex_ when there is no inflight + // operation. Otherwise print an error. + if (!pool_mutex_.LockWhenWithTimeout(absl::Condition(&no_inflight), absl::Seconds(5))) { + RAY_LOG(ERROR) + << create_buffer_ops_.size() << " remaining inflight create buffer operations " + << "during ObjectBufferPool destruction. Either abort these operations before " + << "destroying ObjectBufferPool, or refactor ObjectBufferPool to make it " + "unnecessary to wait for the operations' completion."; } + + // Abort unfinished buffers in progress. + for (auto it = create_buffer_state_.begin(); it != create_buffer_state_.end(); it++) { + RAY_CHECK_OK(store_client_.Release(it->first)); + RAY_CHECK_OK(store_client_.Abort(it->first)); + create_buffer_state_.erase(it); + } + RAY_CHECK(create_buffer_state_.empty()); RAY_CHECK_OK(store_client_.Disconnect()); } -uint64_t ObjectBufferPool::GetNumChunks(uint64_t data_size) { +uint64_t ObjectBufferPool::GetNumChunks(uint64_t data_size) const { return (data_size + default_chunk_size_ - 1) / default_chunk_size_; } -uint64_t ObjectBufferPool::GetBufferLength(uint64_t chunk_index, uint64_t data_size) { +uint64_t ObjectBufferPool::GetBufferLength(uint64_t chunk_index, + uint64_t data_size) const { return (chunk_index + 1) * default_chunk_size_ > data_size ? data_size % default_chunk_size_ : default_chunk_size_; @@ -49,7 +73,7 @@ uint64_t ObjectBufferPool::GetBufferLength(uint64_t chunk_index, uint64_t data_s std::pair, ray::Status> ObjectBufferPool::CreateObjectReader(const ObjectID &object_id, rpc::Address owner_address) { - std::lock_guard lock(pool_mutex_); + absl::MutexLock lock(&pool_mutex_); std::vector object_ids{object_id}; std::vector object_buffers(1); @@ -76,53 +100,21 @@ ray::Status ObjectBufferPool::CreateChunk(const ObjectID &object_id, const rpc::Address &owner_address, uint64_t data_size, uint64_t metadata_size, uint64_t chunk_index) { - std::unique_lock lock(pool_mutex_); - if (create_buffer_state_.count(object_id) == 0) { - int64_t object_size = data_size - metadata_size; - // Try to create shared buffer. - std::shared_ptr data; - - // Release the buffer pool lock during the blocking create call. - lock.unlock(); - Status s = store_client_.CreateAndSpillIfNeeded( - object_id, owner_address, object_size, NULL, metadata_size, &data, - plasma::flatbuf::ObjectSource::ReceivedFromRemoteRaylet); - lock.lock(); - - // Another thread may have succeeded in creating the chunk while the lock - // was released. In that case skip the remainder of the creation block. - if (create_buffer_state_.count(object_id) == 0) { - std::vector buffer; - if (!s.ok()) { - // Create failed. The object may already exist locally. If something else went - // wrong, another chunk will succeed in creating the buffer, and this - // chunk will eventually make it here via pull requests. - return ray::Status::IOError(s.message()); - } - // Read object into store. - uint8_t *mutable_data = data->Data(); - uint64_t num_chunks = GetNumChunks(data_size); - create_buffer_state_.emplace( - std::piecewise_construct, std::forward_as_tuple(object_id), - std::forward_as_tuple(BuildChunks(object_id, mutable_data, data_size, data))); - RAY_LOG(DEBUG) << "Created object " << object_id - << " in plasma store, number of chunks: " << num_chunks - << ", chunk index: " << chunk_index; - RAY_CHECK(create_buffer_state_[object_id].chunk_info.size() == num_chunks); - } - } - if (create_buffer_state_[object_id].chunk_state[chunk_index] != - CreateChunkState::AVAILABLE) { + absl::MutexLock lock(&pool_mutex_); + RAY_RETURN_NOT_OK(EnsureBufferExists(object_id, owner_address, data_size, metadata_size, + chunk_index)); + auto &state = create_buffer_state_.at(object_id); + if (state.chunk_state[chunk_index] != CreateChunkState::AVAILABLE) { // There can be only one reference to this chunk at any given time. return ray::Status::IOError("Chunk already received by a different thread."); } - create_buffer_state_[object_id].chunk_state[chunk_index] = CreateChunkState::REFERENCED; + state.chunk_state[chunk_index] = CreateChunkState::REFERENCED; return ray::Status::OK(); } void ObjectBufferPool::WriteChunk(const ObjectID &object_id, const uint64_t chunk_index, const std::string &data) { - std::lock_guard lock(pool_mutex_); + absl::MutexLock lock(&pool_mutex_); auto it = create_buffer_state_.find(object_id); if (it == create_buffer_state_.end() || it->second.chunk_state.at(chunk_index) != CreateChunkState::REFERENCED) { @@ -148,7 +140,7 @@ void ObjectBufferPool::WriteChunk(const ObjectID &object_id, const uint64_t chun } void ObjectBufferPool::AbortCreate(const ObjectID &object_id) { - std::lock_guard lock(pool_mutex_); + absl::MutexLock lock(&pool_mutex_); auto it = create_buffer_state_.find(object_id); if (it != create_buffer_state_.end()) { RAY_LOG(INFO) << "Not enough memory to create requested object " << object_id @@ -179,13 +171,84 @@ std::vector ObjectBufferPool::BuildChunks( return chunks; } +ray::Status ObjectBufferPool::EnsureBufferExists(const ObjectID &object_id, + const rpc::Address &owner_address, + uint64_t data_size, + uint64_t metadata_size, + uint64_t chunk_index) { + while (true) { + // Buffer for object_id already exists. + if (create_buffer_state_.contains(object_id)) { + return ray::Status::OK(); + } + + auto it = create_buffer_ops_.find(object_id); + if (it == create_buffer_ops_.end()) { + // No inflight create buffer operation, proceed to start one. + break; + } + + auto cond_var = it->second; + // Release pool_mutex_ while waiting, until the current inflight create buffer + // operation finishes. + cond_var->Wait(&pool_mutex_); + } + + // Indicate that there is an inflight create buffer operation, by inserting into + // create_buffer_ops_. + RAY_CHECK( + create_buffer_ops_.insert({object_id, std::make_shared()}).second); + const int64_t object_size = + static_cast(data_size) - static_cast(metadata_size); + std::shared_ptr data; + + // Release pool_mutex_ during the blocking create call. + pool_mutex_.Unlock(); + Status s = store_client_.CreateAndSpillIfNeeded( + object_id, owner_address, static_cast(object_size), nullptr, + static_cast(metadata_size), &data, + plasma::flatbuf::ObjectSource::ReceivedFromRemoteRaylet); + pool_mutex_.Lock(); + + // No other thread could have created the buffer. + RAY_CHECK(!create_buffer_state_.contains(object_id)); + + // Remove object_id from create_buffer_ops_ to indicate to the waiting ops that the + // inflight operation has finished. Wake up waiters so they can either start another + // create buffer op, or proceed after the buffer has been created. + { + auto it = create_buffer_ops_.find(object_id); + it->second->SignalAll(); + create_buffer_ops_.erase(it); + } + + if (!s.ok()) { + // Create failed. Buffer creation will be tried by another chunk. + // And this chunk will eventually make it here via retried pull requests. + return ray::Status::IOError(s.message()); + } + + // Read object into store. + uint8_t *mutable_data = data->Data(); + uint64_t num_chunks = GetNumChunks(data_size); + create_buffer_state_.emplace( + std::piecewise_construct, std::forward_as_tuple(object_id), + std::forward_as_tuple(BuildChunks(object_id, mutable_data, data_size, data))); + RAY_CHECK(create_buffer_state_[object_id].chunk_info.size() == num_chunks); + RAY_LOG(DEBUG) << "Created object " << object_id + << " in plasma store, number of chunks: " << num_chunks + << ", chunk index: " << chunk_index; + + return ray::Status::OK(); +} + void ObjectBufferPool::FreeObjects(const std::vector &object_ids) { - std::lock_guard lock(pool_mutex_); + absl::MutexLock lock(&pool_mutex_); RAY_CHECK_OK(store_client_.Delete(object_ids)); } std::string ObjectBufferPool::DebugString() const { - std::lock_guard lock(pool_mutex_); + absl::MutexLock lock(&pool_mutex_); std::stringstream result; result << "BufferPool:"; result << "\n- create buffer state map size: " << create_buffer_state_.size(); diff --git a/src/ray/object_manager/object_buffer_pool.h b/src/ray/object_manager/object_buffer_pool.h index 05c51e5e00117..b2722a3eceecc 100644 --- a/src/ray/object_manager/object_buffer_pool.h +++ b/src/ray/object_manager/object_buffer_pool.h @@ -16,12 +16,14 @@ #include #include -#include +#include #include #include -#include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/object_manager/memory_object_reader.h" @@ -68,14 +70,14 @@ class ObjectBufferPool { /// /// \param data_size The size of the object + metadata. /// \return The number of chunks into which the object will be split. - uint64_t GetNumChunks(uint64_t data_size); + uint64_t GetNumChunks(uint64_t data_size) const; /// Computes the buffer length of a chunk of an object. /// /// \param chunk_index The chunk index for which to obtain the buffer length. /// \param data_size The size of the object + metadata. /// \return The buffer length of the chunk at chunk_index. - uint64_t GetBufferLength(uint64_t chunk_index, uint64_t data_size); + uint64_t GetBufferLength(uint64_t chunk_index, uint64_t data_size) const; /// Returns an object reader for read. /// @@ -85,7 +87,7 @@ class ObjectBufferPool { /// this method. An IOError status is returned if the Get call on the plasma store /// fails, and the MemoryObjectReader will be empty. std::pair, ray::Status> CreateObjectReader( - const ObjectID &object_id, rpc::Address owner_address); + const ObjectID &object_id, rpc::Address owner_address) LOCKS_EXCLUDED(pool_mutex_); /// Returns a chunk of an empty object at the given chunk_index. The object chunk /// serves as the buffer that is to be written to by a connection receiving an @@ -106,7 +108,7 @@ class ObjectBufferPool { /// (with no intermediate AbortCreateChunk). ray::Status CreateChunk(const ObjectID &object_id, const rpc::Address &owner_address, uint64_t data_size, uint64_t metadata_size, - uint64_t chunk_index); + uint64_t chunk_index) LOCKS_EXCLUDED(pool_mutex_); /// Write to a Chunk of an object. If all chunks of an object is written, /// it seals the object. @@ -119,34 +121,44 @@ class ObjectBufferPool { /// \param chunk_index The index of the chunk. /// \param data The data to write into the chunk. void WriteChunk(const ObjectID &object_id, uint64_t chunk_index, - const std::string &data); + const std::string &data) LOCKS_EXCLUDED(pool_mutex_); /// Free a list of objects from object store. /// /// \param object_ids the The list of ObjectIDs to be deleted. /// \return Void. - void FreeObjects(const std::vector &object_ids); + void FreeObjects(const std::vector &object_ids) LOCKS_EXCLUDED(pool_mutex_); /// Abort the create operation associated with an object. This destroys the buffer /// state, including create operations in progress for all chunks of the object. - void AbortCreate(const ObjectID &object_id); + void AbortCreate(const ObjectID &object_id) LOCKS_EXCLUDED(pool_mutex_); /// Returns debug string for class. /// /// \return string. - std::string DebugString() const; + std::string DebugString() const LOCKS_EXCLUDED(pool_mutex_); private: /// Splits an object into ceil(data_size/chunk_size) chunks, which will /// either be read or written to in parallel. std::vector BuildChunks(const ObjectID &object_id, uint8_t *data, uint64_t data_size, - std::shared_ptr buffer_ref); + std::shared_ptr buffer_ref) + EXCLUSIVE_LOCKS_REQUIRED(pool_mutex_); + + /// Ensures buffer for the object exists, and creates the buffer if needed. + /// Returns OK if buffer exists. + /// Must hold pool_mutex_ when calling this function. pool_mutex_ can be released + /// during the call. + ray::Status EnsureBufferExists(const ObjectID &object_id, + const rpc::Address &owner_address, uint64_t data_size, + uint64_t metadata_size, uint64_t chunk_index) + EXCLUSIVE_LOCKS_REQUIRED(pool_mutex_); /// The state of a chunk associated with a create operation. enum class CreateChunkState : unsigned int { AVAILABLE = 0, REFERENCED, SEALED }; - /// Holds the state of a create buffer. + /// Holds the state of creating chunks. Members are protected by pool_mutex_. struct CreateBufferState { CreateBufferState() {} CreateBufferState(std::vector chunk_info) @@ -166,18 +178,29 @@ class ObjectBufferPool { /// Returned when GetChunk or CreateChunk fails. const ChunkInfo errored_chunk_ = {0, nullptr, 0, nullptr}; - /// Mutex on public methods for thread-safe operations on - /// get_buffer_state_, create_buffer_state_, and store_client_. - mutable std::mutex pool_mutex_; + /// Socket name of plasma store. + const std::string store_socket_name_; + /// Determines the maximum chunk size to be transferred by a single thread. const uint64_t default_chunk_size_; + + /// Mutex to protect create_buffer_ops_, create_buffer_state_ and following invariants: + /// - create_buffer_ops_ contains an object_id iff there is an inflight operation to + /// create the buffer for the object. + /// - An object_id cannot appear in both create_buffer_ops_ and create_buffer_state_. + mutable absl::Mutex pool_mutex_; + /// Makes sure each object has at most one inflight create buffer operation. + /// Other operations can wait on the std::condition_variable for the operation + /// to complete. If successful, the corresponding entry in create_buffer_state_ + /// will be created. + absl::flat_hash_map> create_buffer_ops_ + GUARDED_BY(pool_mutex_); /// The state of a buffer that's currently being used. - std::unordered_map create_buffer_state_; + absl::flat_hash_map create_buffer_state_ + GUARDED_BY(pool_mutex_); /// Plasma client pool. plasma::PlasmaClient store_client_; - /// Socket name of plasma store. - std::string store_socket_name_; }; } // namespace ray diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 8e4dd703b91fb..3ee951d75553d 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -88,6 +88,7 @@ ObjectManager::ObjectManager( buffer_pool_(config_.store_socket_name, config_.object_chunk_size), rpc_work_(rpc_service_), object_manager_server_("ObjectManager", config_.object_manager_port, + config_.object_manager_address == "127.0.0.1", config_.rpc_service_threads_number), object_manager_service_(rpc_service_, *this), client_call_manager_(main_service, config_.rpc_service_threads_number), @@ -441,17 +442,18 @@ void ObjectManager::PushObjectInternal(const ObjectID &object_id, const NodeID & [=]() { // Post to the multithreaded RPC event loop so that data is copied // off of the main thread. - SendObjectChunk(push_id, object_id, node_id, chunk_id, rpc_client, - [=](const Status &status) { - // Post back to the main event loop because the - // PushManager is thread-safe. - main_service_->post( - [this, node_id, object_id]() { - push_manager_->OnChunkComplete(node_id, object_id); - }, - "ObjectManager.Push"); - }, - std::move(chunk_reader)); + SendObjectChunk( + push_id, object_id, node_id, chunk_id, rpc_client, + [=](const Status &status) { + // Post back to the main event loop because the + // PushManager is thread-safe. + main_service_->post( + [this, node_id, object_id]() { + push_manager_->OnChunkComplete(node_id, object_id); + }, + "ObjectManager.Push"); + }, + chunk_reader); }, "ObjectManager.Push"); }); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 3aaa847f03381..c0519a38306bd 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include #include @@ -49,6 +49,8 @@ namespace ray { struct ObjectManagerConfig { + /// The IP address this object manager is running on. + std::string object_manager_address; /// The port that the object manager should use to listen for connections /// from other object managers. If this is 0, the object manager will choose /// its own port. @@ -56,7 +58,7 @@ struct ObjectManagerConfig { /// The object manager's global timer frequency. unsigned int timer_freq_ms; /// The time in milliseconds to wait before retrying a pull - /// that fails due to node id lookup. + /// that failed. unsigned int pull_timeout_ms; /// Object chunk size, in bytes uint64_t object_chunk_size; diff --git a/src/ray/object_manager/plasma/store.cc b/src/ray/object_manager/plasma/store.cc index ff9e98ddb765c..0b8b24dbac56d 100644 --- a/src/ray/object_manager/plasma/store.cc +++ b/src/ray/object_manager/plasma/store.cc @@ -32,7 +32,7 @@ #include #include -#include +#include #include #include #include @@ -53,6 +53,7 @@ #include "ray/object_manager/plasma/protocol.h" #include "ray/util/util.h" +namespace ph = boost::placeholders; namespace fb = plasma::flatbuf; namespace plasma { @@ -297,7 +298,9 @@ void PlasmaStore::ConnectClient(const boost::system::error_code &error) { if (!error) { // Accept a new local client and dispatch it to the node manager. auto new_connection = Client::Create( - boost::bind(&PlasmaStore::ProcessMessage, this, _1, _2, _3), std::move(socket_)); + // NOLINTNEXTLINE : handler must be of boost::AcceptHandler type. + boost::bind(&PlasmaStore::ProcessMessage, this, ph::_1, ph::_2, ph::_3), + std::move(socket_)); } // We're ready to accept another client. DoAccept(); diff --git a/src/ray/object_manager/pull_manager.h b/src/ray/object_manager/pull_manager.h index 9b7f20c14a478..6c5108f111abe 100644 --- a/src/ray/object_manager/pull_manager.h +++ b/src/ray/object_manager/pull_manager.h @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include "absl/container/flat_hash_map.h" diff --git a/src/ray/protobuf/agent_manager.proto b/src/ray/protobuf/agent_manager.proto index f573f53766525..cbbd127004536 100644 --- a/src/ray/protobuf/agent_manager.proto +++ b/src/ray/protobuf/agent_manager.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index dd9cf403c305c..1d3dd8124484d 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; @@ -155,10 +156,12 @@ message RayException { /// The runtime environment describes all the runtime packages needed to /// run some task or actor. message RuntimeEnv { - /// The raw json passed from user - string raw_json = 1; - /// Uris used in this runtime env + /// The serialized runtime env passed from the user. + string serialized_runtime_env = 1; + /// URIs used in this runtime env. These will be used for reference counting. repeated string uris = 2; + /// Indicates whether to install runtime env eagerly before the workers are leased. + bool runtime_env_eager_install = 3; } /// The task specification encapsulates all immutable information about the @@ -209,21 +212,19 @@ message TaskSpec { int64 placement_group_bundle_index = 19; // Whether or not this task should capture parent's placement group automatically. bool placement_group_capture_child_tasks = 20; - // Environment variables to override for this task - map override_environment_variables = 21; // Whether or not to skip the execution of this task. When it's true, // the receiver will not execute the task. This field is used by async actors // to guarantee task submission order after restart. - bool skip_execution = 22; + bool skip_execution = 21; // Breakpoint if this task should drop into the debugger when it starts executing // and "" if the task should not drop into the debugger. - bytes debugger_breakpoint = 23; - // Serialized JSON string of the parsed runtime environment dict for this task. - string serialized_runtime_env = 24; + bytes debugger_breakpoint = 22; + // Runtime environment for this task. + RuntimeEnv runtime_env = 23; // The concurrency group name in which this task will be performed. - string concurrency_group_name = 25; + string concurrency_group_name = 24; // Whether application-level errors (exceptions) should be retried. - bool retry_exceptions = 26; + bool retry_exceptions = 25; } message Bundle { diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 9af0a87231326..81a8fbb5fd3d2 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/event.proto b/src/ray/protobuf/event.proto index 2edc202776f6b..5ec8ee9402492 100644 --- a/src/ray/protobuf/event.proto +++ b/src/ray/protobuf/event.proto @@ -1,4 +1,5 @@ syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index ec1f3e7380d53..5f35c1a21e4d5 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; @@ -149,19 +150,17 @@ message ActorTableData { RayException creation_task_exception = 18; // The actor's namespace. Named `ray_namespace` to avoid confusions when invoked in c++. string ray_namespace = 19; - // Runtime required to run this actor - // It'll only be set if it's a detached actor and the original job has this field - RuntimeEnv runtime_env = 20; // The unix ms timestamp the actor was started at. - uint64 start_time = 21; + uint64 start_time = 20; // The unix ms timestamp the actor was ended at. - uint64 end_time = 22; + uint64 end_time = 21; + // Serialized runtime_env used to report in the dashboard snapshot. We need to populate + // it here instead of grabbing it from the task spec because the task spec is cleared + // for deleted actors: https://github.com/ray-project/ray/pull/11149. + string serialized_runtime_env = 22; // The actor's class name. This is necessary because the task spec's lifetime // is shorter than the ActorTableData. string class_name = 23; - // The actor's serialized runtime environment. This is necessary because the - // task spec's lifetime is shorter than the ActorTableData. - string serialized_runtime_env = 24; } message ErrorTableData { @@ -278,24 +277,20 @@ message TaskLeaseData { } message JobConfig { - // Environment variables to be set on worker processes. - map worker_env = 1; // The number of java workers per worker process. - uint32 num_java_workers_per_process = 2; + uint32 num_java_workers_per_process = 1; // The jvm options for java workers of the job. - repeated string jvm_options = 3; + repeated string jvm_options = 2; // A list of directories or files (jar files or dynamic libraries) that specify the // search path for user code. This will be used as `CLASSPATH` in Java, and `PYTHONPATH` // in Python. In C++, libraries under these paths will be loaded by 'dlopen'. - repeated string code_search_path = 4; + repeated string code_search_path = 3; // Runtime environment to run the code - RuntimeEnv runtime_env = 5; + RuntimeEnv runtime_env = 4; // The job's namespace. Named `ray_namespace` to avoid confusions when invoked in c++. - string ray_namespace = 6; - // Serialized JSON string of the parsed runtime environment dict for this job. - string serialized_runtime_env = 7; + string ray_namespace = 5; // An opaque kv store for job related metadata. - map metadata = 8; + map metadata = 6; } message JobTableData { diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 308083f201208..65e9bbad13bc3 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -13,7 +13,7 @@ // limitations under the License. syntax = "proto3"; - +option cc_enable_arenas = true; package ray.rpc; import "src/ray/protobuf/common.proto"; diff --git a/src/ray/protobuf/job_agent.proto b/src/ray/protobuf/job_agent.proto index 07355a0a8f7c0..e187de67ae0f5 100644 --- a/src/ray/protobuf/job_agent.proto +++ b/src/ray/protobuf/job_agent.proto @@ -15,6 +15,7 @@ syntax = "proto3"; package ray.rpc; +option cc_enable_arenas = true; import "src/ray/protobuf/agent_manager.proto"; diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index 0c56bb7832b3a..0331369528753 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -13,12 +13,31 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; import "src/ray/protobuf/common.proto"; import "src/ray/protobuf/gcs.proto"; +message WorkerBacklogReport { + // TaskSpec indicating the scheduling class. + // Cannot send scheduling class directly + // since it's local to each process. + TaskSpec resource_spec = 1; + // Size of the backlog for the above scheduling class. + int64 backlog_size = 2; +} + +message ReportWorkerBacklogRequest { + // Unique id of the worker that's reporting the backlog + bytes worker_id = 1; + // Backlog report per scheduling class + repeated WorkerBacklogReport backlog_reports = 2; +} + +message ReportWorkerBacklogReply {} + // Request a worker from the raylet with the specified resources. message RequestWorkerLeaseRequest { // TaskSpec containing the requested resources. @@ -254,6 +273,8 @@ service NodeManagerService { returns (RequestResourceReportReply); // Request a worker from the raylet. rpc RequestWorkerLease(RequestWorkerLeaseRequest) returns (RequestWorkerLeaseReply); + // Report task backlog information from a worker to the raylet + rpc ReportWorkerBacklog(ReportWorkerBacklogRequest) returns (ReportWorkerBacklogReply); // Release a worker back to its raylet. rpc ReturnWorker(ReturnWorkerRequest) returns (ReturnWorkerReply); // This method is only used by GCS, and the purpose is to release leased workers diff --git a/src/ray/protobuf/object_manager.proto b/src/ray/protobuf/object_manager.proto index 8bd6986f6b5b1..c212b18b266d1 100644 --- a/src/ray/protobuf/object_manager.proto +++ b/src/ray/protobuf/object_manager.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/pubsub.proto b/src/ray/protobuf/pubsub.proto index fc046afcf69c2..8181f886ffb3c 100644 --- a/src/ray/protobuf/pubsub.proto +++ b/src/ray/protobuf/pubsub.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index e207263e515a7..5dab0499d7d56 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; @@ -61,8 +62,6 @@ message ClientTask { // A name parameter, if the payload can be called in more than one way // (like a method on a payload object). string name = 2; - // A namespace parameter. - string namespace = 9; // A reference to the payload. bytes payload_id = 3; // Positional parameters to pass to this call. @@ -76,6 +75,8 @@ message ClientTask { TaskOptions options = 7; // Options passed to create the default remote task excution environment. TaskOptions baseline_options = 8; + // A namespace parameter. + string namespace = 9; } message ClientTaskTicket { diff --git a/src/ray/protobuf/reporter.proto b/src/ray/protobuf/reporter.proto index 225c520481cc5..00849c0683960 100644 --- a/src/ray/protobuf/reporter.proto +++ b/src/ray/protobuf/reporter.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/runtime_env_agent.proto b/src/ray/protobuf/runtime_env_agent.proto index a7903f8939c91..f36adf38cdb2a 100644 --- a/src/ray/protobuf/runtime_env_agent.proto +++ b/src/ray/protobuf/runtime_env_agent.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; @@ -21,6 +22,10 @@ import "src/ray/protobuf/agent_manager.proto"; message CreateRuntimeEnvRequest { string serialized_runtime_env = 1; bytes job_id = 2; + // Serialized allocated resource instances. Key is resource type, value is allocated + // instances. For example,{"CPU":20000,"memory":40000,"GPU":[10000, 10000]} means 2 cpu + // cores, 2 Gi memory, GPU 0 and GPU 1. + string serialized_allocated_resource_instances = 3; } message CreateRuntimeEnvReply { diff --git a/src/ray/protobuf/serialization.proto b/src/ray/protobuf/serialization.proto index e5fed8e4a3876..84da8dff1531c 100644 --- a/src/ray/protobuf/serialization.proto +++ b/src/ray/protobuf/serialization.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.serialization; diff --git a/src/ray/protobuf/serve.proto b/src/ray/protobuf/serve.proto index 24e755a0b883a..2636dcf685544 100644 --- a/src/ray/protobuf/serve.proto +++ b/src/ray/protobuf/serve.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.serve; @@ -31,19 +32,17 @@ message AutoscalingConfig { uint32 max_replicas = 2; // Target number of in flight requests per replicas. This is the primary configuration // knob for replica autoscaler. Lower the number, the more rapidly will the replicas - // being scaled up. Must be a non-negative inter. + // being scaled up. Must be a non-negative integer. uint32 target_num_ongoing_requests_per_replica = 3; // The frequency of how long does each replica sending metrics to autoscaler. double metrics_interval_s = 4; - // The interval (in seconds) of autoscaler evaluating metrics and performing scaling - // decision. - double loop_period_s = 5; + // The window (in seconds) for autoscaler to calculate rolling average of metrics on. - double look_back_period_s = 6; + double look_back_period_s = 5; // The multiplicative "gain" factor to limit scaling decisions. - double smoothing_factor = 7; + double smoothing_factor = 6; } // Configuration options for a backend, to be set by the user. @@ -62,11 +61,11 @@ message BackendConfig { // Duration that backend workers will wait until there is no more work to be done before // shutting down. Defaults to 2s. - double experimental_graceful_shutdown_wait_loop_s = 4; + double graceful_shutdown_wait_loop_s = 4; // Controller waits for this duration to forcefully kill the replica for shutdown. // Defaults to 20s. - double experimental_graceful_shutdown_timeout_s = 5; + double graceful_shutdown_timeout_s = 5; // Is the construction of backend is cross language? bool is_cross_language = 6; @@ -95,3 +94,35 @@ message RequestMetadata { message RequestWrapper { bytes body = 1; } + +message UpdatedObject { + bytes object_snapshot = 1; + int32 snapshot_id = 2; +} + +message LongPollRequest { + map keys_to_snapshot_ids = 1; +} + +message LongPollResult { + map updated_objects = 1; +} + +message EndpointInfo { + string endpoint_name = 1; + string route = 2; + map config = 3; +} + +message EndpointSet { + map endpoints = 1; +} + +message ActorSet { + repeated string names = 1; +} + +message BackendVersion { + string code_version = 1; + bytes user_config = 2; +} diff --git a/src/ray/ray_version_script.lds b/src/ray/ray_version_script.lds index 6d53de5ed92d1..b18b99d675dfa 100644 --- a/src/ray/ray_version_script.lds +++ b/src/ray/ray_version_script.lds @@ -39,7 +39,6 @@ VERSION_1.0 { *ray*streaming*; *aligned_free*; *aligned_malloc*; - *absl*; *grpc*; local: *; }; diff --git a/src/ray/raylet/agent_manager.cc b/src/ray/raylet/agent_manager.cc index 55fe1392f6686..ec15c27a85be4 100644 --- a/src/ray/raylet/agent_manager.cc +++ b/src/ray/raylet/agent_manager.cc @@ -36,6 +36,8 @@ void AgentManager::HandleRegisterAgent(const rpc::RegisterAgentRequest &request, RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << agent_ip_address_ << ", port: " << agent_port_ << ", pid: " << agent_pid_; reply->set_status(rpc::AGENT_RPC_STATUS_OK); + // Reset the restart count after registration is done. + agent_restart_count_ = 0; send_reply_callback(ray::Status::OK(), nullptr, nullptr); } @@ -65,14 +67,16 @@ void AgentManager::StartAgent() { ProcessEnvironment env; env.insert({"RAY_NODE_ID", options_.node_id.Hex()}); env.insert({"RAY_RAYLET_PID", std::to_string(getpid())}); + // Report the restart count to the agent so that we can decide whether or not + // report the error message to drivers. + env.insert({"RESTART_COUNT", std::to_string(agent_restart_count_)}); + env.insert({"MAX_RESTART_COUNT", + std::to_string(RayConfig::instance().agent_max_restart_count())}); Process child(argv.data(), nullptr, ec, false, env); if (!child.IsValid() || ec) { // The worker failed to start. This is a fatal error. RAY_LOG(FATAL) << "Failed to start agent with return value " << ec << ": " << ec.message(); - RAY_UNUSED(delay_executor_([this] { StartAgent(); }, - RayConfig::instance().agent_restart_interval_ms())); - return; } std::thread monitor_thread([this, child]() mutable { @@ -101,22 +105,39 @@ void AgentManager::StartAgent() { .WithField("pid", agent_pid_) << "Agent process with pid " << child.GetId() << " exit, return value " << exit_code; - RAY_UNUSED(delay_executor_([this] { StartAgent(); }, - RayConfig::instance().agent_restart_interval_ms())); + if (agent_restart_count_ < RayConfig::instance().agent_max_restart_count()) { + RAY_UNUSED(delay_executor_( + [this] { + agent_restart_count_++; + StartAgent(); + }, + // Retrying with exponential backoff + RayConfig::instance().agent_restart_interval_ms() * + std::pow(2, (agent_restart_count_ + 1)))); + } else { + RAY_LOG(INFO) << "Agent has failed " + << RayConfig::instance().agent_max_restart_count() + << " times in a row without registering the agent. This is highly " + "likely there's a bug in the dashboard agent. Please check out " + "the dashboard_agent.log file."; + } }); monitor_thread.detach(); } -void AgentManager::CreateRuntimeEnv(const JobID &job_id, - const std::string &serialized_runtime_env, - CreateRuntimeEnvCallback callback) { +void AgentManager::CreateRuntimeEnv( + const JobID &job_id, const std::string &serialized_runtime_env, + const std::string &serialized_allocated_resource_instances, + CreateRuntimeEnvCallback callback) { if (runtime_env_agent_client_ == nullptr) { RAY_LOG(INFO) << "Runtime env agent is not registered yet. Will retry CreateRuntimeEnv later: " << serialized_runtime_env; delay_executor_( - [this, job_id, serialized_runtime_env, callback] { - CreateRuntimeEnv(job_id, serialized_runtime_env, callback); + [this, job_id, serialized_runtime_env, serialized_allocated_resource_instances, + callback] { + CreateRuntimeEnv(job_id, serialized_runtime_env, + serialized_allocated_resource_instances, callback); }, RayConfig::instance().agent_manager_retry_interval_ms()); return; @@ -124,9 +145,12 @@ void AgentManager::CreateRuntimeEnv(const JobID &job_id, rpc::CreateRuntimeEnvRequest request; request.set_job_id(job_id.Hex()); request.set_serialized_runtime_env(serialized_runtime_env); + request.set_serialized_allocated_resource_instances( + serialized_allocated_resource_instances); runtime_env_agent_client_->CreateRuntimeEnv( - request, [this, job_id, serialized_runtime_env, callback]( - Status status, const rpc::CreateRuntimeEnvReply &reply) { + request, + [this, job_id, serialized_runtime_env, serialized_allocated_resource_instances, + callback](const Status &status, const rpc::CreateRuntimeEnvReply &reply) { if (status.ok()) { if (reply.status() == rpc::AGENT_RPC_STATUS_OK) { callback(true, reply.serialized_runtime_env_context()); @@ -142,8 +166,10 @@ void AgentManager::CreateRuntimeEnv(const JobID &job_id, << ", status = " << status << ", maybe there are some network problems, will retry it later."; delay_executor_( - [this, job_id, serialized_runtime_env, callback] { - CreateRuntimeEnv(job_id, serialized_runtime_env, callback); + [this, job_id, serialized_runtime_env, + serialized_allocated_resource_instances, callback] { + CreateRuntimeEnv(job_id, serialized_runtime_env, + serialized_allocated_resource_instances, callback); }, RayConfig::instance().agent_manager_retry_interval_ms()); } diff --git a/src/ray/raylet/agent_manager.h b/src/ray/raylet/agent_manager.h index bb12df0f64da4..ba81454b84536 100644 --- a/src/ray/raylet/agent_manager.h +++ b/src/ray/raylet/agent_manager.h @@ -64,9 +64,10 @@ class AgentManager : public rpc::AgentManagerServiceHandler { /// Request agent to create a runtime env. /// \param[in] runtime_env The runtime env. - virtual void CreateRuntimeEnv(const JobID &job_id, - const std::string &serialized_runtime_env, - CreateRuntimeEnvCallback callback); + virtual void CreateRuntimeEnv( + const JobID &job_id, const std::string &serialized_runtime_env, + const std::string &serialized_allocated_resource_instances, + CreateRuntimeEnvCallback callback); /// Request agent to delete a list of URIs. /// \param[in] URIs The list of URIs to delete. @@ -80,6 +81,8 @@ class AgentManager : public rpc::AgentManagerServiceHandler { Options options_; pid_t agent_pid_ = 0; int agent_port_ = 0; + /// The number of times the agent is restarted. + std::atomic agent_restart_count_ = 0; std::string agent_ip_address_; DelayExecutorFn delay_executor_; RuntimeEnvAgentClientFactoryFn runtime_env_agent_client_factory_; diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index aa096b3f1e86b..93655b7501d1e 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -212,6 +212,7 @@ int main(int argc, char *argv[]) { // Configuration for the object manager. ray::ObjectManagerConfig object_manager_config; + object_manager_config.object_manager_address = node_ip_address; object_manager_config.object_manager_port = object_manager_port; object_manager_config.store_socket_name = store_socket_name; @@ -244,7 +245,7 @@ int main(int argc, char *argv[]) { // Initialize stats. const ray::stats::TagsType global_tags = { {ray::stats::ComponentKey, "raylet"}, - {ray::stats::VersionKey, "2.0.0.dev0"}, + {ray::stats::VersionKey, kRayVersion}, {ray::stats::NodeAddressKey, node_ip_address}}; ray::stats::Init(global_tags, metrics_agent_port); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 4260542319060..eb0e1f7cadc37 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -252,7 +252,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self temp_dir_(config.temp_dir), initial_config_(config), dependency_manager_(object_manager_), - node_manager_server_("NodeManager", config.node_manager_port), + node_manager_server_("NodeManager", config.node_manager_port, + config.node_manager_address == "127.0.0.1"), node_manager_service_(io_service, *this), agent_manager_service_handler_( new DefaultAgentManagerServiceHandler(agent_manager_)), @@ -372,7 +373,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self }, /*runtime_env_agent_factory=*/ [this](const std::string &ip_address, int port) { - RAY_CHECK(!ip_address.empty() && port != 0); + RAY_CHECK(!ip_address.empty() && port != 0) + << "ip_address: " << ip_address << " port: " << port; return std::shared_ptr( new rpc::RuntimeEnvAgentClient(ip_address, port, client_call_manager_)); }); @@ -525,7 +527,7 @@ void NodeManager::DestroyWorker(std::shared_ptr worker, } void NodeManager::HandleJobStarted(const JobID &job_id, const JobTableData &job_data) { - RAY_LOG(DEBUG) << "HandleJobStarted " << job_id; + RAY_LOG(DEBUG) << "HandleJobStarted for job " << job_id; worker_pool_.HandleJobStarted(job_id, job_data.config()); // NOTE: Technically `HandleJobStarted` isn't idempotent because we'll // increment the ref count multiple times. This is fine because @@ -1255,6 +1257,8 @@ void NodeManager::DisconnectClient( // Return the resources that were being used by this worker. cluster_task_manager_->ReleaseWorkerResources(worker); + cluster_task_manager_->ClearWorkerBacklog(worker->WorkerId()); + // Since some resources may have been released, we can try to dispatch more tasks. cluster_task_manager_->ScheduleAndDispatchTasks(); } else if (is_driver) { @@ -1500,19 +1504,28 @@ void NodeManager::HandleRequestResourceReport( send_reply_callback(Status::OK(), nullptr, nullptr); } +void NodeManager::HandleReportWorkerBacklog( + const rpc::ReportWorkerBacklogRequest &request, rpc::ReportWorkerBacklogReply *reply, + rpc::SendReplyCallback send_reply_callback) { + const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); + cluster_task_manager_->ClearWorkerBacklog(worker_id); + std::unordered_set seen; + for (const auto &backlog_report : request.backlog_reports()) { + const TaskSpecification resource_spec(backlog_report.resource_spec()); + const SchedulingClass scheduling_class = resource_spec.GetSchedulingClass(); + RAY_CHECK(seen.find(scheduling_class) == seen.end()); + cluster_task_manager_->SetWorkerBacklog(scheduling_class, worker_id, + backlog_report.backlog_size()); + } + send_reply_callback(Status::OK(), nullptr, nullptr); +} + void NodeManager::HandleRequestWorkerLease(const rpc::RequestWorkerLeaseRequest &request, rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) { rpc::Task task_message; task_message.mutable_task_spec()->CopyFrom(request.resource_spec()); - auto backlog_size = -1; - if (RayConfig::instance().report_worker_backlog()) { - // We add 1 to the backlog size because we need a worker to fulfill the - // current request, as well as workers to serve the requests in the - // backlog. - backlog_size = request.backlog_size() + 1; - } - RayTask task(task_message, backlog_size); + RayTask task(task_message); bool is_actor_creation_task = task.GetTaskSpecification().IsActorCreationTask(); ActorID actor_id = ActorID::Nil(); metrics_num_task_scheduled_ += 1; @@ -1662,7 +1675,7 @@ void NodeManager::HandleReturnWorker(const rpc::ReturnWorkerRequest &request, if (worker->IsBlocked()) { HandleDirectCallTaskUnblocked(worker); } - cluster_task_manager_->ReturnWorkerResources(worker); + cluster_task_manager_->ReleaseWorkerResources(worker); HandleWorkerAvailable(worker); } } else { @@ -1868,7 +1881,8 @@ void NodeManager::FinishAssignedActorCreationTask(WorkerInterface &worker, auto job_id = task.GetTaskSpecification().JobId(); auto job_config = worker_pool_.GetJobConfig(job_id); RAY_CHECK(job_config); - runtime_env_manager_.AddURIReference(actor_id.Hex(), job_config->runtime_env()); + runtime_env_manager_.AddURIReference(actor_id.Hex(), + task.GetTaskSpecification().RuntimeEnv()); } } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index a699635c439f7..e8fb4e3050254 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -48,7 +48,6 @@ namespace ray { namespace raylet { -using rpc::ActorTableData; using rpc::ErrorType; using rpc::GcsNodeInfo; using rpc::HeartbeatTableData; @@ -273,13 +272,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// returned to idle. bool FinishAssignedTask(const std::shared_ptr &worker_ptr); - /// Helper function to produce actor table data for a newly created actor. - /// - /// \param task_spec RayTask specification of the actor creation task that created the - /// actor. - /// \param worker The port that the actor is listening on. - std::shared_ptr CreateActorTableDataFromCreationTask( - const TaskSpecification &task_spec, int port, const WorkerID &worker_id); /// Handle a worker finishing an assigned actor creation task. /// \param worker The worker that finished the task. /// \param task The actor task or actor creation task. @@ -495,6 +487,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler { rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) override; + /// Handle a `ReportWorkerBacklog` request. + void HandleReportWorkerBacklog(const rpc::ReportWorkerBacklogRequest &request, + rpc::ReportWorkerBacklogReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Handle a `ReturnWorker` request. void HandleReturnWorker(const rpc::ReturnWorkerRequest &request, rpc::ReturnWorkerReply *reply, diff --git a/src/ray/raylet/placement_group_resource_manager.cc b/src/ray/raylet/placement_group_resource_manager.cc index 8639689edb949..d9ccfd1ac0574 100644 --- a/src/ray/raylet/placement_group_resource_manager.cc +++ b/src/ray/raylet/placement_group_resource_manager.cc @@ -152,6 +152,9 @@ void NewPlacementGroupResourceManager::ReturnBundle( // will be resource leak. cluster_resource_scheduler_->DeleteLocalResource(resource.first); deleted.push_back(resource.first); + } else { + RAY_LOG(DEBUG) << "Available bundle resource:[" << resource.first + << "] is not empty. Resources are not deleted from the local node."; } } pg_bundles_.erase(it); diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index b8040b6f8acdc..c2f431b20027d 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -15,7 +15,7 @@ #include "ray/raylet/raylet.h" #include -#include +#include #include #include @@ -61,7 +61,10 @@ Raylet::Raylet(instrumented_io_context &main_service, const std::string &socket_ const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client, int metrics_export_port) : main_service_(main_service), - self_node_id_(NodeID::FromRandom()), + self_node_id_( + !RayConfig::instance().OVERRIDE_NODE_ID_FOR_TESTING().empty() + ? NodeID::FromHex(RayConfig::instance().OVERRIDE_NODE_ID_FOR_TESTING()) + : NodeID::FromRandom()), gcs_client_(gcs_client), node_manager_(main_service, self_node_id_, node_manager_config, object_manager_config, gcs_client_), diff --git a/src/ray/raylet/scheduling/cluster_resource_data.cc b/src/ray/raylet/scheduling/cluster_resource_data.cc index ea4ae6621f6b5..f19287d0915f5 100644 --- a/src/ray/raylet/scheduling/cluster_resource_data.cc +++ b/src/ray/raylet/scheduling/cluster_resource_data.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/raylet/scheduling/cluster_resource_data.h" + #include "ray/common/bundle_spec.h" #include "ray/common/task/scheduling_resources.h" @@ -536,7 +537,7 @@ bool TaskResourceInstances::IsEmpty() const { return true; } -std::string TaskResourceInstances::DebugString() const { +std::string TaskResourceInstances::DebugString(const StringIdMap &string_id_map) const { std::stringstream buffer; buffer << std::endl << " Allocation: {"; for (size_t i = 0; i < this->predefined_resources.size(); i++) { @@ -547,7 +548,7 @@ std::string TaskResourceInstances::DebugString() const { buffer << " ["; for (auto it = this->custom_resources.begin(); it != this->custom_resources.end(); ++it) { - buffer << it->first << ":" << VectorToString(it->second) << ", "; + buffer << string_id_map.Get(it->first) << ":" << VectorToString(it->second) << ", "; } buffer << "]" << std::endl; diff --git a/src/ray/raylet/scheduling/cluster_resource_data.h b/src/ray/raylet/scheduling/cluster_resource_data.h index 0398726f39d42..783ab12da9eee 100644 --- a/src/ray/raylet/scheduling/cluster_resource_data.h +++ b/src/ray/raylet/scheduling/cluster_resource_data.h @@ -138,7 +138,7 @@ class TaskResourceInstances { /// Check whether there are no resource instances. bool IsEmpty() const; /// Returns human-readable string for these resources. - std::string DebugString() const; + [[nodiscard]] std::string DebugString(const StringIdMap &string_id_map) const; }; /// Total and available capacities of each resource of a node. @@ -189,7 +189,7 @@ class NodeResourceInstances { /// Returns if this equals another node resources. bool operator==(const NodeResourceInstances &other); /// Returns human-readable string for these resources. - std::string DebugString(StringIdMap string_to_int_map) const; + [[nodiscard]] std::string DebugString(StringIdMap string_to_int_map) const; }; struct Node { diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc index 6fcff8a501c55..1174f138395e0 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc @@ -456,8 +456,7 @@ void ClusterResourceScheduler::AddLocalResourceInstances( for (size_t i = 0; i < instances.size(); i++) { node_instances->available[i] += instances[i]; - node_instances->total[i] = - std::max(node_instances->total[i], node_instances->available[i]); + node_instances->total[i] += instances[i]; } UpdateLocalAvailableResourcesFromResourceInstances(); } diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index 1b90e93fb1bf4..14a999d3be4e4 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -48,7 +48,6 @@ ClusterTaskManager::ClusterTaskManager( announce_infeasible_task_(announce_infeasible_task), max_resource_shapes_per_load_report_( RayConfig::instance().max_resource_shapes_per_load_report()), - report_worker_backlog_(RayConfig::instance().report_worker_backlog()), worker_pool_(worker_pool), leased_workers_(leased_workers), get_task_arguments_(get_task_arguments), @@ -426,7 +425,6 @@ void ClusterTaskManager::QueueAndScheduleTask( } else { tasks_to_schedule_[scheduling_class].push_back(work); } - AddToBacklogTracker(task); ScheduleAndDispatchTasks(); } @@ -563,12 +561,6 @@ void ClusterTaskManager::ReleaseTaskArgs(const TaskID &task_id) { } } -void ClusterTaskManager::ReturnWorkerResources(std::shared_ptr worker) { - // TODO(Shanly): This method will be removed and can be replaced by - // `ReleaseWorkerResources` directly once we remove the legacy scheduler. - ReleaseWorkerResources(worker); -} - void ReplyCancelled(std::shared_ptr &work, bool runtime_env_setup_failed) { auto reply = work->reply; auto callback = work->callback; @@ -587,7 +579,6 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id, for (auto work_it = work_queue.begin(); work_it != work_queue.end(); work_it++) { const auto &task = (*work_it)->task; if (task.GetTaskSpecification().TaskId() == task_id) { - RemoveFromBacklogTracker(task); RAY_LOG(DEBUG) << "Canceling task " << task_id << " from schedule queue."; ReplyCancelled(*work_it, runtime_env_setup_failed); work_queue.erase(work_it); @@ -604,7 +595,6 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id, for (auto work_it = work_queue.begin(); work_it != work_queue.end(); work_it++) { const auto &task = (*work_it)->task; if (task.GetTaskSpecification().TaskId() == task_id) { - RemoveFromBacklogTracker(task); RAY_LOG(DEBUG) << "Canceling task " << task_id << " from dispatch queue."; ReplyCancelled(*work_it, runtime_env_setup_failed); if ((*work_it)->status == WorkStatus::WAITING_FOR_WORKER) { @@ -634,7 +624,6 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id, for (auto work_it = work_queue.begin(); work_it != work_queue.end(); work_it++) { const auto &task = (*work_it)->task; if (task.GetTaskSpecification().TaskId() == task_id) { - RemoveFromBacklogTracker(task); RAY_LOG(DEBUG) << "Canceling task " << task_id << " from infeasible queue."; ReplyCancelled(*work_it, runtime_env_setup_failed); work_queue.erase(work_it); @@ -649,7 +638,6 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id, auto iter = waiting_tasks_index_.find(task_id); if (iter != waiting_tasks_index_.end()) { const auto &task = (*iter->second)->task; - RemoveFromBacklogTracker(task); ReplyCancelled(*iter->second, runtime_env_setup_failed); if (!task.GetTaskSpecification().GetDependencies().empty()) { task_dependency_manager_.RemoveTaskDependencies( @@ -716,36 +704,35 @@ void ClusterTaskManager::FillResourceUsage( TaskSpecification::GetSchedulingClass(one_cpu_resource_set)); { num_reported++; - int count = 0; + int ready_count = 0; auto it = tasks_to_schedule_.find(one_cpu_scheduling_cls); if (it != tasks_to_schedule_.end()) { - count += it->second.size(); + ready_count += it->second.size(); } it = tasks_to_dispatch_.find(one_cpu_scheduling_cls); if (it != tasks_to_dispatch_.end()) { - count += it->second.size(); + ready_count += it->second.size(); } - - if (count > 0) { + int infeasible_count = 0; + it = infeasible_tasks_.find(one_cpu_scheduling_cls); + if (it != infeasible_tasks_.end()) { + infeasible_count += it->second.size(); + } + const int total_count = ready_count + infeasible_count; + if (total_count > 0) { auto by_shape_entry = resource_load_by_shape->Add(); - for (const auto &resource : one_cpu_resource_set.GetResourceMap()) { + for (const auto &[label, quantity] : one_cpu_resource_set.GetResourceMap()) { // Add to `resource_loads`. - const auto &label = resource.first; - const auto &quantity = resource.second; - (*resource_loads)[label] += quantity * count; + (*resource_loads)[label] += quantity * total_count; // Add to `resource_load_by_shape`. (*by_shape_entry->mutable_shape())[label] = quantity; } - int num_ready = by_shape_entry->num_ready_requests_queued(); - by_shape_entry->set_num_ready_requests_queued(num_ready + count); - - auto backlog_it = backlog_tracker_.find(one_cpu_scheduling_cls); - if (backlog_it != backlog_tracker_.end()) { - by_shape_entry->set_backlog_size(backlog_it->second); - } + by_shape_entry->set_num_ready_requests_queued(ready_count); + by_shape_entry->set_num_infeasible_requests_queued(infeasible_count); + by_shape_entry->set_backlog_size(TotalBacklogSize(one_cpu_scheduling_cls)); } } @@ -783,10 +770,7 @@ void ClusterTaskManager::FillResourceUsage( // ClusterResourceScheduler::GetBestSchedulableNode for more details. int num_ready = by_shape_entry->num_ready_requests_queued(); by_shape_entry->set_num_ready_requests_queued(num_ready + count); - auto backlog_it = backlog_tracker_.find(scheduling_class); - if (backlog_it != backlog_tracker_.end()) { - by_shape_entry->set_backlog_size(backlog_it->second); - } + by_shape_entry->set_backlog_size(TotalBacklogSize(scheduling_class)); } for (const auto &pair : tasks_to_dispatch_) { @@ -819,10 +803,7 @@ void ClusterTaskManager::FillResourceUsage( } int num_ready = by_shape_entry->num_ready_requests_queued(); by_shape_entry->set_num_ready_requests_queued(num_ready + count); - auto backlog_it = backlog_tracker_.find(scheduling_class); - if (backlog_it != backlog_tracker_.end()) { - by_shape_entry->set_backlog_size(backlog_it->second); - } + by_shape_entry->set_backlog_size(TotalBacklogSize(scheduling_class)); } for (const auto &pair : infeasible_tasks_) { @@ -858,10 +839,7 @@ void ClusterTaskManager::FillResourceUsage( // ClusterResourceScheduler::GetBestSchedulableNode for more details. int num_infeasible = by_shape_entry->num_infeasible_requests_queued(); by_shape_entry->set_num_infeasible_requests_queued(num_infeasible + count); - auto backlog_it = backlog_tracker_.find(scheduling_class); - if (backlog_it != backlog_tracker_.end()) { - by_shape_entry->set_backlog_size(backlog_it->second); - } + by_shape_entry->set_backlog_size(TotalBacklogSize(scheduling_class)); } if (RayConfig::instance().enable_light_weight_resource_report()) { @@ -1015,7 +993,6 @@ void ClusterTaskManager::Dispatch( RAY_CHECK(leased_workers.find(worker->WorkerId()) == leased_workers.end()); leased_workers[worker->WorkerId()] = worker; - RemoveFromBacklogTracker(task); // Update our internal view of the cluster state. std::shared_ptr allocated_resources; @@ -1071,7 +1048,6 @@ void ClusterTaskManager::Spillback(const NodeID &spillback_to, metric_tasks_spilled_++; const auto &task = work->task; const auto &task_spec = task.GetTaskSpecification(); - RemoveFromBacklogTracker(task); RAY_LOG(DEBUG) << "Spilling task " << task_spec.TaskId() << " to node " << spillback_to; if (!cluster_resource_scheduler_->AllocateRemoteTaskResources( @@ -1098,23 +1074,44 @@ void ClusterTaskManager::Spillback(const NodeID &spillback_to, send_reply_callback(); } -void ClusterTaskManager::AddToBacklogTracker(const RayTask &task) { - if (report_worker_backlog_) { - auto cls = task.GetTaskSpecification().GetSchedulingClass(); - backlog_tracker_[cls] += task.BacklogSize(); +void ClusterTaskManager::ClearWorkerBacklog(const WorkerID &worker_id) { + for (auto it = backlog_tracker_.begin(); it != backlog_tracker_.end();) { + it->second.erase(worker_id); + if (it->second.empty()) { + it = backlog_tracker_.erase(it); + } else { + ++it; + } } } -void ClusterTaskManager::RemoveFromBacklogTracker(const RayTask &task) { - if (report_worker_backlog_) { - SchedulingClass cls = task.GetTaskSpecification().GetSchedulingClass(); - backlog_tracker_[cls] -= task.BacklogSize(); - if (backlog_tracker_[cls] == 0) { - backlog_tracker_.erase(backlog_tracker_.find(cls)); +void ClusterTaskManager::SetWorkerBacklog(SchedulingClass scheduling_class, + const WorkerID &worker_id, + int64_t backlog_size) { + if (backlog_size == 0) { + backlog_tracker_[scheduling_class].erase(worker_id); + if (backlog_tracker_[scheduling_class].empty()) { + backlog_tracker_.erase(scheduling_class); } + } else { + backlog_tracker_[scheduling_class][worker_id] = backlog_size; } } +int64_t ClusterTaskManager::TotalBacklogSize(SchedulingClass scheduling_class) { + auto backlog_it = backlog_tracker_.find(scheduling_class); + if (backlog_it == backlog_tracker_.end()) { + return 0; + } + + int64_t sum = 0; + for (const auto &worker_id_and_backlog_size : backlog_it->second) { + sum += worker_id_and_backlog_size.second; + } + + return sum; +} + void ClusterTaskManager::ReleaseWorkerResources(std::shared_ptr worker) { RAY_CHECK(worker != nullptr); auto allocated_instances = worker->GetAllocatedInstances(); @@ -1196,8 +1193,6 @@ void ClusterTaskManager::ScheduleAndDispatchTasks() { } void ClusterTaskManager::SpillWaitingTasks() { - RAY_LOG(DEBUG) << "Attempting to spill back from waiting task queue, num waiting: " - << waiting_task_queue_.size(); // Try to spill waiting tasks to a remote node, prioritizing those at the end // of the queue. Waiting tasks are spilled if there are enough remote // resources AND (we have no resources available locally OR their diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index 57ee9aab80678..4259afc8d04c5 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -102,6 +102,11 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { get_task_arguments, size_t max_pinned_task_arguments_bytes); + void SetWorkerBacklog(SchedulingClass scheduling_class, const WorkerID &worker_id, + int64_t backlog_size) override; + + void ClearWorkerBacklog(const WorkerID &worker_id) override; + /// (Step 1) Queue tasks and schedule. /// Queue task and schedule. This hanppens when processing the worker lease request. /// @@ -125,13 +130,6 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { /// \param task: Output parameter. void TaskFinished(std::shared_ptr worker, RayTask *task) override; - /// Return worker resources. - /// This method will be removed and can be replaced by `ReleaseWorkerResources` directly - /// once we remove the legacy scheduler. - /// - /// \param worker: The worker which was running the task. - void ReturnWorkerResources(std::shared_ptr worker) override; - /// Attempt to cancel an already queued task. /// /// \param task_id: The id of the task to remove. @@ -261,7 +259,6 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { std::function announce_infeasible_task_; const int max_resource_shapes_per_load_report_; - const bool report_worker_backlog_; /// TODO(swang): Add index from TaskID -> Work to avoid having to iterate /// through queues to cancel tasks, etc. @@ -307,8 +304,9 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { std::unordered_map>> infeasible_tasks_; - /// Track the cumulative backlog of all workers requesting a lease to this raylet. - std::unordered_map backlog_tracker_; + /// Track the backlog of all workers belonging to this raylet. + std::unordered_map> + backlog_tracker_; /// TODO(Shanly): Remove `worker_pool_` and `leased_workers_` and make them as /// parameters of methods if necessary once we remove the legacy scheduler. @@ -360,8 +358,8 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { void Spillback(const NodeID &spillback_to, const std::shared_ptr &work); - void AddToBacklogTracker(const RayTask &task); - void RemoveFromBacklogTracker(const RayTask &task); + /// Sum up the backlog size across all workers for a given scheduling class. + int64_t TotalBacklogSize(SchedulingClass scheduling_class); // Helper function to pin a task's args immediately before dispatch. This // returns false if there are missing args (due to eviction) or if there is diff --git a/src/ray/raylet/scheduling/cluster_task_manager_interface.h b/src/ray/raylet/scheduling/cluster_task_manager_interface.h index 457daa4d7b320..71864f38df846 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_interface.h +++ b/src/ray/raylet/scheduling/cluster_task_manager_interface.h @@ -78,13 +78,6 @@ class ClusterTaskManagerInterface { /// \param task: Output parameter. virtual void TaskFinished(std::shared_ptr worker, RayTask *task) = 0; - /// Return worker resources. - /// This method will be removed and can be replaced by `ReleaseWorkerResources` directly - /// once we remove the legacy scheduler - /// - /// \param worker: The worker which was running the task. - virtual void ReturnWorkerResources(std::shared_ptr worker) = 0; - /// Attempt to cancel an already queued task. /// /// \param task_id: The id of the task to remove. @@ -96,6 +89,20 @@ class ClusterTaskManagerInterface { virtual bool CancelTask(const TaskID &task_id, bool runtime_env_setup_failed = false) = 0; + /// Set the worker backlog size for a particular scheduling class. + /// + /// \param scheduling_class: The scheduling class this backlog is for. + /// \param worker_id: The ID of the worker that owns the backlog information. + /// \param backlog_size: The size of the backlog. + virtual void SetWorkerBacklog(SchedulingClass scheduling_class, + const WorkerID &worker_id, int64_t backlog_size) = 0; + + /// Remove all backlog information about the given worker. + /// + /// \param worker_id: The ID of the worker owning the backlog information + /// that we want to remove. + virtual void ClearWorkerBacklog(const WorkerID &worker_id) = 0; + /// Queue task and schedule. This hanppens when processing the worker lease request. /// /// \param task: The incoming task to be queued and scheduled. diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index 78fe7320c8631..f19386ad4bab5 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -48,8 +48,7 @@ class MockWorkerPool : public WorkerPoolInterface { void PopWorker(const TaskSpecification &task_spec, const PopWorkerCallback &callback, const std::string &allocated_instances_serialized_json) { num_pops++; - const WorkerCacheKey env = { - task_spec.OverrideEnvironmentVariables(), task_spec.SerializedRuntimeEnv(), {}}; + const WorkerCacheKey env = {task_spec.SerializedRuntimeEnv(), {}}; const int runtime_env_hash = env.IntHash(); callbacks[runtime_env_hash].push_back(callback); } @@ -101,10 +100,11 @@ class MockWorkerPool : public WorkerPoolInterface { int num_pops; }; -std::shared_ptr CreateSingleNodeScheduler( - const std::string &id, double num_gpus = 0.0) { +std::shared_ptr CreateSingleNodeScheduler(const std::string &id, + double num_cpus, + double num_gpus) { std::unordered_map local_node_resources; - local_node_resources[ray::kCPU_ResourceLabel] = 8; + local_node_resources[ray::kCPU_ResourceLabel] = num_cpus; local_node_resources[ray::kGPU_ResourceLabel] = num_gpus; local_node_resources[ray::kMemory_ResourceLabel] = 128; @@ -116,16 +116,18 @@ std::shared_ptr CreateSingleNodeScheduler( RayTask CreateTask(const std::unordered_map &required_resources, int num_args = 0, std::vector args = {}, - std::string serialized_runtime_env = "{}") { + const std::string &serialized_runtime_env = "{}", + const std::vector &runtime_env_uris = {}) { TaskSpecBuilder spec_builder; TaskID id = RandomTaskId(); JobID job_id = RandomJobId(); rpc::Address address; - spec_builder.SetCommonTaskSpec( - id, "dummy_task", Language::PYTHON, - FunctionDescriptorBuilder::BuildPython("", "", "", ""), job_id, TaskID::Nil(), 0, - TaskID::Nil(), address, 0, required_resources, {}, - std::make_pair(PlacementGroupID::Nil(), -1), true, "", serialized_runtime_env); + spec_builder.SetCommonTaskSpec(id, "dummy_task", Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("", "", "", ""), + job_id, TaskID::Nil(), 0, TaskID::Nil(), address, 0, + required_resources, {}, + std::make_pair(PlacementGroupID::Nil(), -1), true, "", + serialized_runtime_env, runtime_env_uris); if (!args.empty()) { for (auto &arg : args) { @@ -177,39 +179,41 @@ class MockTaskDependencyManager : public TaskDependencyManagerInterface { class ClusterTaskManagerTest : public ::testing::Test { public: - ClusterTaskManagerTest(double num_gpus_at_head = 0.0) + ClusterTaskManagerTest(double num_cpus_at_head = 8.0, double num_gpus_at_head = 0.0) : id_(NodeID::FromRandom()), - scheduler_(CreateSingleNodeScheduler(id_.Binary(), num_gpus_at_head)), + scheduler_( + CreateSingleNodeScheduler(id_.Binary(), num_cpus_at_head, num_gpus_at_head)), is_owner_alive_(true), node_info_calls_(0), announce_infeasible_task_calls_(0), dependency_manager_(missing_objects_), - task_manager_(id_, scheduler_, dependency_manager_, - /* is_owner_alive= */ - [this](const WorkerID &worker_id, const NodeID &node_id) { - return is_owner_alive_; - }, - /* get_node_info= */ - [this](const NodeID &node_id) { - node_info_calls_++; - return node_info_[node_id]; - }, - /* announce_infeasible_task= */ - [this](const RayTask &task) { announce_infeasible_task_calls_++; }, - pool_, leased_workers_, - /* get_task_arguments= */ - [this](const std::vector &object_ids, - std::vector> *results) { - for (auto &obj_id : object_ids) { - if (missing_objects_.count(obj_id) == 0) { - results->emplace_back(MakeDummyArg()); - } else { - results->emplace_back(nullptr); - } - } - return true; - }, - /*max_pinned_task_arguments_bytes=*/1000) {} + task_manager_( + id_, scheduler_, dependency_manager_, + /* is_owner_alive= */ + [this](const WorkerID &worker_id, const NodeID &node_id) { + return is_owner_alive_; + }, + /* get_node_info= */ + [this](const NodeID &node_id) { + node_info_calls_++; + return node_info_[node_id]; + }, + /* announce_infeasible_task= */ + [this](const RayTask &task) { announce_infeasible_task_calls_++; }, pool_, + leased_workers_, + /* get_task_arguments= */ + [this](const std::vector &object_ids, + std::vector> *results) { + for (auto &obj_id : object_ids) { + if (missing_objects_.count(obj_id) == 0) { + results->emplace_back(MakeDummyArg()); + } else { + results->emplace_back(nullptr); + } + } + return true; + }, + /*max_pinned_task_arguments_bytes=*/1000) {} RayObject *MakeDummyArg() { std::vector data; @@ -287,7 +291,15 @@ class ClusterTaskManagerTest : public ::testing::Test { // Same as ClusterTaskManagerTest, but the head node starts with 4.0 num gpus. class ClusterTaskManagerTestWithGPUsAtHead : public ClusterTaskManagerTest { public: - ClusterTaskManagerTestWithGPUsAtHead() : ClusterTaskManagerTest(4.0) {} + ClusterTaskManagerTestWithGPUsAtHead() + : ClusterTaskManagerTest(/*num_cpus_at_head=*/8.0, /*num_gpus_at_head=*/4.0) {} +}; + +// Same as ClusterTaskManagerTest, but the head node starts with 0.0 num cpus. +class ClusterTaskManagerTestWithoutCPUsAtHead : public ClusterTaskManagerTest { + public: + ClusterTaskManagerTestWithoutCPUsAtHead() + : ClusterTaskManagerTest(/*num_cpus_at_head=*/0.0) {} }; TEST_F(ClusterTaskManagerTest, BasicTest) { @@ -367,8 +379,7 @@ TEST_F(ClusterTaskManagerTest, DispatchQueueNonBlockingTest) { pool_.TriggerCallbacks(); // Push a worker that can only run task A. - const WorkerCacheKey env_A = { - /*override_environment_variables=*/{}, serialized_runtime_env_A, {}}; + const WorkerCacheKey env_A = {serialized_runtime_env_A, {}}; const int runtime_env_hash_A = env_A.IntHash(); std::shared_ptr worker_A = std::make_shared(WorkerID::FromRandom(), 1234, runtime_env_hash_A); @@ -860,7 +871,7 @@ TEST_F(ClusterTaskManagerTest, HeartbeatTest) { TEST_F(ClusterTaskManagerTest, BacklogReportTest) { /* Test basic scheduler functionality: - 1. Queue and attempt to schedule/dispatch atest with no workers available + 1. Queue and attempt to schedule/dispatch a test with no workers available 2. A worker becomes available, dispatch again. */ rpc::RequestWorkerLeaseReply reply; @@ -873,18 +884,21 @@ TEST_F(ClusterTaskManagerTest, BacklogReportTest) { std::vector to_cancel; - // Don't add these fist 2 tasks to `to_cancel`. + const WorkerID worker_id_submitting_first_task = WorkerID::FromRandom(); + // Don't add the fist task to `to_cancel`. for (int i = 0; i < 1; i++) { RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 8}}); - task.SetBacklogSize(10 - i); task_manager_.QueueAndScheduleTask(task, &reply, callback); + task_manager_.SetWorkerBacklog(task.GetTaskSpecification().GetSchedulingClass(), + worker_id_submitting_first_task, 10 - i); pool_.TriggerCallbacks(); } for (int i = 1; i < 10; i++) { RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 8}}); - task.SetBacklogSize(10 - i); task_manager_.QueueAndScheduleTask(task, &reply, callback); + task_manager_.SetWorkerBacklog(task.GetTaskSpecification().GetSchedulingClass(), + WorkerID::FromRandom(), 10 - i); pool_.TriggerCallbacks(); to_cancel.push_back(task.GetTaskSpecification().TaskId()); } @@ -910,6 +924,7 @@ TEST_F(ClusterTaskManagerTest, BacklogReportTest) { std::make_shared(WorkerID::FromRandom(), 1234); pool_.PushWorker(worker); task_manager_.ScheduleAndDispatchTasks(); + task_manager_.ClearWorkerBacklog(worker_id_submitting_first_task); pool_.TriggerCallbacks(); { @@ -1525,6 +1540,50 @@ TEST_F(ClusterTaskManagerTest, PopWorkerExactlyOnce) { AssertNoLeaks(); } +// Regression test for https://github.com/ray-project/ray/issues/16935: +// When a task requires 1 CPU and is infeasible because head node has 0 CPU, +// make sure the task's resource demand is reported. +TEST_F(ClusterTaskManagerTestWithoutCPUsAtHead, OneCpuInfeasibleTask) { + rpc::RequestWorkerLeaseReply reply; + bool callback_occurred = false; + bool *callback_occurred_ptr = &callback_occurred; + auto callback = [callback_occurred_ptr](const Status &, const std::function &, + const std::function &) { + *callback_occurred_ptr = true; + }; + + constexpr int num_cases = 5; + // Create 5 tasks with different CPU requests. + const std::array cpu_request = {1, 2, 1, 3, 1}; + // Each type of CPU request corresponds to a types of resource demand. + const std::array demand_types = {1, 2, 2, 3, 3}; + // Number of infeasible 1 CPU requests.. + const std::array num_infeasible_1cpu = {1, 1, 2, 2, 3}; + + for (int i = 0; i < num_cases; ++i) { + RayTask task = CreateTask({{ray::kCPU_ResourceLabel, cpu_request[i]}}); + task_manager_.QueueAndScheduleTask(task, &reply, callback); + pool_.TriggerCallbacks(); + + // The task cannot run because there is only 1 node (head) with 0 CPU. + ASSERT_FALSE(callback_occurred); + ASSERT_EQ(leased_workers_.size(), 0); + ASSERT_EQ(pool_.workers.size(), 0); + ASSERT_EQ(node_info_calls_, 0); + + rpc::ResourcesData data; + task_manager_.FillResourceUsage(data); + const auto &resource_load_by_shape = data.resource_load_by_shape(); + ASSERT_EQ(resource_load_by_shape.resource_demands().size(), demand_types[i]); + + // 1 CPU demand currently is always the 1st. + const auto &demand = resource_load_by_shape.resource_demands()[0]; + EXPECT_EQ(demand.num_infeasible_requests_queued(), num_infeasible_1cpu[i]); + ASSERT_EQ(demand.shape().size(), 1); + ASSERT_EQ(demand.shape().at("CPU"), 1); + } +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/src/ray/raylet/scheduling/fixed_point.cc b/src/ray/raylet/scheduling/fixed_point.cc deleted file mode 100644 index ec0b3ed9af16d..0000000000000 --- a/src/ray/raylet/scheduling/fixed_point.cc +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2020-2021 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ray/raylet/scheduling/fixed_point.h" - -#include - -FixedPoint::FixedPoint(double d) { i_ = (uint64_t)(d * RESOURCE_UNIT_SCALING); } - -FixedPoint::FixedPoint(int i) { i_ = (i * RESOURCE_UNIT_SCALING); } - -FixedPoint::FixedPoint(uint32_t i) { i_ = (i * RESOURCE_UNIT_SCALING); } - -FixedPoint::FixedPoint(int64_t i) : FixedPoint((double)i) {} - -FixedPoint::FixedPoint(uint64_t i) : FixedPoint((double)i) {} - -FixedPoint FixedPoint::operator+(FixedPoint const &ru) const { - FixedPoint res; - res.i_ = i_ + ru.i_; - return res; -} - -FixedPoint FixedPoint::operator+=(FixedPoint const &ru) { - i_ += ru.i_; - return *this; -} - -FixedPoint FixedPoint::operator-(FixedPoint const &ru) const { - FixedPoint res; - res.i_ = i_ - ru.i_; - return res; -} - -FixedPoint FixedPoint::operator-=(FixedPoint const &ru) { - i_ -= ru.i_; - return *this; -} - -FixedPoint FixedPoint::operator-() const { - FixedPoint res; - res.i_ = -i_; - return res; -} - -FixedPoint FixedPoint::operator+(double const d) const { - FixedPoint res; - res.i_ = i_ + (int64_t)(d * RESOURCE_UNIT_SCALING); - return res; -} - -FixedPoint FixedPoint::operator-(double const d) const { - FixedPoint res; - res.i_ = i_ - (int64_t)(d * RESOURCE_UNIT_SCALING); - return res; -} - -FixedPoint FixedPoint::operator=(double const d) { - i_ = (int64_t)(d * RESOURCE_UNIT_SCALING); - return *this; -} - -FixedPoint FixedPoint::operator+=(double const d) { - i_ += (int64_t)(d * RESOURCE_UNIT_SCALING); - return *this; -} - -FixedPoint FixedPoint::operator+=(int64_t const ru) { - *this += (double)ru; - return *this; -} - -bool FixedPoint::operator<(FixedPoint const &ru1) const { return (i_ < ru1.i_); }; -bool FixedPoint::operator>(FixedPoint const &ru1) const { return (i_ > ru1.i_); }; -bool FixedPoint::operator<=(FixedPoint const &ru1) const { return (i_ <= ru1.i_); }; -bool FixedPoint::operator>=(FixedPoint const &ru1) const { return (i_ >= ru1.i_); }; -bool FixedPoint::operator==(FixedPoint const &ru1) const { return (i_ == ru1.i_); }; -bool FixedPoint::operator!=(FixedPoint const &ru1) const { return (i_ != ru1.i_); }; - -std::ostream &operator<<(std::ostream &out, FixedPoint const &ru1) { - out << ru1.i_; - return out; -} - -double FixedPoint::Double() const { return round(i_) / RESOURCE_UNIT_SCALING; }; diff --git a/src/ray/raylet/scheduling/fixed_point.h b/src/ray/raylet/scheduling/fixed_point.h index f133397ec6251..a18ffd1873218 100644 --- a/src/ray/raylet/scheduling/fixed_point.h +++ b/src/ray/raylet/scheduling/fixed_point.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include @@ -25,41 +26,85 @@ class FixedPoint { int64_t i_ = 0; public: - FixedPoint() = default; - FixedPoint(double d); - FixedPoint(int i); - FixedPoint(uint32_t i); - FixedPoint(int64_t i); - FixedPoint(uint64_t i); - - FixedPoint operator+(FixedPoint const &ru) const; - - FixedPoint operator+=(FixedPoint const &ru); - - FixedPoint operator-(FixedPoint const &ru) const; - - FixedPoint operator-=(FixedPoint const &ru); - - FixedPoint operator-() const; - - FixedPoint operator+(double const d) const; - - FixedPoint operator-(double const d) const; - - FixedPoint operator=(double const d); - - FixedPoint operator+=(double const d); - - FixedPoint operator+=(int64_t const ru); - - bool operator<(FixedPoint const &ru1) const; - bool operator>(FixedPoint const &ru1) const; - bool operator<=(FixedPoint const &ru1) const; - bool operator>=(FixedPoint const &ru1) const; - bool operator==(FixedPoint const &ru1) const; - bool operator!=(FixedPoint const &ru1) const; - - double Double() const; + FixedPoint() : FixedPoint(0.0) {} + FixedPoint(double d) { i_ = (uint64_t)(d * RESOURCE_UNIT_SCALING); } // NOLINT + + FixedPoint(int i) { i_ = (i * RESOURCE_UNIT_SCALING); } // NOLINT + + FixedPoint(uint32_t i) { i_ = (i * RESOURCE_UNIT_SCALING); } // NOLINT + + FixedPoint(int64_t i) : FixedPoint((double)i) {} // NOLINT + + FixedPoint(uint64_t i) : FixedPoint((double)i) {} // NOLINT + + FixedPoint operator+(FixedPoint const &ru) const { + FixedPoint res; + res.i_ = i_ + ru.i_; + return res; + } + + FixedPoint &operator+=(FixedPoint const &ru) { + i_ += ru.i_; + return *this; + } + + FixedPoint operator-(FixedPoint const &ru) const { + FixedPoint res; + res.i_ = i_ - ru.i_; + return res; + } + + FixedPoint &operator-=(FixedPoint const &ru) { + i_ -= ru.i_; + return *this; + } + + FixedPoint operator-() const { + FixedPoint res; + res.i_ = -i_; + return res; + } + + FixedPoint operator+(double const d) const { + FixedPoint res; + res.i_ = i_ + static_cast(d * RESOURCE_UNIT_SCALING); + return res; + } + + FixedPoint operator-(double const d) const { + FixedPoint res; + res.i_ = i_ + static_cast(d * RESOURCE_UNIT_SCALING); + return res; + } + + FixedPoint operator=(double const d) { + i_ = static_cast(d * RESOURCE_UNIT_SCALING); + return *this; + } + + FixedPoint operator+=(double const d) { + i_ += static_cast(d * RESOURCE_UNIT_SCALING); + return *this; + } + + FixedPoint operator+=(int64_t const ru) { + *this += static_cast(ru); + return *this; + } + + bool operator<(FixedPoint const &ru1) const { return (i_ < ru1.i_); }; + bool operator>(FixedPoint const &ru1) const { return (i_ > ru1.i_); }; + bool operator<=(FixedPoint const &ru1) const { return (i_ <= ru1.i_); }; + bool operator>=(FixedPoint const &ru1) const { return (i_ >= ru1.i_); }; + bool operator==(FixedPoint const &ru1) const { return (i_ == ru1.i_); }; + bool operator!=(FixedPoint const &ru1) const { return (i_ != ru1.i_); }; + + [[nodiscard]] double Double() const { return round(i_) / RESOURCE_UNIT_SCALING; }; friend std::ostream &operator<<(std::ostream &out, FixedPoint const &ru1); }; + +inline std::ostream &operator<<(std::ostream &out, FixedPoint const &ru1) { + out << ru1.i_; + return out; +} diff --git a/src/ray/raylet/scheduling/scheduling_policy.cc b/src/ray/raylet/scheduling/scheduling_policy.cc index 4bf28bdb75a21..40c1ca39605d8 100644 --- a/src/ray/raylet/scheduling/scheduling_policy.cc +++ b/src/ray/raylet/scheduling/scheduling_policy.cc @@ -57,7 +57,7 @@ int64_t HybridPolicyWithFilter(const ResourceRequest &resource_request, if (node_filter == NodeFilter::kGPU) { return has_gpu; } - RAY_CHECK(node_filter == NodeFilter::kCPUOnly); + RAY_CHECK(node_filter == NodeFilter::kNonGpu); return !has_gpu; }; @@ -149,16 +149,18 @@ int64_t HybridPolicy(const ResourceRequest &resource_request, const int64_t loca spread_threshold, force_spillback, require_available); } - // Try schedule on CPU-only nodes. - const auto node_id = - HybridPolicyWithFilter(resource_request, local_node_id, nodes, spread_threshold, - force_spillback, require_available, NodeFilter::kCPUOnly); - if (node_id != -1) { - return node_id; + // Try schedule on non-GPU nodes. + auto best_node_id = HybridPolicyWithFilter( + resource_request, local_node_id, nodes, spread_threshold, force_spillback, + /*require_available*/ true, NodeFilter::kNonGpu); + if (best_node_id != -1) { + return best_node_id; } - // Could not schedule on CPU-only nodes, schedule on GPU nodes as a last resort. + + // If we cannot find any available node from non-gpu nodes, fallback to the original + // scheduling return HybridPolicyWithFilter(resource_request, local_node_id, nodes, spread_threshold, - force_spillback, require_available, NodeFilter::kGPU); + force_spillback, require_available); } } // namespace raylet_scheduling_policy diff --git a/src/ray/raylet/scheduling/scheduling_policy.h b/src/ray/raylet/scheduling/scheduling_policy.h index b6f382ff1d078..b137491576690 100644 --- a/src/ray/raylet/scheduling/scheduling_policy.h +++ b/src/ray/raylet/scheduling/scheduling_policy.h @@ -62,8 +62,15 @@ int64_t HybridPolicy( bool force_spillback, bool require_available, bool scheduler_avoid_gpu_nodes = RayConfig::instance().scheduler_avoid_gpu_nodes()); -// -enum class NodeFilter { kAny, kGPU, kCPUOnly }; +enum class NodeFilter { + /// Default scheduling. + kAny, + /// Schedule on GPU only nodes. + kGPU, + /// Schedule on nodes that don't have GPU. Since GPUs are more scarce resources, we need + /// special handling for this. + kNonGpu +}; /// \param resource_request: The resource request we're attempting to schedule. /// \param local_node_id: The id of the local node, which is needed for traversal order. @@ -72,7 +79,7 @@ enum class NodeFilter { kAny, kGPU, kCPUOnly }; /// truncated to 0. /// \param node_filter: defines the subset of nodes were are allowed to schedule on. /// can be one of kAny (can schedule on all nodes), kGPU (can only schedule on kGPU -/// nodes), kCPUOnly (can only schedule on non-GPU nodes. +/// nodes), kNonGpu (can only schedule on non-GPU nodes. /// /// \return -1 if the task is unfeasible, otherwise the node id (key in `nodes`) to /// schedule on. diff --git a/src/ray/raylet/scheduling/scheduling_policy_test.cc b/src/ray/raylet/scheduling/scheduling_policy_test.cc index fb51d7f4c8711..6a834db1966e9 100644 --- a/src/ray/raylet/scheduling/scheduling_policy_test.cc +++ b/src/ray/raylet/scheduling/scheduling_policy_test.cc @@ -338,6 +338,42 @@ TEST_F(SchedulingPolicyTest, ForceSpillbackOnlyFeasibleLocallyTest) { ASSERT_EQ(to_schedule, -1); } +TEST_F(SchedulingPolicyTest, NonGpuNodePreferredSchedulingTest) { + // Prefer to schedule on CPU nodes first. + // GPU nodes should be preferred as a last resort. + StringIdMap map; + int64_t local_node = 0; + int64_t remote_node_1 = 1; + int64_t remote_node_2 = 2; + + // local {CPU:2, GPU:1} + // Remote {CPU: 2} + absl::flat_hash_map nodes; + nodes.emplace(local_node, CreateNodeResources(2, 2, 0, 0, 1, 1)); + nodes.emplace(remote_node_1, CreateNodeResources(2, 2, 0, 0, 0, 0)); + nodes.emplace(remote_node_2, CreateNodeResources(3, 3, 0, 0, 0, 0)); + + ResourceRequest req = ResourceMapToResourceRequest(map, {{"CPU", 1}}, false); + int to_schedule = raylet_scheduling_policy::HybridPolicy( + req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, remote_node_1); + + req = ResourceMapToResourceRequest(map, {{"CPU", 3}}, false); + to_schedule = raylet_scheduling_policy::HybridPolicy( + req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, remote_node_2); + + req = ResourceMapToResourceRequest(map, {{"CPU", 1}, {"GPU", 1}}, false); + to_schedule = raylet_scheduling_policy::HybridPolicy( + req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, local_node); + + req = ResourceMapToResourceRequest(map, {{"CPU", 2}}, false); + to_schedule = raylet_scheduling_policy::HybridPolicy( + req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, remote_node_1); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 08331a75f176d..fd2b7b723f755 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -14,7 +14,7 @@ #include "ray/raylet/worker.h" -#include +#include #include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/raylet.h" diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index f0268021280f8..959cc551f0dbc 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -176,7 +176,6 @@ Process WorkerPool::StartWorkerProcess( const Language &language, const rpc::WorkerType worker_type, const JobID &job_id, PopWorkerStatus *status, const std::vector &dynamic_options, const int runtime_env_hash, const std::string &serialized_runtime_env, - std::unordered_map override_environment_variables, const std::string &serialized_runtime_env_context, const std::string &allocated_instances_serialized_json) { rpc::JobConfig *job_config = nullptr; @@ -313,39 +312,41 @@ Process WorkerPool::StartWorkerProcess( // need to add a new CLI parameter for both Python and Java workers. env.emplace(kEnvVarKeyJobId, job_id.Hex()); } - if (job_config) { - env.insert(job_config->worker_env().begin(), job_config->worker_env().end()); - } - - for (const auto &pair : override_environment_variables) { - env[pair.first] = pair.second; - } - if (language == Language::PYTHON) { + if (language == Language::PYTHON || language == Language::JAVA) { if (serialized_runtime_env != "{}" && serialized_runtime_env != "") { worker_command_args.push_back("--serialized-runtime-env=" + serialized_runtime_env); // Allocated_resource_json is only used in "shim process". worker_command_args.push_back("--allocated-instances-serialized-json=" + allocated_instances_serialized_json); + + worker_command_args.push_back("--language=" + Language_Name(language)); + + worker_command_args.push_back("--runtime-env-hash=" + + std::to_string(runtime_env_hash)); + + if (serialized_runtime_env_context != "{}" && + !serialized_runtime_env_context.empty()) { + worker_command_args.push_back("--serialized-runtime-env-context=" + + serialized_runtime_env_context); + } } else { // The "shim process" setup worker is not needed, so do not run it. // Check that the arg really is the path to the setup worker before erasing it, to // prevent breaking tests that mock out the worker command args. if (worker_command_args.size() >= 2 && worker_command_args[1].find(kSetupWorkerFilename) != std::string::npos) { - worker_command_args.erase(worker_command_args.begin() + 1, - worker_command_args.begin() + 2); + if (language == Language::PYTHON) { + worker_command_args.erase(worker_command_args.begin() + 1, + worker_command_args.begin() + 2); + } else { + // Erase the python executable as well for other languages. + worker_command_args.erase(worker_command_args.begin(), + worker_command_args.begin() + 2); + } } } - worker_command_args.push_back("--runtime-env-hash=" + - std::to_string(runtime_env_hash)); - - if (serialized_runtime_env_context != "{}" && serialized_runtime_env_context != "") { - worker_command_args.push_back("--serialized-runtime-env-context=" + - serialized_runtime_env_context); - } - if (ray_debugger_external) { worker_command_args.push_back("--ray-debugger-external"); } @@ -483,6 +484,24 @@ void WorkerPool::MarkPortAsFree(int port) { void WorkerPool::HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config) { all_jobs_[job_id] = job_config; + if (job_config.runtime_env().runtime_env_eager_install() && + job_config.has_runtime_env()) { + auto const &runtime_env = job_config.runtime_env().serialized_runtime_env(); + RAY_LOG(INFO) << "[Eagerly] Start install runtime environment for job " << job_id + << ". The runtime environment was " << runtime_env << "."; + CreateRuntimeEnv( + runtime_env, job_id, + [job_id](bool successful, const std::string &serialized_runtime_env_context) { + if (successful) { + RAY_LOG(INFO) << "[Eagerly] Create runtime env successful for job " << job_id + << ". The result context was " << serialized_runtime_env_context + << "."; + } else { + RAY_LOG(ERROR) << "[Eagerly] Couldn't create a runtime environment for job " + << job_id << "."; + } + }); + } } void WorkerPool::HandleJobFinished(const JobID &job_id) { @@ -749,7 +768,7 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { // The worker is used for the actor creation task with dynamic options. if (!used) { // Put it into idle dedicated worker pool. - // TODO(guyang.sgy): This worker will not be used forever. We should kill it. + // TODO(SongGuyang): This worker will not be used forever. We should kill it. state.idle_dedicated_workers[task_id] = worker; } return; @@ -921,7 +940,8 @@ void WorkerPool::TryKillingIdleWorkers() { void WorkerPool::PopWorker(const TaskSpecification &task_spec, const PopWorkerCallback &callback, const std::string &allocated_instances_serialized_json) { - RAY_LOG(DEBUG) << "Pop worker for task " << task_spec.TaskId(); + RAY_LOG(DEBUG) << "Pop worker for task " << task_spec.TaskId() << " task name " + << task_spec.FunctionDescriptor()->ToString(); auto &state = GetStateForLanguage(task_spec.GetLanguage()); std::shared_ptr worker = nullptr; @@ -936,8 +956,7 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, Process proc = StartWorkerProcess( task_spec.GetLanguage(), rpc::WorkerType::WORKER, task_spec.JobId(), &status, dynamic_options, task_spec.GetRuntimeEnvHash(), serialized_runtime_env, - task_spec.OverrideEnvironmentVariables(), serialized_runtime_env_context, - allocated_instances_serialized_json); + serialized_runtime_env_context, allocated_instances_serialized_json); if (status == PopWorkerStatus::OK) { RAY_CHECK(proc.IsValid()); WarnAboutSize(); @@ -948,7 +967,7 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, state.starting_workers_to_tasks[proc] = std::move(task_info); } } else { - // TODO(guyang.sgy): Wait until a worker is pushed or a worker can be started If + // TODO(SongGuyang): Wait until a worker is pushed or a worker can be started If // startup concurrency maxed out or job not started. PopWorkerCallbackAsync(callback, nullptr, status); } @@ -976,24 +995,24 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, dynamic_options = task_spec.DynamicWorkerOptions(); } - // create runtime env. if (task_spec.HasRuntimeEnv()) { - agent_manager_->CreateRuntimeEnv( - task_spec.JobId(), task_spec.SerializedRuntimeEnv(), - [start_worker_process_fn, callback, &state, task_spec, dynamic_options, - allocated_instances_serialized_json]( - bool success, const std::string &serialized_runtime_env_context) { - if (success) { + // create runtime env. + CreateRuntimeEnv( + task_spec.SerializedRuntimeEnv(), task_spec.JobId(), + [start_worker_process_fn, callback, &state, task_spec, dynamic_options]( + bool successful, const std::string &serialized_runtime_env_context) { + if (successful) { start_worker_process_fn(task_spec, state, dynamic_options, true, task_spec.SerializedRuntimeEnv(), serialized_runtime_env_context, callback); } else { - RAY_LOG(WARNING) << "Couldn't create a runtime environment for task " - << task_spec.TaskId() << ". The runtime environment was " - << task_spec.SerializedRuntimeEnv() << "."; callback(nullptr, PopWorkerStatus::RuntimeEnvCreationFailed); + RAY_LOG(WARNING) + << "Create runtime env failed for task " << task_spec.TaskId() + << " and couldn't create the dedicated worker."; } - }); + }, + allocated_instances_serialized_json); } else { start_worker_process_fn(task_spec, state, dynamic_options, true, "", "", callback); @@ -1036,8 +1055,8 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, // Start a new worker process. if (task_spec.HasRuntimeEnv()) { // create runtime env. - agent_manager_->CreateRuntimeEnv( - task_spec.JobId(), task_spec.SerializedRuntimeEnv(), + CreateRuntimeEnv( + task_spec.SerializedRuntimeEnv(), task_spec.JobId(), [start_worker_process_fn, callback, &state, task_spec]( bool successful, const std::string &serialized_runtime_env_context) { if (successful) { @@ -1045,12 +1064,13 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, task_spec.SerializedRuntimeEnv(), serialized_runtime_env_context, callback); } else { - RAY_LOG(WARNING) << "Couldn't create a runtime environment for task " - << task_spec.TaskId() << ". The runtime environment was " - << task_spec.SerializedRuntimeEnv() << "."; callback(nullptr, PopWorkerStatus::RuntimeEnvCreationFailed); + RAY_LOG(WARNING) + << "Create runtime env failed for task " << task_spec.TaskId() + << " and couldn't create the worker."; } - }); + }, + allocated_instances_serialized_json); } else { start_worker_process_fn(task_spec, state, {}, false, "", "", callback); } @@ -1067,7 +1087,7 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, int64_t bac int64_t num_available_cpus) { // Code path of task that needs a dedicated worker. if ((task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) || - task_spec.OverrideEnvironmentVariables().size() > 0 || task_spec.HasRuntimeEnv()) { + task_spec.HasRuntimeEnv()) { return; // Not handled. // TODO(architkulkarni): We'd eventually like to prestart workers with the same // runtime env to improve initial startup performance. @@ -1324,6 +1344,26 @@ WorkerPool::IOWorkerState &WorkerPool::GetIOWorkerStateFromWorkerType( UNREACHABLE; } +void WorkerPool::CreateRuntimeEnv( + const std::string &serialized_runtime_env, const JobID &job_id, + const std::function &callback, + const std::string &serialized_allocated_resource_instances) { + // create runtime env. + agent_manager_->CreateRuntimeEnv( + job_id, serialized_runtime_env, serialized_allocated_resource_instances, + [job_id, serialized_runtime_env, callback]( + bool successful, const std::string &serialized_runtime_env_context) { + if (successful) { + callback(true, serialized_runtime_env_context); + } else { + RAY_LOG(WARNING) << "Couldn't create a runtime environment for job " << job_id + << ". The runtime environment was " << serialized_runtime_env + << "."; + callback(false, ""); + } + }); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 7991600cfd6c6..92c19329c17dc 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -397,7 +397,6 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { PopWorkerStatus *status /*output*/, const std::vector &dynamic_options = {}, const int runtime_env_hash = 0, const std::string &serialized_runtime_env = "{}", - std::unordered_map override_environment_variables = {}, const std::string &serialized_runtime_env_context = "{}", const std::string &allocated_instances_serialized_json = "{}"); @@ -589,6 +588,12 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { const PopWorkerStatus &status, bool *found /* output */, bool *worker_used /* output */, TaskID *task_id /* output */); + /// Create runtime env asynchronously by runtime env agent. + void CreateRuntimeEnv( + const std::string &serialized_runtime_env, const JobID &job_id, + const std::function &callback, + const std::string &serialized_allocated_resource_instances = "{}"); + /// For Process class for managing subprocesses (e.g. reaping zombies). instrumented_io_context *io_service_; /// Node ID of the current node. diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 9a28520700a8e..37fb903b4a7ab 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -103,9 +103,10 @@ class WorkerPoolMock : public WorkerPool { const WorkerCommandMap &worker_commands, absl::flat_hash_map> &mock_worker_rpc_clients) - : WorkerPool(io_service, NodeID::FromRandom(), "", POOL_SIZE_SOFT_LIMIT, 0, - MAXIMUM_STARTUP_CONCURRENCY, 0, 0, {}, nullptr, worker_commands, - []() {}, 0, [this]() { return current_time_ms_; }), + : WorkerPool( + io_service, NodeID::FromRandom(), "", POOL_SIZE_SOFT_LIMIT, 0, + MAXIMUM_STARTUP_CONCURRENCY, 0, 0, {}, nullptr, worker_commands, []() {}, 0, + [this]() { return current_time_ms_; }), last_worker_process_(), instrumented_io_service_(io_service), error_message_type_(1), @@ -257,7 +258,7 @@ class WorkerPoolMock : public WorkerPool { is_java = true; } } - // TODO(guyang.sgy): support C++ language workers. + // TODO(SongGuyang): support C++ language workers. int num_workers = is_java ? NUM_WORKERS_PER_PROCESS_JAVA : 1; for (int i = 0; i < num_workers; i++) { auto worker = @@ -458,7 +459,7 @@ static inline TaskSpecification ExampleTaskSpec( } else { message.set_type(TaskType::NORMAL_TASK); } - message.set_serialized_runtime_env(serialized_runtime_env); + message.mutable_runtime_env()->set_serialized_runtime_env(serialized_runtime_env); return TaskSpecification(std::move(message)); } @@ -1257,8 +1258,7 @@ TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) { ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), /*dynamic_options=*/{}, TaskID::ForFakeTask(), "mock_runtime_env_2"); - const WorkerCacheKey env1 = { - /*override_environment_variables=*/{}, "mock_runtime_env_1", {}}; + const WorkerCacheKey env1 = {"mock_runtime_env_1", {}}; const int runtime_env_hash_1 = env1.IntHash(); // Push worker with runtime env 1. diff --git a/src/ray/raylet_client/raylet_client.cc b/src/ray/raylet_client/raylet_client.cc index 41e9611491c7d..3da524b8611a4 100644 --- a/src/ray/raylet_client/raylet_client.cc +++ b/src/ray/raylet_client/raylet_client.cc @@ -296,13 +296,20 @@ Status raylet::RayletClient::FreeObjects(const std::vector &object_ids } void raylet::RayletClient::RequestWorkerLease( - const TaskSpecification &resource_spec, + const rpc::TaskSpec &task_spec, const rpc::ClientCallback &callback, const int64_t backlog_size) { - rpc::RequestWorkerLeaseRequest request; - request.mutable_resource_spec()->CopyFrom(resource_spec.GetMessage()); - request.set_backlog_size(backlog_size); - grpc_client_->RequestWorkerLease(request, callback); + google::protobuf::Arena arena; + auto request = + google::protobuf::Arena::CreateMessage(&arena); + // The unsafe allocating here is actually safe because the life-cycle of + // task_spec is longer than request. + // Request will be sent before the end of this call, and after that, it won't be + // used any more. + request->unsafe_arena_set_allocated_resource_spec( + const_cast(&task_spec)); + request->set_backlog_size(backlog_size); + grpc_client_->RequestWorkerLease(*request, callback); } /// Spill objects to external storage. @@ -314,6 +321,20 @@ void raylet::RayletClient::RequestObjectSpillage( grpc_client_->RequestObjectSpillage(request, callback); } +void raylet::RayletClient::ReportWorkerBacklog( + const WorkerID &worker_id, + const std::vector &backlog_reports) { + rpc::ReportWorkerBacklogRequest request; + request.set_worker_id(worker_id.Binary()); + request.mutable_backlog_reports()->Add(backlog_reports.begin(), backlog_reports.end()); + grpc_client_->ReportWorkerBacklog( + request, [](const Status &status, const rpc::ReportWorkerBacklogReply &reply) { + if (!status.ok()) { + RAY_LOG(INFO) << "Error reporting task backlog information: " << status; + } + }); +} + Status raylet::RayletClient::ReturnWorker(int worker_port, const WorkerID &worker_id, bool disconnect_worker) { rpc::ReturnWorkerRequest request; @@ -373,7 +394,7 @@ void raylet::RayletClient::CommitBundleResources( } void raylet::RayletClient::CancelResourceReserve( - BundleSpecification &bundle_spec, + const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) { rpc::CancelResourceReserveRequest request; request.mutable_bundle_spec()->CopyFrom(bundle_spec.GetMessage()); diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index 558fed24b24cf..547e8eaa7ee00 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -68,6 +68,10 @@ class WorkerLeaseInterface { const ray::TaskSpecification &resource_spec, const ray::rpc::ClientCallback &callback, const int64_t backlog_size = -1) = 0; + virtual void RequestWorkerLease( + const rpc::TaskSpec &task_spec, + const ray::rpc::ClientCallback &callback, + const int64_t backlog_size = -1) = 0; /// Returns a worker to the raylet. /// \param worker_port The local port of the worker on the raylet node. @@ -89,6 +93,14 @@ class WorkerLeaseInterface { const TaskID &task_id, const rpc::ClientCallback &callback) = 0; + /// Report the backlog size of a given worker and a given scheduling class to the + /// raylet. + /// \param worker_id The ID of the worker that reports the backlog size. + /// \param backlog_reports The backlog report for each scheduling class + virtual void ReportWorkerBacklog( + const WorkerID &worker_id, + const std::vector &backlog_reports) = 0; + virtual ~WorkerLeaseInterface(){}; }; @@ -117,7 +129,7 @@ class ResourceReserveInterface { const ray::rpc::ClientCallback &callback) = 0; virtual void CancelResourceReserve( - BundleSpecification &bundle_spec, + const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) = 0; virtual void ReleaseUnusedBundles( @@ -360,12 +372,24 @@ class RayletClient : public RayletClientInterface { void RequestWorkerLease( const ray::TaskSpecification &resource_spec, const ray::rpc::ClientCallback &callback, + const int64_t backlog_size) override { + RequestWorkerLease(resource_spec.GetMessage(), callback, backlog_size); + } + + void RequestWorkerLease( + const rpc::TaskSpec &resource_spec, + const ray::rpc::ClientCallback &callback, const int64_t backlog_size) override; /// Implements WorkerLeaseInterface. ray::Status ReturnWorker(int worker_port, const WorkerID &worker_id, bool disconnect_worker) override; + /// Implements WorkerLeaseInterface. + void ReportWorkerBacklog( + const WorkerID &worker_id, + const std::vector &backlog_reports) override; + /// Implements WorkerLeaseInterface. void ReleaseUnusedWorkers( const std::vector &workers_in_use, @@ -389,7 +413,7 @@ class RayletClient : public RayletClientInterface { /// Implements CancelResourceReserveInterface. void CancelResourceReserve( - BundleSpecification &bundle_spec, + const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) override; diff --git a/src/ray/rpc/common.cc b/src/ray/rpc/common.cc index eef01f3e1e2f5..7526c1e6efc6f 100644 --- a/src/ray/rpc/common.cc +++ b/src/ray/rpc/common.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "ray/rpc/common.h" + #include #include -#include "ray/rpc/common.h" - namespace ray::rpc { std::string ReadCert(const std::string &cert_filepath) { @@ -26,4 +26,4 @@ std::string ReadCert(const std::string &cert_filepath) { return buffer.str(); }; -} // namespace rpc::ray +} // namespace ray::rpc diff --git a/src/ray/rpc/common.h b/src/ray/rpc/common.h index 929a555a942f6..314e1eccf382c 100644 --- a/src/ray/rpc/common.h +++ b/src/ray/rpc/common.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + namespace ray::rpc { // Utility to read cert file from a particular location diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 3840527fb5a9a..8f3f98b67445c 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -36,11 +36,13 @@ DEFINE_stats(grpc_server_req_finished, "Finished request number in grpc server", namespace ray { namespace rpc { -GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, - bool use_tls, int64_t keepalive_time_ms) +GrpcServer::GrpcServer(std::string name, const uint32_t port, + bool listen_to_localhost_only, int num_threads, + int64_t keepalive_time_ms, bool use_tls) : name_(std::move(name)), port_(port), use_tls_(use_tls), + listen_to_localhost_only_(listen_to_localhost_only), is_closed_(true), num_threads_(num_threads), keepalive_time_ms_(keepalive_time_ms) { @@ -49,7 +51,8 @@ GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, void GrpcServer::Run() { uint32_t specified_port = port_; - std::string server_address("0.0.0.0:" + std::to_string(port_)); + std::string server_address((listen_to_localhost_only_ ? "127.0.0.1:" : "0.0.0.0:") + + std::to_string(port_)); grpc::ServerBuilder builder; // Disable the SO_REUSEPORT option. We don't need it in ray. If the option is enabled // (default behavior in grpc), we may see multiple workers listen on the same port and diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 826efbdf260bb..c83628b72b2e8 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -61,9 +61,11 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - GrpcServer(std::string name, const uint32_t port, int num_threads = 1, - bool use_tls = false, - int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/); + + GrpcServer(std::string name, const uint32_t port, bool listen_to_localhost_only, + int num_threads = 1, + int64_t keepalive_time_ms = 7200000, /*2 hours, grpc default*/ + bool use_tls = false); /// Destruct this gRPC server. ~GrpcServer() { Shutdown(); } @@ -114,6 +116,9 @@ class GrpcServer { int port_; /// Whether to use TLS. bool use_tls_; + /// Listen to localhost (127.0.0.1) only if it's true, otherwise listen to all network + /// interfaces (0.0.0.0) + const bool listen_to_localhost_only_; /// Indicates whether this server has been closed. bool is_closed_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. diff --git a/src/ray/rpc/node_manager/node_manager_client.h b/src/ray/rpc/node_manager/node_manager_client.h index 341613a848e98..fad890c990e00 100644 --- a/src/ray/rpc/node_manager/node_manager_client.h +++ b/src/ray/rpc/node_manager/node_manager_client.h @@ -79,6 +79,9 @@ class NodeManagerWorkerClient /// Request a worker lease. VOID_RPC_CLIENT_METHOD(NodeManagerService, RequestWorkerLease, grpc_client_, ) + /// Report task backlog information + VOID_RPC_CLIENT_METHOD(NodeManagerService, ReportWorkerBacklog, grpc_client_, ) + /// Return a worker lease. VOID_RPC_CLIENT_METHOD(NodeManagerService, ReturnWorker, grpc_client_, ) diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index 7f7d2a5a9738b..2cec90a4512f7 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -28,6 +28,7 @@ namespace rpc { RPC_SERVICE_HANDLER(NodeManagerService, UpdateResourceUsage, -1) \ RPC_SERVICE_HANDLER(NodeManagerService, RequestResourceReport, -1) \ RPC_SERVICE_HANDLER(NodeManagerService, RequestWorkerLease, -1) \ + RPC_SERVICE_HANDLER(NodeManagerService, ReportWorkerBacklog, -1) \ RPC_SERVICE_HANDLER(NodeManagerService, ReturnWorker, -1) \ RPC_SERVICE_HANDLER(NodeManagerService, ReleaseUnusedWorkers, -1) \ RPC_SERVICE_HANDLER(NodeManagerService, CancelWorkerLease, -1) \ @@ -70,6 +71,10 @@ class NodeManagerServiceHandler { RequestWorkerLeaseReply *reply, SendReplyCallback send_reply_callback) = 0; + virtual void HandleReportWorkerBacklog(const ReportWorkerBacklogRequest &request, + ReportWorkerBacklogReply *reply, + SendReplyCallback send_reply_callback) = 0; + virtual void HandleReturnWorker(const ReturnWorkerRequest &request, ReturnWorkerReply *reply, SendReplyCallback send_reply_callback) = 0; diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index d3c199b50c6bb..9e2d50e8324e4 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -14,10 +14,11 @@ #pragma once +#include #include -#include #include + #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/grpc_util.h" #include "ray/common/status.h" @@ -145,6 +146,7 @@ class ServerCallImpl : public ServerCall { response_writer_(&context_), io_service_(io_service), call_name_(std::move(call_name)) { + reply_ = google::protobuf::Arena::CreateMessage(&arena_); // TODO call_name_ sometimes get corrunpted due to memory issues. RAY_CHECK(!call_name_.empty()) << "Call name is empty"; STATS_grpc_server_req_new.Record(1.0, call_name_); @@ -187,7 +189,7 @@ class ServerCallImpl : public ServerCall { factory.CreateCall(); } (service_handler_.*handle_request_function_)( - request_, &reply_, + request_, reply_, [this](Status status, std::function success, std::function failure) { // These two callbacks must be set before `SendReply`, because `SendReply` @@ -222,9 +224,13 @@ class ServerCallImpl : public ServerCall { /// Tell gRPC to finish this request and send reply asynchronously. void SendReply(const Status &status) { state_ = ServerCallState::SENDING_REPLY; - response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this); + response_writer_.Finish(*reply_, RayStatusToGrpcStatus(status), this); } + /// The memory pool for this request. It's used for reply. + /// With arena, we'll be able to setup the reply without copying some field. + google::protobuf::Arena arena_; + /// State of this call. ServerCallState state_; @@ -250,8 +256,9 @@ class ServerCallImpl : public ServerCall { /// The request message. Request request_; - /// The reply message. - Reply reply_; + /// The reply message. This one is owned by arena. It's not valid beyond + /// the life-cycle of this call. + Reply *reply_; /// Human-readable name for this RPC call. std::string call_name_; diff --git a/src/ray/rpc/test/grpc_server_client_test.cc b/src/ray/rpc/test/grpc_server_client_test.cc index e7b602e6b316f..3bd86f5a24f63 100644 --- a/src/ray/rpc/test/grpc_server_client_test.cc +++ b/src/ray/rpc/test/grpc_server_client_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include + #include "gtest/gtest.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" @@ -35,13 +36,14 @@ class TestServiceHandler { RAY_LOG(INFO) << "No reply!"; return; } - send_reply_callback(ray::Status::OK(), - /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, - /*reply_failure=*/ - [this]() { - RAY_LOG(INFO) << "Reply failed."; - reply_failure_count++; - }); + send_reply_callback( + ray::Status::OK(), + /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, + /*reply_failure=*/ + [this]() { + RAY_LOG(INFO) << "Reply failed."; + reply_failure_count++; + }); } std::atomic request_count{0}; std::atomic reply_failure_count{0}; @@ -83,7 +85,7 @@ class TestGrpcServerClientFixture : public ::testing::Test { handler_io_service_.run(); }); test_service_.reset(new TestGrpcService(handler_io_service_, test_service_handler_)); - grpc_server_.reset(new GrpcServer("test", 0)); + grpc_server_.reset(new GrpcServer("test", 0, true)); grpc_server_->RegisterService(*test_service_); grpc_server_->Run(); diff --git a/src/ray/util/event.h b/src/ray/util/event.h index 9caed946f3af1..4f2e98a4427c3 100644 --- a/src/ray/util/event.h +++ b/src/ray/util/event.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once +#include + #include #include #include @@ -22,6 +24,8 @@ #include #include #include + +#include "nlohmann/json.hpp" #include "ray/util/logging.h" #include "ray/util/util.h" #include "spdlog/sinks/basic_file_sink.h" @@ -29,10 +33,6 @@ #include "spdlog/spdlog.h" #include "src/ray/protobuf/event.pb.h" -#include "nlohmann/json.hpp" - -#include - using json = nlohmann::json; namespace ray { @@ -102,7 +102,7 @@ class EventManager final { // We added `const json &custom_fields` here because we need to support typed custom // fields. - // TODO(guyang.sgy): Remove the protobuf `rpc::Event` and use an internal struct + // TODO(SongGuyang): Remove the protobuf `rpc::Event` and use an internal struct // instead. void Publish(const rpc::Event &event, const json &custom_fields); diff --git a/src/ray/util/util.h b/src/ray/util/util.h index 9b2e3f443dbac..95500e91694a7 100644 --- a/src/ray/util/util.h +++ b/src/ray/util/util.h @@ -21,7 +21,6 @@ #include #include #include - #include #include "ray/util/logging.h" @@ -167,7 +166,7 @@ class InitShutdownRAII { /// \param shutdown_func The shutdown function. /// \param args The arguments for the init function. template - InitShutdownRAII(InitFunc init_func, ShutdownFunc shutdown_func, Args &&... args) + InitShutdownRAII(InitFunc init_func, ShutdownFunc shutdown_func, Args &&...args) : shutdown_(shutdown_func) { init_func(args...); } @@ -259,7 +258,7 @@ template class ThreadPrivate { public: template - ThreadPrivate(Ts &&... ts) : t_(std::forward(ts)...) {} + explicit ThreadPrivate(Ts &&...ts) : t_(std::forward(ts)...) {} T &operator*() { ThreadCheck(); @@ -312,4 +311,43 @@ class ThreadPrivate { mutable std::mutex mutex_; }; +class ExponentialBackOff { + public: + ExponentialBackOff() = default; + ExponentialBackOff(const ExponentialBackOff &) = default; + ExponentialBackOff(ExponentialBackOff &&) = default; + ExponentialBackOff &operator=(const ExponentialBackOff &) = default; + ExponentialBackOff &operator=(ExponentialBackOff &&) = default; + + /// Construct an exponential back off counter. + /// + /// \param[in] initial_value The start value for this counter + /// \param[in] multiplier The multiplier for this counter. + /// \param[in] max_value The maximum value for this counter. By default it's + /// infinite double. + ExponentialBackOff(uint64_t initial_value, double multiplier, + uint64_t max_value = std::numeric_limits::max()) + : curr_value_(initial_value), + initial_value_(initial_value), + max_value_(max_value), + multiplier_(multiplier) { + RAY_CHECK(multiplier > 0.0) << "Multiplier must be greater than 0"; + } + + uint64_t Next() { + auto ret = curr_value_; + curr_value_ = curr_value_ * multiplier_; + curr_value_ = std::min(curr_value_, max_value_); + return ret; + } + + void Reset() { curr_value_ = initial_value_; } + + private: + uint64_t curr_value_; + uint64_t initial_value_; + uint64_t max_value_; + double multiplier_; +}; + } // namespace ray diff --git a/src/ray/util/util_test.cc b/src/ray/util/util_test.cc index 435f1598f4f69..3e13dedb10bf9 100644 --- a/src/ray/util/util_test.cc +++ b/src/ray/util/util_test.cc @@ -102,6 +102,23 @@ TEST(UtilTest, ParseCommandLineTest) { ASSERT_EQ(ParseCommandLine(R"(x' a \b')", win32), ArgList({R"(x')", R"(a)", R"(\b')"})); } +TEST(UtilTest, ExponentialBackOffTest) { + auto exp = ExponentialBackOff(1, 2, 9); + ASSERT_EQ(1, exp.Next()); + ASSERT_EQ(2, exp.Next()); + ASSERT_EQ(4, exp.Next()); + ASSERT_EQ(8, exp.Next()); + ASSERT_EQ(9, exp.Next()); + ASSERT_EQ(9, exp.Next()); + exp.Reset(); + ASSERT_EQ(1, exp.Next()); + ASSERT_EQ(2, exp.Next()); + ASSERT_EQ(4, exp.Next()); + ASSERT_EQ(8, exp.Next()); + ASSERT_EQ(9, exp.Next()); + ASSERT_EQ(9, exp.Next()); +} + TEST(UtilTest, ParseURLTest) { const std::string url = "http://abc?num_objects=9&offset=8388878&size=8388878"; auto parsed_url = *ParseURL(url); diff --git a/streaming/src/queue/queue_handler.h b/streaming/src/queue/queue_handler.h index e04e34b359804..26bd863e85ecc 100644 --- a/streaming/src/queue/queue_handler.h +++ b/streaming/src/queue/queue_handler.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include #include diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index c51b1a8a11a5b..5e5b575223a6b 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -639,7 +639,7 @@ class StreamingWorker { } // namespace ray int main(int argc, char **argv) { - RAY_CHECK(argc == 5); + RAY_CHECK(argc >= 4); auto store_socket = std::string(argv[1]); auto raylet_socket = std::string(argv[2]); auto node_manager_port = std::stoi(std::string(argv[3])); diff --git a/thirdparty/patches/prometheus-windows-pollfd.patch b/thirdparty/patches/prometheus-windows-pollfd.patch index 1941b6cb247c0..3b30942bb85f2 100644 --- a/thirdparty/patches/prometheus-windows-pollfd.patch +++ b/thirdparty/patches/prometheus-windows-pollfd.patch @@ -6,17 +6,46 @@ Windows Vista and later SDKs define struct pollfd for WSAPoll(), but it has a pe civetweb provides its own implementation of poll, but it has a conflicting definition for pollfd. Hence we block Windows from defining pollfd (which this project doesn't use). --- - bazel/civetweb.BUILD | 1 + - 1 file changed, 1 insertion(+) + bazel/civetweb.BUILD | 7 +++++++ + 1 file changed, 7 insertions(+) diff --git bazel/civetweb.BUILD bazel/civetweb.BUILD --- bazel/civetweb.BUILD +++ bazel/civetweb.BUILD -@@ -34,5 +34,6 @@ cc_library( +@@ -9,6 +9,11 @@ config_setting( + values = {"cpu": "darwin_x86_64"}, + ) + ++config_setting( ++ name = "darwin_arm64", ++ values = {"cpu": "darwin_arm64"}, ++) ++ + config_setting( + name = "windows", + values = { "cpu": "x64_windows" }, +@@ -34,6 +39,7 @@ cc_library( "-DNO_CACHING", "-DNO_SSL", "-DNO_FILES", + "-D_WIN32_WINNT=0x0502", "-UDEBUG", ], --- + includes = [ +@@ -46,6 +52,7 @@ cc_library( + }) + select({ + ":darwin": [], + ":darwin_x86_64": [], ++ ":darwin_arm64": [], + ":windows": [], + ":windows_msvc": [], + "//conditions:default": ["-lrt"], +@@ -86,6 +93,7 @@ cc_library( + }) + select({ + ":darwin": [], + ":darwin_x86_64": [], ++ ":darwin_arm64": [], + ":windows": [], + ":windows_msvc": [], + "//conditions:default": ["-lrt"], +-- diff --git a/thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch b/thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch deleted file mode 100644 index 9cd53fe60f842..0000000000000 --- a/thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch +++ /dev/null @@ -1,8 +0,0 @@ -diff --git BUILD.boost BUILD.boost ---- BUILD.boost -+++ BUILD.boost -@@ -1356,3 +1356,2 @@ boost_library( - defines = [ -- "BOOST_FALLTHROUGH", - ], --- diff --git a/thirdparty/patches/rules_boost-windows-linkopts.patch b/thirdparty/patches/rules_boost-windows-linkopts.patch index 28bda4eb06939..204443d3c7186 100644 --- a/thirdparty/patches/rules_boost-windows-linkopts.patch +++ b/thirdparty/patches/rules_boost-windows-linkopts.patch @@ -1,15 +1,12 @@ diff --git BUILD.boost BUILD.boost --- BUILD.boost +++ BUILD.boost -@@ -313,1 +313,9 @@ boost_library(name = "asio", -- linkopts = ["-lpthread"], -+ linkopts = select({ -+ ":linux": [ -+ "-lpthread", -+ ], -+ ":osx_x86_64": [ -+ "-lpthread", -+ ], -+ "//conditions:default": [], -+ }), --- +@@ -428,6 +428,7 @@ boost_library( + }), + linkopts = select({ + ":android": [], ++ ":windows": [], + "//conditions:default": ["-lpthread"], + }), + deps = [ +-- From aea3e4e593dceb06f51297dc06e01bcad28fa67c Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 12 Oct 2021 16:00:45 +0100 Subject: [PATCH 40/56] Revert "Address comments" This reverts commit 53896b33f1f973d57f1f98bccabddf8e2e12d9d0. --- .bazelrc | 10 +- .buildkite/pipeline.gpu.large.yml | 8 - .buildkite/pipeline.gpu.yml | 10 - .buildkite/pipeline.macos.yml | 10 +- .buildkite/pipeline.yml | 34 +- .buildkite/windows/install/bazel.ps1 | 2 +- .clang-tidy | 41 +- .flake8 | 16 - .github/CODEOWNERS | 12 +- .github/workflows/main.yml | 3 +- .gitpod/Dockerfile | 2 +- BUILD.bazel | 35 +- bazel/ray_deps_setup.bzl | 9 +- benchmarks/object_store/test_object_store.py | 1 - benchmarks/single_node/test_single_node.py | 3 +- ci/asan_tests/run_asan_tests.sh | 8 +- ci/travis/bazel.py | 42 +- ci/travis/ci.sh | 56 +-- ci/travis/format.sh | 4 - ci/travis/install-dependencies.sh | 2 +- ci/travis/test-worker-in-container.sh | 2 +- cpp/BUILD.bazel | 1 - cpp/src/ray/api.cc | 2 +- cpp/src/ray/runtime/abstract_ray_runtime.cc | 2 +- .../ray/runtime/object/native_object_store.cc | 2 +- .../runtime/task/local_mode_task_submitter.cc | 4 +- cpp/src/ray/runtime/task/task_executor.cc | 2 +- cpp/src/ray/runtime/task/task_executor.h | 4 +- cpp/src/ray/util/process_helper.cc | 11 +- dashboard/agent.py | 2 +- dashboard/client/src/pages/job/JobDetail.tsx | 11 - dashboard/client/src/pages/job/index.tsx | 19 +- dashboard/client/src/type/job.d.ts | 2 - dashboard/head.py | 11 +- dashboard/modules/job/job_agent.py | 4 +- .../modules/runtime_env/runtime_env_agent.py | 24 +- dashboard/modules/snapshot/snapshot_head.py | 3 +- .../modules/snapshot/snapshot_schema.json | 4 + dashboard/tests/test_dashboard.py | 2 +- doc/BUILD | 32 -- doc/Makefile | 2 +- doc/examples/dask_xgboost/README.rst | 1 - doc/examples/dask_xgboost/dask_xgboost.py | 321 ------------ doc/examples/dask_xgboost/dask_xgboost.yaml | 24 - doc/examples/modin_xgboost/README.rst | 1 - doc/examples/modin_xgboost/modin_xgboost.py | 233 --------- doc/examples/modin_xgboost/modin_xgboost.yaml | 24 - doc/examples/overview.rst | 12 +- doc/kubernetes/ray-cluster.yaml | 4 +- doc/source/advanced.rst | 35 +- doc/source/cluster/config.rst | 62 --- doc/source/cluster/ray-client.rst | 69 +-- doc/source/conf.py | 6 +- doc/source/configure.rst | 22 - doc/source/data/.gitignore | 1 - doc/source/data/_examples/README.rst | 1 - .../data/_examples/big_data_ingestion.py | 276 ----------- doc/source/data/big_data_ingestion.yaml | 54 -- doc/source/data/dask-on-ray.rst | 20 +- doc/source/data/dataset-pipeline.rst | 121 +---- doc/source/data/dataset-tensor-support.rst | 72 ++- doc/source/data/dataset.rst | 8 +- doc/source/data/package-ref.rst | 2 - doc/source/development.rst | 11 +- doc/source/index.rst | 6 +- doc/source/raysgd/raysgd.rst | 2 +- doc/source/raysgd/raysgd_pytorch.rst | 5 +- doc/source/raysgd/raysgd_tensorflow.rst | 5 +- doc/source/raysgd/raysgd_tune.rst | 3 - doc/source/raysgd/v2/api.rst | 21 +- doc/source/raysgd/v2/examples.rst | 3 - .../tune_cifar_pytorch_pbt_example.rst | 6 - doc/source/raysgd/v2/migration-guide.rst | 393 --------------- doc/source/raysgd/v2/raysgd.rst | 3 +- doc/source/raysgd/v2/user_guide.rst | 12 +- doc/source/serve/core-apis.rst | 43 +- doc/source/serve/deployment.rst | 10 +- doc/source/serve/ml-models.rst | 15 +- doc/source/tune/_tutorials/_faq.inc | 55 +-- doc/source/tune/api_docs/suggestion.rst | 3 +- doc/source/tune/user-guide.rst | 10 +- java/BUILD.bazel | 5 - java/dependencies.bzl | 2 - .../java/io/ray/runtime/RayNativeRuntime.java | 18 +- .../runtime/object/LocalModeObjectStore.java | 2 +- .../ray/runtime/object/NativeObjectStore.java | 6 +- .../io/ray/runtime/object/ObjectRefImpl.java | 2 +- .../io/ray/runtime/object/ObjectStore.java | 4 +- java/serve/pom.xml | 15 - .../src/main/java/io/ray/serve/Constants.java | 6 - .../java/io/ray/serve/DeploymentInfo.java | 38 -- .../io/ray/serve/DummyBackendReplica.java | 12 - .../main/java/io/ray/serve/HandleOptions.java | 15 - .../src/main/java/io/ray/serve/HttpProxy.java | 161 ------ .../main/java/io/ray/serve/ProxyActor.java | 175 ------- .../main/java/io/ray/serve/ProxyRouter.java | 72 --- .../java/io/ray/serve/RayServeConfig.java | 6 - .../java/io/ray/serve/RayServeHandle.java | 73 --- .../java/io/ray/serve/RayServeMetrics.java | 74 --- .../java/io/ray/serve/RayServeReplica.java | 211 +++----- .../io/ray/serve/RayServeWrappedReplica.java | 42 +- .../main/java/io/ray/serve/ReplicaConfig.java | 8 +- .../java/io/ray/serve/ReplicaContext.java | 2 +- .../main/java/io/ray/serve/ReplicaSet.java | 138 ------ .../src/main/java/io/ray/serve/Router.java | 64 --- .../java/io/ray/serve/ServeController.java | 6 - .../main/java/io/ray/serve/ServeProxy.java | 14 - .../main/java/io/ray/serve/api/Client.java | 72 --- .../src/main/java/io/ray/serve/api/Serve.java | 54 +- .../java/io/ray/serve/poll/KeyListener.java | 2 +- .../io/ray/serve/poll/LongPollClient.java | 69 +-- .../io/ray/serve/poll/LongPollNamespace.java | 4 +- .../java/io/ray/serve/poll/UpdatedObject.java | 33 ++ .../io/ray/serve/util/CollectionUtil.java | 10 - .../java/io/ray/serve/util/CommonUtil.java | 13 - .../java/io/ray/serve/util/ReflectUtil.java | 14 - .../io/ray/serve/util/ServeProtoUtil.java | 75 +-- .../java/io/ray/serve/util/SocketUtil.java | 49 -- .../io/ray/serve/DummyServeController.java | 21 - .../test/java/io/ray/serve/HttpProxyTest.java | 74 --- .../java/io/ray/serve/ProxyActorTest.java | 110 ----- .../java/io/ray/serve/ProxyRouterTest.java | 68 --- .../java/io/ray/serve/RayServeHandleTest.java | 76 --- .../io/ray/serve/RayServeReplicaTest.java | 46 +- .../java/io/ray/serve/ReplicaSetTest.java | 108 ---- .../test/java/io/ray/serve/RouterTest.java | 80 --- .../java/io/ray/serve/api/ClientTest.java | 47 -- .../test/java/io/ray/serve/api/ServeTest.java | 71 +-- .../java/io/ray/serve/poll/KeyTypeTest.java | 15 +- .../io/ray/serve/poll/LongPollClientTest.java | 29 +- python/build-wheel-windows.sh | 7 - python/ray/_private/client_mode_hook.py | 50 +- python/ray/_private/parameter.py | 16 +- python/ray/_private/runtime_env/__init__.py | 3 + python/ray/_private/runtime_env/conda.py | 4 +- .../ray/_private/runtime_env/conda_utils.py | 15 - python/ray/_private/runtime_env/context.py | 13 +- python/ray/_private/runtime_env/plugin.py | 70 --- python/ray/_private/runtime_env/validation.py | 458 ++++++----------- .../ray/_private/runtime_env/working_dir.py | 2 +- python/ray/_private/services.py | 51 +- python/ray/_private/test_utils.py | 68 ++- python/ray/_private/tls_utils.py | 85 ---- python/ray/_private/utils.py | 28 +- python/ray/_raylet.pxd | 2 +- python/ray/_raylet.pyx | 142 +++--- python/ray/actor.py | 68 ++- python/ray/autoscaler/_private/autoscaler.py | 12 +- python/ray/autoscaler/_private/docker.py | 2 +- .../_private/fake_multi_node/__init__.py | 0 .../_private/fake_multi_node/example.yaml | 55 --- .../_private/fake_multi_node/node_provider.py | 114 ----- python/ray/autoscaler/_private/gcp/node.py | 20 +- python/ray/autoscaler/_private/monitor.py | 12 +- .../ray/autoscaler/_private/node_launcher.py | 5 +- python/ray/autoscaler/_private/providers.py | 8 - .../_private/resource_demand_scheduler.py | 27 +- python/ray/autoscaler/gcp/tpu.yaml | 18 +- python/ray/autoscaler/node_provider.py | 12 - python/ray/autoscaler/ray-schema.json | 2 +- python/ray/cluster_utils.py | 69 --- python/ray/cross_language.py | 3 +- python/ray/data/__init__.py | 7 +- python/ray/data/block.py | 24 +- python/ray/data/dataset.py | 390 ++++----------- python/ray/data/dataset_pipeline.py | 188 +------ python/ray/data/datasource/__init__.py | 4 +- python/ray/data/datasource/datasource.py | 63 +-- .../data/datasource/file_based_datasource.py | 57 +-- .../ray/data/datasource/numpy_datasource.py | 13 +- python/ray/data/examples/demo_infer.py | 2 +- .../ray/data/extensions/tensor_extension.py | 8 +- python/ray/data/impl/arrow_block.py | 20 +- python/ray/data/impl/block_list.py | 7 - python/ray/data/impl/compute.py | 23 +- python/ray/data/impl/lazy_block_list.py | 57 +-- python/ray/data/impl/pipeline_executor.py | 15 +- python/ray/data/impl/progress_bar.py | 12 +- python/ray/data/impl/remote_fn.py | 5 +- python/ray/data/impl/simple_block.py | 4 +- python/ray/data/impl/tensor_block.py | 80 +++ python/ray/data/read_api.py | 77 +-- python/ray/data/tests/test_dataset.py | 466 ++++++------------ .../ray/data/tests/test_dataset_pipeline.py | 89 +--- python/ray/data/tests/test_raydp_dataset.py | 4 - python/ray/exceptions.py | 6 +- python/ray/experimental/array/remote/core.py | 4 +- python/ray/experimental/internal_kv.py | 12 +- python/ray/experimental/raysort/constants.py | 11 +- python/ray/experimental/raysort/main.py | 369 +++++--------- python/ray/experimental/raysort/sortlib.py | 8 +- .../ray/experimental/raysort/tracing_utils.py | 127 +---- python/ray/experimental/raysort/types.py | 12 +- python/ray/includes/common.pxd | 6 +- python/ray/includes/libcoreworker.pxd | 5 +- python/ray/job_config.py | 61 ++- python/ray/node.py | 31 +- python/ray/remote_function.py | 77 ++- python/ray/runtime_context.py | 16 +- python/ray/scripts/scripts.py | 27 +- python/ray/serialization.py | 2 +- python/ray/serve/BUILD | 10 +- python/ray/serve/api.py | 154 ++---- python/ray/serve/autoscaling_metrics.py | 5 +- python/ray/serve/autoscaling_policy.py | 1 + python/ray/serve/backend_state.py | 77 +-- .../serve/{replica.py => backend_worker.py} | 54 +- python/ray/serve/common.py | 5 +- python/ray/serve/config.py | 38 +- python/ray/serve/controller.py | 139 ++---- python/ray/serve/endpoint_state.py | 1 + python/ray/serve/examples/doc/conda_env.py | 23 +- python/ray/serve/handle.py | 15 +- python/ray/serve/http_proxy.py | 7 +- python/ray/serve/long_poll.py | 26 +- python/ray/serve/router.py | 10 +- python/ray/serve/storage/checkpoint_path.py | 4 +- python/ray/serve/storage/kv_store.py | 4 +- python/ray/serve/tests/conftest.py | 7 - python/ray/serve/tests/test_advanced.py | 12 +- .../serve/tests/test_autoscaling_metrics.py | 12 +- .../serve/tests/test_autoscaling_policy.py | 52 -- python/ray/serve/tests/test_backend_state.py | 77 ++- python/ray/serve/tests/test_config.py | 2 - python/ray/serve/tests/test_deploy.py | 57 +-- python/ray/serve/tests/test_get_deployment.py | 31 -- python/ray/serve/tests/test_handle.py | 27 +- python/ray/serve/tests/test_long_poll.py | 14 - python/ray/serve/tests/test_ray_client.py | 5 +- python/ray/serve/tests/test_regression.py | 2 +- python/ray/serve/tests/test_standalone.py | 9 +- python/ray/sgd/__init__.py | 3 +- python/ray/sgd/callbacks.py | 1 - python/ray/state.py | 14 +- python/ray/tests/BUILD | 13 +- python/ray/tests/client_test_utils.py | 17 - python/ray/tests/mock_setup_worker.py | 3 - python/ray/tests/test_advanced.py | 5 +- python/ray/tests/test_advanced_3.py | 13 +- python/ray/tests/test_autoscaler.py | 61 +-- .../tests/test_autoscaler_fake_multinode.py | 58 --- python/ray/tests/test_autoscaler_yaml.py | 3 - python/ray/tests/test_basic.py | 15 +- python/ray/tests/test_basic_3.py | 11 +- python/ray/tests/test_client.py | 59 +-- python/ray/tests/test_client_compat.py | 33 -- .../tests/test_client_library_integration.py | 8 +- python/ray/tests/test_client_proxy.py | 10 +- python/ray/tests/test_client_reconnect.py | 9 +- python/ray/tests/test_dashboard.py | 52 +- python/ray/tests/test_distributed_sort.py | 19 +- python/ray/tests/test_failure_2.py | 3 +- python/ray/tests/test_global_state.py | 10 +- python/ray/tests/test_multi_tenancy.py | 12 +- python/ray/tests/test_object_manager.py | 48 +- python/ray/tests/test_output.py | 8 +- python/ray/tests/test_placement_group.py | 7 - python/ray/tests/test_placement_group_3.py | 35 -- python/ray/tests/test_ray_debugger.py | 7 +- python/ray/tests/test_ray_init.py | 40 -- .../tests/test_resource_demand_scheduler.py | 130 ++--- python/ray/tests/test_runtime_context.py | 115 ----- python/ray/tests/test_runtime_env.py | 68 ++- .../ray/tests/test_runtime_env_complicated.py | 137 ++--- python/ray/tests/test_runtime_env_env_vars.py | 244 ++++++--- python/ray/tests/test_runtime_env_plugin.py | 75 --- .../ray/tests/test_runtime_env_validation.py | 379 -------------- python/ray/tests/test_scheduling.py | 99 +--- python/ray/tests/test_tls_auth.py | 66 +-- python/ray/tests/test_traceback.py | 39 -- python/ray/tune/BUILD | 2 +- .../ray/tune/analysis/experiment_analysis.py | 22 +- python/ray/tune/commands.py | 7 +- python/ray/tune/durable_trainable.py | 14 +- python/ray/tune/function_runner.py | 11 +- python/ray/tune/logger.py | 17 +- python/ray/tune/progress_reporter.py | 96 +--- python/ray/tune/ray_trial_executor.py | 27 +- python/ray/tune/registry.py | 22 +- python/ray/tune/result.py | 4 - python/ray/tune/schedulers/hyperband.py | 2 - python/ray/tune/schedulers/trial_scheduler.py | 6 - python/ray/tune/suggest/bohb.py | 4 +- python/ray/tune/tests/test_api.py | 8 - python/ray/tune/tests/test_cluster.py | 2 +- python/ray/tune/tests/test_logger.py | 22 + .../ray/tune/tests/test_progress_reporter.py | 200 +++----- .../ray/tune/tests/test_ray_trial_executor.py | 69 +-- python/ray/tune/tests/test_trial_runner_3.py | 3 +- .../tune/tests/test_trial_runner_callbacks.py | 2 +- python/ray/tune/tests/test_trial_scheduler.py | 1 - .../tune/tests/test_trial_scheduler_pbt.py | 12 +- python/ray/tune/trainable.py | 59 +-- python/ray/tune/trial.py | 65 +-- python/ray/tune/trial_runner.py | 104 +--- python/ray/tune/tune.py | 82 +-- python/ray/tune/utils/util.py | 9 +- python/ray/util/__init__.py | 2 +- python/ray/util/client/__init__.py | 3 +- python/ray/util/client/client_pickler.py | 15 +- python/ray/util/client/options.py | 1 + python/ray/util/client/server/proxier.py | 10 +- python/ray/util/client/server/server.py | 2 +- python/ray/util/client/worker.py | 20 +- python/ray/util/dask/scheduler_utils.py | 5 +- python/ray/util/placement_group.py | 4 +- python/ray/util/sgd/torch/torch_runner.py | 14 +- .../ray/util/sgd/torch/training_operator.py | 42 +- python/ray/util/sgd/v2/BUILD | 27 - python/ray/util/sgd/v2/__init__.py | 4 +- python/ray/util/sgd/v2/backends/backend.py | 26 +- python/ray/util/sgd/v2/backends/horovod.py | 2 - python/ray/util/sgd/v2/backends/torch.py | 2 - python/ray/util/sgd/v2/constants.py | 4 - .../v2/examples/tensorflow_mnist_example.py | 4 +- .../tune_cifar_pytorch_pbt_example.py | 200 -------- python/ray/util/sgd/v2/tests/test_backend.py | 4 - python/ray/util/sgd/v2/tests/test_gpu.py | 92 ---- python/ray/util/sgd/v2/tests/test_trainer.py | 82 ++- python/ray/util/sgd/v2/tests/test_tune.py | 31 +- python/ray/util/tracing/tracing_helper.py | 7 +- python/ray/worker.py | 47 +- python/ray/workers/setup_worker.py | 10 +- python/ray/workflow/common.py | 3 +- python/ray/workflow/execution.py | 7 +- python/ray/workflow/recovery.py | 29 +- python/ray/workflow/step_executor.py | 92 ++-- .../workflow/tests/test_basic_workflows_2.py | 43 +- python/ray/workflow/tests/test_lifetime.py | 26 +- python/ray/workflow/workflow_access.py | 4 +- python/ray/workflow/workflow_context.py | 109 +--- python/ray/workflow/workflow_storage.py | 27 +- python/requirements.txt | 5 +- python/requirements/ml/requirements_rllib.txt | 4 +- python/requirements/requirements_default.txt | 2 +- python/requirements_linters.txt | 1 - python/setup.py | 29 +- release/.buildkite/build_pipeline.py | 1 - release/RELEASE_CHECKLIST.md | 1 - release/RELEASE_PROCESS.rst | 3 - release/alerts/xgboost_tests.py | 4 +- release/e2e.py | 207 +++----- .../dask_xgboost_app_config.yaml | 5 +- .../golden_notebook_tests.yaml | 21 +- .../modin_xgboost_app_config.yaml | 5 +- .../workloads/dask_xgboost_test.py | 123 ++++- .../workloads/modin_xgboost_test.py | 119 ++++- .../workloads/torch_tune_serve_test.py | 4 +- .../golden_notebook_tests/workloads/util.py | 49 -- .../workloads/utils/utils.py | 5 + release/kubernetes_manual_tests/README.md | 25 - release/kubernetes_manual_tests/helm-test.sh | 8 - .../kubernetes_manual_tests/k8s-test-scale.sh | 11 - release/kubernetes_manual_tests/k8s-test.sh | 9 - .../k8s_release_tests.sh | 30 -- release/long_running_tests/tpl_cpu_1.yaml | 5 - .../large_scale_dask_on_ray_app_config.yaml | 1 + release/nightly_tests/dataset/app_config.yaml | 1 + .../dataset/dataset_shuffle_data_loader.py | 2 +- .../dataset/pipelined_ingestion_app.yaml | 1 + .../dataset/pipelined_training.py | 4 +- .../dataset/pipelined_training_app.yaml | 1 + .../dataset/shuffle_app_config.yaml | 1 + .../decision_tree_app_config.yaml | 1 + .../many_nodes_tests/app_config.yaml | 2 +- release/nightly_tests/nightly_tests.yaml | 25 +- .../placement_group_tests/app_config.yaml | 12 - .../placement_group_tests/cluster.py | 13 - .../placement_group_tests/compute.yaml | 27 - .../placement_group_tests/pg_run.py | 65 --- .../shuffle/shuffle_app_config.yaml | 2 + .../shuffle_data_loader_app_config.yaml | 1 + .../stress_tests/stress_tests_app_config.yaml | 1 + .../1.7.0/benchmarks/many_actors.txt | 10 - .../1.7.0/benchmarks/many_nodes.txt | 10 - .../1.7.0/benchmarks/many_pgs.txt | 10 - .../1.7.0/benchmarks/many_tasks.txt | 10 - release/release_logs/1.7.0/microbenchmark.txt | 134 ----- .../1.7.0/scalability/object_store.txt | 10 - .../1.7.0/scalability/single_node.txt | 16 - .../1.7.0/stress_tests/dead_actors.txt | 11 - .../1.7.0/stress_tests/many_tasks.txt | 19 - .../1.7.0/stress_tests/placement_group.txt | 9 - release/serve_tests/serve_tests.yaml | 15 - .../serve_cluster_fault_tolerance.py | 119 ----- .../workloads/serve_test_cluster_utils.py | 25 +- release/util/pip_download_test.sh | 2 +- rllib/BUILD | 89 +--- rllib/agents/a3c/a3c_tf_policy.py | 2 +- rllib/agents/a3c/a3c_torch_policy.py | 18 +- rllib/agents/a3c/tests/test_a2c.py | 11 +- rllib/agents/a3c/tests/test_a3c.py | 3 +- rllib/agents/ars/tests/test_ars.py | 10 +- rllib/agents/cql/cql.py | 3 +- rllib/agents/cql/cql_torch_policy.py | 67 ++- rllib/agents/cql/tests/test_cql.py | 11 +- rllib/agents/ddpg/ddpg_tf_model.py | 12 +- rllib/agents/ddpg/ddpg_tf_policy.py | 8 +- rllib/agents/ddpg/ddpg_torch_model.py | 12 +- rllib/agents/ddpg/ddpg_torch_policy.py | 37 +- rllib/agents/ddpg/tests/test_apex_ddpg.py | 6 +- rllib/agents/ddpg/tests/test_ddpg.py | 8 +- rllib/agents/ddpg/tests/test_td3.py | 3 +- rllib/agents/dqn/apex.py | 3 +- rllib/agents/dqn/dqn.py | 16 +- rllib/agents/dqn/dqn_torch_policy.py | 46 +- rllib/agents/dqn/learner_thread.py | 24 +- rllib/agents/dqn/r2d2.py | 14 +- rllib/agents/dqn/r2d2_tf_policy.py | 6 +- rllib/agents/dqn/r2d2_torch_policy.py | 44 +- rllib/agents/dqn/simple_q_tf_policy.py | 2 +- rllib/agents/dqn/simple_q_torch_policy.py | 17 +- rllib/agents/dqn/tests/test_apex_dqn.py | 15 +- rllib/agents/dqn/tests/test_dqn.py | 4 +- rllib/agents/dqn/tests/test_r2d2.py | 3 +- rllib/agents/dqn/tests/test_simple_q.py | 3 +- rllib/agents/dreamer/dreamer.py | 3 +- rllib/agents/impala/tests/test_impala.py | 12 +- rllib/agents/impala/vtrace_tf_policy.py | 26 +- rllib/agents/impala/vtrace_torch_policy.py | 45 +- rllib/agents/maml/maml.py | 17 +- rllib/agents/maml/tests/test_maml.py | 6 +- rllib/agents/marwil/tests/test_bc.py | 8 +- rllib/agents/marwil/tests/test_marwil.py | 8 +- rllib/agents/mbmpo/mbmpo.py | 17 +- rllib/agents/mbmpo/tests/test_mbmpo.py | 8 +- rllib/agents/pg/pg_torch_policy.py | 14 +- rllib/agents/pg/tests/test_pg.py | 10 +- rllib/agents/ppo/appo_tf_policy.py | 2 +- rllib/agents/ppo/appo_torch_policy.py | 47 +- rllib/agents/ppo/ddppo.py | 15 +- rllib/agents/ppo/ppo.py | 12 +- rllib/agents/ppo/ppo_torch_policy.py | 33 +- rllib/agents/ppo/tests/test_appo.py | 16 +- rllib/agents/ppo/tests/test_ddppo.py | 26 +- rllib/agents/ppo/tests/test_ppo.py | 29 +- rllib/agents/qmix/qmix_policy.py | 2 +- rllib/agents/sac/rnnsac.py | 7 + rllib/agents/sac/rnnsac_torch_policy.py | 32 +- rllib/agents/sac/sac_tf_model.py | 10 +- rllib/agents/sac/sac_tf_policy.py | 8 +- rllib/agents/sac/sac_torch_model.py | 8 +- rllib/agents/sac/sac_torch_policy.py | 62 ++- rllib/agents/sac/tests/test_rnnsac.py | 73 --- rllib/agents/sac/tests/test_sac.py | 38 +- rllib/agents/tests/test_trainer.py | 3 +- rllib/agents/trainer.py | 299 ++++------- .../alpha_zero/core/alpha_zero_policy.py | 7 +- rllib/contrib/bandits/agents/policy.py | 2 +- .../bandits/examples/LinTS_train_wheel_env.py | 3 +- rllib/contrib/maddpg/maddpg_policy.py | 2 +- rllib/contrib/sumo/connector.py | 5 +- rllib/env/base_env.py | 7 +- rllib/env/multi_agent_env.py | 3 +- rllib/env/policy_server_input.py | 16 +- rllib/env/remote_vector_env.py | 20 +- rllib/env/tests/test_local_inference.sh | 42 ++ .../tests/test_policy_client_server_setup.sh | 63 --- rllib/env/tests/test_remote_inference.sh | 41 ++ rllib/env/tests/test_remote_worker_envs.py | 98 ---- rllib/env/wrappers/unity3d_env.py | 16 +- .../collectors/simple_list_collector.py | 7 +- rllib/evaluation/metrics.py | 21 +- rllib/evaluation/rollout_worker.py | 30 +- rllib/examples/centralized_critic.py | 2 +- rllib/examples/custom_keras_model.py | 5 +- .../examples/custom_model_loss_and_metrics.py | 12 +- rllib/examples/deterministic_training.py | 6 +- .../env/coin_game_non_vectorized_env.py | 11 +- .../examples/env/coin_game_vectorized_env.py | 9 +- .../env/matrix_sequential_social_dilemma.py | 6 +- rllib/examples/env/random_env.py | 26 +- rllib/examples/pettingzoo_env.py | 18 +- .../remote_vector_env_with_custom_api.py | 3 +- .../rock_paper_scissors_multiagent.py | 8 +- rllib/examples/serving/cartpole_client.py | 2 +- rllib/examples/serving/unity3d_client.py | 14 +- .../examples/serving/unity3d_dummy_client.py | 144 ------ rllib/examples/serving/unity3d_server.py | 70 +-- rllib/examples/trajectory_view_api.py | 50 +- rllib/execution/common.py | 3 + rllib/execution/learner_thread.py | 25 +- rllib/execution/multi_gpu_learner_thread.py | 68 +-- rllib/execution/rollout_ops.py | 24 +- rllib/execution/train_ops.py | 79 +-- rllib/models/tests/test_preprocessors.py | 6 +- rllib/models/tf/complex_input_net.py | 8 +- rllib/models/torch/complex_input_net.py | 9 +- rllib/models/torch/torch_modelv2.py | 8 - rllib/policy/eager_tf_policy.py | 118 ++--- rllib/policy/policy.py | 114 +++-- rllib/policy/policy_template.py | 3 +- rllib/policy/sample_batch.py | 7 +- .../tests/test_compute_log_likelihoods.py | 2 +- rllib/policy/tf_policy.py | 31 +- rllib/policy/tf_policy_template.py | 9 +- rllib/policy/torch_policy.py | 38 +- rllib/tests/test_exec_api.py | 3 +- rllib/tests/test_supported_multi_agent.py | 26 +- rllib/tests/test_supported_spaces.py | 8 +- rllib/utils/__init__.py | 3 +- .../utils/exploration/stochastic_sampling.py | 20 +- rllib/utils/metrics/__init__.py | 0 rllib/utils/metrics/learner_info.py | 84 ---- rllib/utils/multi_agent.py | 21 +- rllib/utils/sgd.py | 55 ++- rllib/utils/test_utils.py | 316 +++--------- rllib/utils/tf_ops.py | 2 +- rllib/utils/tf_run_builder.py | 5 +- rllib/utils/torch_ops.py | 6 +- .../ray/gcs/gcs_server/gcs_node_manager.h | 1 - .../gcs_placement_group_scheduler.h | 65 ++- .../ray/gcs/gcs_server/gcs_resource_manager.h | 1 - src/mock/ray/gcs/pubsub/gcs_pub_sub.h | 27 - .../gcs/store_client/in_memory_store_client.h | 66 --- .../ray/gcs/store_client/redis_store_client.h | 67 --- src/mock/ray/gcs/store_client/store_client.h | 66 --- src/mock/ray/pubsub/publisher.h | 100 ---- src/mock/ray/pubsub/subscriber.h | 155 ------ src/mock/ray/raylet/node_manager.h | 5 - .../cluster_task_manager_interface.h | 2 + src/mock/ray/raylet_client/raylet_client.h | 47 +- src/mock/ray/rpc/worker/core_worker_client.h | 123 ----- .../ray/rpc/worker/core_worker_client_pool.h | 23 - src/ray/common/bundle_spec.cc | 25 +- src/ray/common/bundle_spec.h | 6 - src/ray/common/client_connection.cc | 2 +- src/ray/common/constants.h | 3 - src/ray/common/id.h | 1 - src/ray/common/network_util.h | 2 +- src/ray/common/ray_config_def.h | 37 +- src/ray/common/ray_internal_flag_def.h | 3 - src/ray/common/runtime_env_manager.cc | 16 +- src/ray/common/runtime_env_manager.h | 7 +- src/ray/common/task/task.cc | 9 +- src/ray/common/task/task.h | 10 +- src/ray/common/task/task_spec.cc | 32 +- src/ray/common/task/task_spec.h | 18 +- src/ray/common/task/task_util.h | 11 +- src/ray/core_worker/common.h | 24 +- src/ray/core_worker/context.cc | 15 +- src/ray/core_worker/context.h | 9 +- src/ray/core_worker/core_worker.cc | 326 ++++++------ src/ray/core_worker/core_worker.h | 105 ++-- ...io_ray_runtime_object_NativeObjectStore.cc | 5 +- .../io_ray_runtime_object_NativeObjectStore.h | 7 +- ...io_ray_runtime_task_NativeTaskSubmitter.cc | 6 +- src/ray/core_worker/reference_count.h | 2 + src/ray/core_worker/reference_count_test.cc | 6 +- .../memory_store/memory_store.cc | 40 +- .../memory_store/memory_store.h | 16 + src/ray/core_worker/test/core_worker_test.cc | 2 +- .../test/direct_task_transport_mock_test.cc | 4 +- .../test/direct_task_transport_test.cc | 358 +++----------- src/ray/core_worker/test/memory_store_test.cc | 11 +- .../transport/dependency_resolver.cc | 6 +- .../transport/direct_task_transport.cc | 93 +--- .../transport/direct_task_transport.h | 40 +- src/ray/gcs/asio.h | 2 +- .../gcs/gcs_client/service_based_accessor.cc | 2 +- .../test/global_state_accessor_test.cc | 1 - .../test/service_based_gcs_client_test.cc | 1 - .../gcs/gcs_server/gcs_actor_distribution.cc | 30 -- .../gcs/gcs_server/gcs_actor_distribution.h | 17 - src/ray/gcs/gcs_server/gcs_actor_manager.cc | 67 +-- src/ray/gcs/gcs_server/gcs_actor_manager.h | 13 +- src/ray/gcs/gcs_server/gcs_actor_scheduler.cc | 6 +- src/ray/gcs/gcs_server/gcs_actor_scheduler.h | 12 +- src/ray/gcs/gcs_server/gcs_node_manager.cc | 8 +- .../gcs_server/gcs_placement_group_manager.cc | 138 ++---- .../gcs_server/gcs_placement_group_manager.h | 52 +- .../gcs_placement_group_scheduler.cc | 28 +- .../gcs_placement_group_scheduler.h | 42 +- .../gcs/gcs_server/gcs_resource_manager.cc | 7 +- src/ray/gcs/gcs_server/gcs_server.cc | 19 +- src/ray/gcs/gcs_server/gcs_server.h | 1 + src/ray/gcs/gcs_server/gcs_server_main.cc | 5 +- src/ray/gcs/gcs_server/gcs_table_storage.h | 130 ++--- .../gcs_server/test/gcs_actor_manager_test.cc | 15 +- .../test/gcs_actor_scheduler_mock_test.cc | 139 ------ .../test/gcs_based_actor_scheduler_test.cc | 18 +- .../gcs_placement_group_manager_mock_test.cc | 174 ------- .../test/gcs_placement_group_manager_test.cc | 47 +- .../gcs_server/test/gcs_server_rpc_test.cc | 1 - .../gcs_server/test/gcs_server_test_util.h | 14 +- src/ray/gcs/pubsub/gcs_pub_sub.h | 3 - src/ray/gcs/redis_context.h | 2 +- src/ray/object_manager/object_buffer_pool.cc | 165 ++----- src/ray/object_manager/object_buffer_pool.h | 59 +-- src/ray/object_manager/object_manager.cc | 24 +- src/ray/object_manager/object_manager.h | 6 +- src/ray/object_manager/plasma/store.cc | 7 +- src/ray/object_manager/pull_manager.h | 2 +- src/ray/protobuf/agent_manager.proto | 1 - src/ray/protobuf/common.proto | 23 +- src/ray/protobuf/core_worker.proto | 1 - src/ray/protobuf/event.proto | 1 - src/ray/protobuf/gcs.proto | 31 +- src/ray/protobuf/gcs_service.proto | 2 +- src/ray/protobuf/job_agent.proto | 1 - src/ray/protobuf/node_manager.proto | 21 - src/ray/protobuf/object_manager.proto | 1 - src/ray/protobuf/pubsub.proto | 1 - src/ray/protobuf/ray_client.proto | 5 +- src/ray/protobuf/reporter.proto | 1 - src/ray/protobuf/runtime_env_agent.proto | 5 - src/ray/protobuf/serialization.proto | 1 - src/ray/protobuf/serve.proto | 47 +- src/ray/ray_version_script.lds | 1 + src/ray/raylet/agent_manager.cc | 54 +- src/ray/raylet/agent_manager.h | 9 +- src/ray/raylet/main.cc | 3 +- src/ray/raylet/node_manager.cc | 40 +- src/ray/raylet/node_manager.h | 13 +- .../placement_group_resource_manager.cc | 3 - src/ray/raylet/raylet.cc | 7 +- .../scheduling/cluster_resource_data.cc | 5 +- .../raylet/scheduling/cluster_resource_data.h | 4 +- .../scheduling/cluster_resource_scheduler.cc | 3 +- .../raylet/scheduling/cluster_task_manager.cc | 103 ++-- .../raylet/scheduling/cluster_task_manager.h | 22 +- .../cluster_task_manager_interface.h | 21 +- .../scheduling/cluster_task_manager_test.cc | 151 ++---- src/ray/raylet/scheduling/fixed_point.cc | 96 ++++ src/ray/raylet/scheduling/fixed_point.h | 115 ++--- .../raylet/scheduling/scheduling_policy.cc | 20 +- src/ray/raylet/scheduling/scheduling_policy.h | 13 +- .../scheduling/scheduling_policy_test.cc | 36 -- src/ray/raylet/worker.cc | 2 +- src/ray/raylet/worker_pool.cc | 124 ++--- src/ray/raylet/worker_pool.h | 7 +- src/ray/raylet/worker_pool_test.cc | 14 +- src/ray/raylet_client/raylet_client.cc | 33 +- src/ray/raylet_client/raylet_client.h | 28 +- src/ray/rpc/common.cc | 6 +- src/ray/rpc/common.h | 2 - src/ray/rpc/grpc_server.cc | 9 +- src/ray/rpc/grpc_server.h | 11 +- .../rpc/node_manager/node_manager_client.h | 3 - .../rpc/node_manager/node_manager_server.h | 5 - src/ray/rpc/server_call.h | 17 +- src/ray/rpc/test/grpc_server_client_test.cc | 18 +- src/ray/util/event.h | 10 +- src/ray/util/util.h | 44 +- src/ray/util/util_test.cc | 17 - streaming/src/queue/queue_handler.h | 2 +- streaming/src/test/mock_actor.cc | 2 +- .../patches/prometheus-windows-pollfd.patch | 37 +- ...les_boost-undefine-boost_fallthrough.patch | 8 + .../rules_boost-windows-linkopts.patch | 21 +- 650 files changed, 5489 insertions(+), 16804 deletions(-) delete mode 100644 .buildkite/pipeline.gpu.large.yml delete mode 100644 doc/examples/dask_xgboost/README.rst delete mode 100644 doc/examples/dask_xgboost/dask_xgboost.py delete mode 100644 doc/examples/dask_xgboost/dask_xgboost.yaml delete mode 100644 doc/examples/modin_xgboost/README.rst delete mode 100644 doc/examples/modin_xgboost/modin_xgboost.py delete mode 100644 doc/examples/modin_xgboost/modin_xgboost.yaml delete mode 100644 doc/source/data/.gitignore delete mode 100644 doc/source/data/_examples/README.rst delete mode 100644 doc/source/data/_examples/big_data_ingestion.py delete mode 100644 doc/source/data/big_data_ingestion.yaml delete mode 100644 doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst delete mode 100644 doc/source/raysgd/v2/migration-guide.rst delete mode 100644 java/serve/src/main/java/io/ray/serve/DeploymentInfo.java delete mode 100644 java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java delete mode 100644 java/serve/src/main/java/io/ray/serve/HandleOptions.java delete mode 100644 java/serve/src/main/java/io/ray/serve/HttpProxy.java delete mode 100644 java/serve/src/main/java/io/ray/serve/ProxyActor.java delete mode 100644 java/serve/src/main/java/io/ray/serve/ProxyRouter.java delete mode 100644 java/serve/src/main/java/io/ray/serve/RayServeConfig.java delete mode 100644 java/serve/src/main/java/io/ray/serve/RayServeHandle.java delete mode 100644 java/serve/src/main/java/io/ray/serve/RayServeMetrics.java delete mode 100644 java/serve/src/main/java/io/ray/serve/ReplicaSet.java delete mode 100644 java/serve/src/main/java/io/ray/serve/Router.java delete mode 100644 java/serve/src/main/java/io/ray/serve/ServeController.java delete mode 100644 java/serve/src/main/java/io/ray/serve/ServeProxy.java delete mode 100644 java/serve/src/main/java/io/ray/serve/api/Client.java create mode 100644 java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java delete mode 100644 java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java delete mode 100644 java/serve/src/main/java/io/ray/serve/util/CommonUtil.java delete mode 100644 java/serve/src/main/java/io/ray/serve/util/SocketUtil.java delete mode 100644 java/serve/src/test/java/io/ray/serve/DummyServeController.java delete mode 100644 java/serve/src/test/java/io/ray/serve/HttpProxyTest.java delete mode 100644 java/serve/src/test/java/io/ray/serve/ProxyActorTest.java delete mode 100644 java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java delete mode 100644 java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java delete mode 100644 java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java delete mode 100644 java/serve/src/test/java/io/ray/serve/RouterTest.java delete mode 100644 java/serve/src/test/java/io/ray/serve/api/ClientTest.java delete mode 100644 python/ray/_private/runtime_env/plugin.py delete mode 100644 python/ray/_private/tls_utils.py delete mode 100644 python/ray/autoscaler/_private/fake_multi_node/__init__.py delete mode 100644 python/ray/autoscaler/_private/fake_multi_node/example.yaml delete mode 100644 python/ray/autoscaler/_private/fake_multi_node/node_provider.py create mode 100644 python/ray/data/impl/tensor_block.py rename python/ray/serve/{replica.py => backend_worker.py} (91%) delete mode 100644 python/ray/sgd/callbacks.py delete mode 100644 python/ray/tests/test_autoscaler_fake_multinode.py delete mode 100644 python/ray/tests/test_client_compat.py delete mode 100644 python/ray/tests/test_runtime_env_plugin.py delete mode 100644 python/ray/tests/test_runtime_env_validation.py delete mode 100644 python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py delete mode 100644 python/ray/util/sgd/v2/tests/test_gpu.py delete mode 100644 release/golden_notebook_tests/workloads/util.py create mode 100644 release/golden_notebook_tests/workloads/utils/utils.py delete mode 100644 release/kubernetes_manual_tests/README.md delete mode 100755 release/kubernetes_manual_tests/helm-test.sh delete mode 100755 release/kubernetes_manual_tests/k8s-test-scale.sh delete mode 100755 release/kubernetes_manual_tests/k8s-test.sh delete mode 100644 release/kubernetes_manual_tests/k8s_release_tests.sh delete mode 100644 release/nightly_tests/placement_group_tests/app_config.yaml delete mode 100644 release/nightly_tests/placement_group_tests/cluster.py delete mode 100644 release/nightly_tests/placement_group_tests/compute.yaml delete mode 100644 release/nightly_tests/placement_group_tests/pg_run.py delete mode 100644 release/release_logs/1.7.0/benchmarks/many_actors.txt delete mode 100644 release/release_logs/1.7.0/benchmarks/many_nodes.txt delete mode 100644 release/release_logs/1.7.0/benchmarks/many_pgs.txt delete mode 100644 release/release_logs/1.7.0/benchmarks/many_tasks.txt delete mode 100644 release/release_logs/1.7.0/microbenchmark.txt delete mode 100644 release/release_logs/1.7.0/scalability/object_store.txt delete mode 100644 release/release_logs/1.7.0/scalability/single_node.txt delete mode 100644 release/release_logs/1.7.0/stress_tests/dead_actors.txt delete mode 100644 release/release_logs/1.7.0/stress_tests/many_tasks.txt delete mode 100644 release/release_logs/1.7.0/stress_tests/placement_group.txt delete mode 100644 release/serve_tests/workloads/serve_cluster_fault_tolerance.py delete mode 100644 rllib/agents/sac/tests/test_rnnsac.py create mode 100755 rllib/env/tests/test_local_inference.sh delete mode 100755 rllib/env/tests/test_policy_client_server_setup.sh create mode 100755 rllib/env/tests/test_remote_inference.sh delete mode 100644 rllib/env/tests/test_remote_worker_envs.py delete mode 100644 rllib/examples/serving/unity3d_dummy_client.py delete mode 100644 rllib/utils/metrics/__init__.py delete mode 100644 rllib/utils/metrics/learner_info.py delete mode 100644 src/mock/ray/gcs/pubsub/gcs_pub_sub.h delete mode 100644 src/mock/ray/gcs/store_client/in_memory_store_client.h delete mode 100644 src/mock/ray/gcs/store_client/redis_store_client.h delete mode 100644 src/mock/ray/gcs/store_client/store_client.h delete mode 100644 src/mock/ray/pubsub/publisher.h delete mode 100644 src/mock/ray/pubsub/subscriber.h delete mode 100644 src/mock/ray/rpc/worker/core_worker_client.h delete mode 100644 src/mock/ray/rpc/worker/core_worker_client_pool.h delete mode 100644 src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc delete mode 100644 src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc create mode 100644 src/ray/raylet/scheduling/fixed_point.cc create mode 100644 thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch diff --git a/.bazelrc b/.bazelrc index 2e4e7b36d10f9..a6ebeba272c0f 100644 --- a/.bazelrc +++ b/.bazelrc @@ -14,13 +14,12 @@ build:macos --copt="-g1" build:linux --cxxopt="-std=c++17" build:macos --cxxopt="-std=c++17" build:clang-cl --cxxopt="-std=c++17" -build:msvc-cl --cxxopt="/std:c++17" -build:windows --cxxopt="/std:c++17" +build:msvc --cxxopt="/std:c++17" # This workaround is needed to prevent Bazel from compiling the same file twice (once PIC and once not). build:linux --force_pic build:macos --force_pic build:clang-cl --compiler=clang-cl -build:msvc-cl --compiler=msvc-cl +build:msvc --compiler=msvc-cl # `LC_ALL` and `LANG` is needed for cpp worker tests, because they will call "ray start". # If we don't add them, python's `click` library will raise an error. build --action_env=LC_ALL @@ -39,7 +38,7 @@ build:windows --enable_runfiles build:linux --per_file_copt="-\\.(asm|S)$@-Werror" build:macos --per_file_copt="-\\.(asm|S)$@-Werror" build:clang-cl --per_file_copt="-\\.(asm|S)$@-Werror" -build:msvc-cl --per_file_copt="-\\.(asm|S)$@-WX" +build:msvc --per_file_copt="-\\.(asm|S)$@-WX" # Ignore warnings for protobuf generated files and external projects. build --per_file_copt="\\.pb\\.cc$@-w" build --per_file_copt="-\\.(asm|S)$,external/.*@-w" @@ -52,7 +51,7 @@ build --per_file_copt="-\\.(asm|S)$,external/com_github_grpc_grpc/.*@-DGRPC_BAZE # Don't generate warnings about kernel features we don't need https://github.com/ray-project/ray/issues/6832 build:linux --per_file_copt="-\\.(asm|S)$,external/com_github_grpc_grpc/.*@-DGPR_MANYLINUX1" # Ignore wchar_t -> char conversion warning on MSVC -build:msvc-cl --per_file_copt="external/boost/libs/regex/src/wc_regex_traits\\.cpp@-wd4244" +build:msvc --per_file_copt="external/boost/libs/regex/src/wc_regex_traits\\.cpp@-wd4244" build --http_timeout_scaling=5.0 build --verbose_failures build:iwyu --experimental_action_listener=//:iwyu_cpp @@ -178,7 +177,6 @@ build:debug --strip="never" # Undefined Behavior Sanitizer build:ubsan --strip=never build:ubsan --copt -fsanitize=undefined -build:ubsan --copt -fno-sanitize=vptr build:ubsan --copt -fno-sanitize-recover=all build:ubsan --copt -g build:ubsan --linkopt -fsanitize=undefined diff --git a/.buildkite/pipeline.gpu.large.yml b/.buildkite/pipeline.gpu.large.yml deleted file mode 100644 index 0bdbca8846841..0000000000000 --- a/.buildkite/pipeline.gpu.large.yml +++ /dev/null @@ -1,8 +0,0 @@ -- label: ":tv: :octopus: SGD GPU tests " - conditions: ["RAY_CI_SGD_AFFECTED"] - commands: - - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - - SGD_TESTING=1 INSTALL_HOROVOD=1 ./ci/travis/install-dependencies.sh - - pip install -Ur ./python/requirements_ml_docker.txt - - ./ci/travis/env_info.sh - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=gpu,gpu_only python/ray/util/sgd/... diff --git a/.buildkite/pipeline.gpu.yml b/.buildkite/pipeline.gpu.yml index e89aeaa9f2d63..0c2c14ecf805f 100644 --- a/.buildkite/pipeline.gpu.yml +++ b/.buildkite/pipeline.gpu.yml @@ -1,13 +1,3 @@ -# Todo: Enable once tests are available -#- label: ":tv: :octopus: Tune GPU tests " -# conditions: ["RAY_CI_TUNE_AFFECTED"] -# commands: -# - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT -# - TUNE_TESTING=1 ./ci/travis/install-dependencies.sh -# - pip install -Ur ./python/requirements_ml_docker.txt -# - ./ci/travis/env_info.sh -# - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-flaky,gpu,gpu_only python/ray/tune/... - - label: ":tv: :brain: RLlib: GPU Examples {A/B}" conditions: ["RAY_CI_RLLIB_AFFECTED"] commands: diff --git a/.buildkite/pipeline.macos.yml b/.buildkite/pipeline.macos.yml index e3ba9347c7cc7..592347d44007c 100644 --- a/.buildkite/pipeline.macos.yml +++ b/.buildkite/pipeline.macos.yml @@ -64,7 +64,7 @@ steps: commands: - *prelude_commands - TORCH_VERSION=1.6 ./ci/travis/install-dependencies.sh - - bazel test --config=ci --test_env=CI $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-flaky,-flaky-mac,-post_wheel_build -- + - bazel test --config=ci --test_env=CI $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-flaky,-flaky-mac -- //:all python/ray/serve/... python/ray/dashboard/... -rllib/... -core_worker_test - *epilogue_commands @@ -82,7 +82,7 @@ steps: - bazel test $(./scripts/bazel_export_options) --config=ci --test_env=CONDA_EXE --test_env=CONDA_PYTHON_EXE --test_env=CONDA_SHLVL --test_env=CONDA_PREFIX --test_env=CONDA_DEFAULT_ENV --test_env=CONDA_PROMPT_MODIFIER --test_env=CI - --test_tag_filters=-kubernetes,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-flaky,-flaky-mac + --test_tag_filters=-kubernetes,-jenkins_only,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-flaky,-flaky-mac python/ray/tests/... - *epilogue_commands @@ -91,7 +91,7 @@ steps: commands: - *prelude_commands - bazel test --config=ci $(./scripts/bazel_export_options) --test_env=CI - --test_tag_filters=-kubernetes,medium_size_python_tests_a_to_j,-flaky,-flaky-mac + --test_tag_filters=-kubernetes,-jenkins_only,medium_size_python_tests_a_to_j,-flaky,-flaky-mac python/ray/tests/... - *epilogue_commands @@ -100,7 +100,7 @@ steps: commands: - *prelude_commands - bazel test --config=ci $(./scripts/bazel_export_options) --test_env=CI - --test_tag_filters=-kubernetes,medium_size_python_tests_k_to_z,-flaky,-flaky-mac + --test_tag_filters=-kubernetes,-jenkins_only,medium_size_python_tests_k_to_z,-flaky,-flaky-mac python/ray/tests/... - *epilogue_commands @@ -110,7 +110,7 @@ steps: - *prelude_commands - ./ci/travis/install-dependencies.sh - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,flaky,flaky-mac + --test_tag_filters=-kubernetes,-jenkins_only,flaky,flaky-mac --test_env=CONDA_EXE --test_env=CONDA_PYTHON_EXE --test_env=CONDA_SHLVL diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 2941476580fc8..c0f6ccda286df 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -182,9 +182,7 @@ - TORCH_VERSION=1.6 ./ci/travis/install-dependencies.sh - ./dashboard/tests/run_ui_tests.sh - bazel test --config=ci $(./scripts/bazel_export_options) python/ray/dashboard/... - - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-post_wheel_build - python/ray/serve/... + - bazel test --config=ci $(./scripts/bazel_export_options) python/ray/serve/... - label: ":python: Minimal install" conditions: ["RAY_CI_PYTHON_AFFECTED"] @@ -210,7 +208,7 @@ # --test_tag_filters=flaky # -- //:all -rllib/... -core_worker_test - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,flaky + --test_tag_filters=-kubernetes,-jenkins_only,flaky --test_env=CONDA_EXE --test_env=CONDA_PYTHON_EXE --test_env=CONDA_SHLVL @@ -222,7 +220,7 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-client_tests,-flaky,-post_wheel_build,-worker-container + --test_tag_filters=-kubernetes,-jenkins_only,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-client_tests,-flaky,-post_wheel_build,-worker-container --test_env=CONDA_EXE --test_env=CONDA_PYTHON_EXE --test_env=CONDA_SHLVL @@ -230,7 +228,7 @@ --test_env=CONDA_DEFAULT_ENV python/ray/tests/... - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,client_tests,-flaky + --test_tag_filters=-kubernetes,-jenkins_only,client_tests,-flaky --test_env=RAY_CLIENT_MODE=1 --test_env=RAY_PROFILING=1 python/ray/tests/... - label: ":python: (Medium A-J)" @@ -238,14 +236,14 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,medium_size_python_tests_a_to_j,-flaky + --test_tag_filters=-kubernetes,-jenkins_only,medium_size_python_tests_a_to_j,-flaky python/ray/tests/... - label: ":python: (Medium K-Z)" conditions: ["RAY_CI_PYTHON_AFFECTED"] commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - bazel test --config=ci $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,medium_size_python_tests_k_to_z,-flaky + --test_tag_filters=-kubernetes,-jenkins_only,medium_size_python_tests_k_to_z,-flaky python/ray/tests/... - label: ":core: Debug Test" commands: @@ -253,7 +251,7 @@ - pip uninstall -y ray - RAY_DEBUG_BUILD=debug ./ci/travis/ci.sh build - bazel test --config=ci-debug $(./scripts/bazel_export_options) - --test_tag_filters=-kubernetes,debug_tests,-flaky + --test_tag_filters=-kubernetes,-jenkins_only,debug_tests,-flaky python/ray/tests/... - label: ":core: (ASAN tests)" conditions: ["RAY_CI_PYTHON_AFFECTED"] @@ -262,7 +260,7 @@ - RLLIB_TESTING=1 ./ci/travis/install-dependencies.sh - bazel test --config=ci --config=asan $(./scripts/bazel_export_options) --config=asan-buildkite - --test_tag_filters=-kubernetes,asan_tests,-flaky + --test_tag_filters=-kubernetes,-jenkins_only,asan_tests,-flaky --test_env=CONDA_EXE --test_env=CONDA_PYTHON_EXE --test_env=CONDA_SHLVL @@ -464,16 +462,16 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - TUNE_TESTING=1 ./ci/travis/install-dependencies.sh - - bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-example,-flaky,-py37,-soft_imports,-gpu_only python/ray/tune/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=example,-tf,-pytorch,-py37,-flaky,-soft_imports,-gpu_only python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,-example,-flaky,-py37,-soft_imports python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=example,-tf,-pytorch,-py37,-flaky,-soft_imports python/ray/tune/... - label: ":octopus: Tune tests and examples {2/2}" conditions: ["RAY_CI_TUNE_AFFECTED"] commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - TUNE_TESTING=1 ./ci/travis/install-dependencies.sh - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-soft_imports,-gpu_only python/ray/tune/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-soft_imports,-gpu_only python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-soft_imports python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-soft_imports python/ray/tune/... - label: ":octopus: Tune soft imports test" conditions: ["RAY_CI_TUNE_AFFECTED"] @@ -488,10 +486,10 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - SGD_TESTING=1 INSTALL_HOROVOD=1 ./ci/travis/install-dependencies.sh - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-client,-gpu_only python/ray/util/sgd/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-client,-gpu_only python/ray/util/sgd/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=client_unit_tests,-gpu_only --test_env=RAY_CLIENT_MODE=1 python/ray/util/sgd/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-gpu_only python/ray/util/sgd/v2/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-client python/ray/util/sgd/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-client python/ray/util/sgd/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=client_unit_tests --test_env=RAY_CLIENT_MODE=1 python/ray/util/sgd/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only python/ray/util/sgd/v2/... - label: ":octopus: Tune/SGD/Modin/Dask tests and examples. Python 3.7" conditions: ["RAY_CI_TUNE_AFFECTED", "RAY_CI_SGD_AFFECTED"] diff --git a/.buildkite/windows/install/bazel.ps1 b/.buildkite/windows/install/bazel.ps1 index adeee13df7209..46411cf3810f3 100644 --- a/.buildkite/windows/install/bazel.ps1 +++ b/.buildkite/windows/install/bazel.ps1 @@ -1,4 +1,4 @@ -$Env:BAZEL_URL="https://github.com/bazelbuild/bazel/releases/download/4.2.1/bazel-4.2.1-windows-x86_64.zip" +$Env:BAZEL_URL="https://github.com/bazelbuild/bazel/releases/download/3.2.0/bazel-3.2.0-windows-x86_64.zip" Write-Host ('Downloading {0} ...' -f $env:BAZEL_URL); [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; Invoke-WebRequest -Uri $env:BAZEL_URL -OutFile 'bazel.zip'; diff --git a/.clang-tidy b/.clang-tidy index 607f19902f3f4..2aa176da910cc 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,64 +1,27 @@ -# Disable the following checks with reasons in parenthesis: -# -# -bugprone-macro-parentheses (inconsistent style) -# -google-readability-todo (potentially too restrictive) +# Disable the following checks due to frequent false positives, noisiness, +# inconsistent style with existing codebase and other reasons: # -misc-non-private-member-variables-in-classes (potentially too restrictive) # -misc-unused-parameters (can be cleaned up in batch and enabled) # -modernize-avoid-c-arrays (too restrictive) -# -modernize-concat-nested-namespaces (inconsistent style) # -modernize-pass-by-value (too restrictive) # -modernize-return-braced-init-list (inconsistent style) # -modernize-use-emplace (more subtle behavior) -# -modernize-use-nodiscard (too much noise) # -modernize-use-trailing-return-type (inconsistent style) -# -modernize-avoid-bind (incorrect conversion) -# -modernize-loop-convert (more subtle behavior) -# -modernize-replace-disallow-copy-and-assign-macro (inconsistent style) -# -modernize-make-unique (doesn't work with private constructor) -# -modernize-make-shared (doesn't work with private constructor) -# Other readability-* rules (potentially too noisy, inconsistent style) -# Other rules not mentioned here or below (not yet evaluated) # # TODO: enable google-* and readability-* families of checks. Checks: > abseil-*, bugprone-*, - -bugprone-macro-parentheses, - google-*, - -google-readability-todo, misc-*, -misc-non-private-member-variables-in-classes, -misc-unused-parameters, modernize-*, -modernize-avoid-c-arrays, - -modernize-concat-nested-namespaces, -modernize-pass-by-value, -modernize-return-braced-init-list, -modernize-use-emplace, - -modernize-use-nodiscard, -modernize-use-trailing-return-type, - -modernize-avoid-bind, - -modernize-loop-convert, - -modernize-replace-disallow-copy-and-assign-macro, - -modernize-make-unique, - -modernize-make-shared, performance-*, - readability-avoid-const-params-in-decls, - readability-braces-around-statements, - readability-const-return-type, - readability-container-size-empty, - readability-delete-null-pointer, - readability-else-after-return, - readability-implicit-bool-conversion, - readability-make-member-function-const, - readability-misleading-indentation, - readability-misplaced-array-index, - readability-named-parameter, - readability-non-const-parameter, - readability-redundant-*, - readability-static-definition-in-anonymous-namespace, - readability-string-compare, - readability-suspicious-call-argument, CheckOptions: # Reduce noisiness of the bugprone-narrowing-conversions check. diff --git a/.flake8 b/.flake8 index cb93e3096d3ef..a4a3510a1bbeb 100644 --- a/.flake8 +++ b/.flake8 @@ -24,20 +24,4 @@ ignore = W605 I N - B001 - B002 - B003 - B004 - B005 - B007 - B008 - B009 - B010 - B011 - B012 - B013 - B014 - B015 - B016 - B017 avoid-escape = no diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3502b7042bf20..c4e254c2dd0f9 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -18,9 +18,6 @@ # Dependencies /python/setup.py @richardliaw @ericl @edoakes -# Formatting tool -/ci/travis/format.sh @richardliaw @ericl @edoakes - # Python worker. #/python/ray/ @ray-project/ray-core-python #!/python/ray/tune/ @ray-project/ray-core-python @@ -33,6 +30,7 @@ /java/*/pom_template.xml @jovany-wang @kfstorm @raulchen /java/api/ @jovany-wang @kfstorm @raulchen + # Ray Client /src/ray/protobuf/ray_client.proto @ijrsvt @ameerhajali @ckw017 @mwtian @@ -41,14 +39,6 @@ # Ray tune. /python/ray/tune/ @ray-project/ray-tune -# Ray data. -/python/ray/data/ @ericl @scv119 -/doc/source/data/ @ericl @scv119 - -# Ray workflows. -/python/ray/workflow/ @ericl @iycheng -/doc/source/workflows/ @ericl @iycheng - # RLlib. #/python/ray/rllib/ @ray-project/rllib diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9404a4a4d2517..8df9fe895df63 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,7 +26,7 @@ jobs: os: windows-2019 python-version: 3.8 # Can be 'msvc' or 'clang-cl' - config: msvc-cl + config: msvc env: BAZEL_CONFIG: ${{ matrix.config }} PYTHON: ${{ matrix.python-version }} @@ -111,6 +111,7 @@ jobs: TRAVIS_COMMIT: ${{ github.sha }} TRAVIS_JOB_ID: ${{ github.run_id }} run: | + # Multi thread in windowns for grpc not working now function clean_up() { echo "Performing cleanup" if [ "${GITHUB_EVENT_NAME}" != "pull_request" ]; then ./ci/travis/upload_build_info.sh; fi diff --git a/.gitpod/Dockerfile b/.gitpod/Dockerfile index ce2af682e0ed9..23682c0ed9687 100644 --- a/.gitpod/Dockerfile +++ b/.gitpod/Dockerfile @@ -15,7 +15,7 @@ RUN set -x; apt update \ && mv bazel.gpg /etc/apt/trusted.gpg.d/ \ && echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list \ && apt update && apt install bazel-3.7.2 -y \ - && pip3 install cython==0.29.0 pytest pandas tree tabulate pexpect sklearn joblib yapf==0.23.0 flake8==3.9.1 mypy==0.782 flake8-quotes flake8-bugbear==21.9.2 setproctitle==1.1.10 psutil \ + && pip3 install cython==0.29.0 pytest pandas tree tabulate pexpect sklearn joblib yapf==0.23.0 flake8==3.9.1 mypy==0.782 flake8-quotes setproctitle==1.1.10 psutil \ && python3 -c 'print("startup --output_base=/workspace/ray/.bazel-cache\nstartup --host_jvm_args=-Xmx1800m\nbuild --jobs=6")' > /etc/bazel.bazelrc RUN update-alternatives --install /usr/local/bin/python python /usr/bin/python3 30 \ diff --git a/BUILD.bazel b/BUILD.bazel index 7db7fc20f7cb7..ad6bd083fd4ad 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -414,6 +414,7 @@ cc_library( ], ) + [ "src/ray/raylet/scheduling/cluster_resource_data.cc", + "src/ray/raylet/scheduling/fixed_point.cc", "src/ray/raylet/scheduling/scheduling_ids.cc", ], hdrs = glob( @@ -552,7 +553,6 @@ cc_library( ":pubsub_lib", ":raylet_client_lib", ":worker_rpc", - "@com_google_absl//absl/container:btree", ], ) @@ -623,12 +623,10 @@ cc_library( "src/ray/stats/metric_exporter_client.cc", ], hdrs = [ - "src/ray/stats/metric.h", "src/ray/stats/metric_defs.h", "src/ray/stats/metric_exporter.h", "src/ray/stats/metric_exporter_client.h", "src/ray/stats/stats.h", - "src/ray/stats/tag_defs.h", ], copts = COPTS, linkopts = select({ @@ -1183,22 +1181,6 @@ cc_test( ], ) -cc_test( - name = "gcs_placement_group_manager_mock_test", - size = "small", - srcs = [ - "src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc", - ], - copts = COPTS, - tags = ["team:core"], - deps = [ - ":gcs_server_lib", - ":gcs_test_util_lib", - ":ray_mock", - "@com_google_googletest//:gtest_main", - ], -) - cc_test( name = "placement_group_resource_manager_test", size = "small", @@ -1531,21 +1513,6 @@ cc_test( ], ) -# cc_test( -# name = "gcs_actor_scheduler_mock_test", -# size = "small", -# srcs = [ -# "src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc", -# ], -# copts = COPTS, -# tags = ["team:core"], -# deps = [ -# ":gcs_server_lib", -# ":ray_mock", -# "@com_google_googletest//:gtest_main", -# ], -# ) - cc_test( name = "gcs_based_actor_scheduler_test", size = "small", diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index 96131feadba41..1925aedfa4edb 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -151,8 +151,8 @@ def ray_deps_setup(): # declaring it here allows us to avoid patching the latter. name = "boost", build_file = "@com_github_nelhage_rules_boost//:BUILD.boost", - sha256 = "83bfc1507731a0906e387fc28b7ef5417d591429e51e788417fe9ff025e116b1", - url = "https://boostorg.jfrog.io/artifactory/main/release/1.74.0/source/boost_1_74_0.tar.bz2", + sha256 = "d73a8da01e8bf8c7eda40b4c84915071a8c8a0df4a6734537ddde4a8580524ee", + url = "https://boostorg.jfrog.io/artifactory/main/release/1.71.0/source/boost_1_71_0.tar.bz2", patches = [ "//thirdparty/patches:boost-exception-no_warn_typeid_evaluated.patch", ], @@ -161,9 +161,10 @@ def ray_deps_setup(): auto_http_archive( name = "com_github_nelhage_rules_boost", # If you update the Boost version, remember to update the 'boost' rule. - url = "https://github.com/nelhage/rules_boost/archive/652b21e35e4eeed5579e696da0facbe8dba52b1f.tar.gz", - sha256 = "c1b8b2adc3b4201683cf94dda7eef3fc0f4f4c0ea5caa3ed3feffe07e1fb5b15", + url = "https://github.com/nelhage/rules_boost/archive/2613d04ab3d22dfc4543ea0a083d9adeaa0daf09.tar.gz", + sha256 = "512f913240e026099d4ca4a98b1ce8048c99de77fdc8e8584e9e2539ee119ca2", patches = [ + "//thirdparty/patches:rules_boost-undefine-boost_fallthrough.patch", "//thirdparty/patches:rules_boost-windows-linkopts.patch", ], ) diff --git a/benchmarks/object_store/test_object_store.py b/benchmarks/object_store/test_object_store.py index 022cb17e8b890..5e251f55f8884 100644 --- a/benchmarks/object_store/test_object_store.py +++ b/benchmarks/object_store/test_object_store.py @@ -65,7 +65,6 @@ def sum(self, arr): if "TEST_OUTPUT_JSON" in os.environ: out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") results = { - "broadcast_time": end - start, "object_size": OBJECT_SIZE, "num_nodes": NUM_NODES, "success": "1" diff --git a/benchmarks/single_node/test_single_node.py b/benchmarks/single_node/test_single_node.py index 3deaa389de600..fb44e7fe29ade 100644 --- a/benchmarks/single_node/test_single_node.py +++ b/benchmarks/single_node/test_single_node.py @@ -199,8 +199,7 @@ def test_large_object(): "num_args": MAX_ARGS, "returns_time": returns_time, "num_returns": MAX_RETURNS, - "get_time": get_time, - "num_get_args": MAX_RAY_GET_ARGS, + "get_time": MAX_RAY_GET_ARGS, "queued_time": queued_time, "num_queued": MAX_QUEUED_TASKS, "large_object_time": large_object_time, diff --git a/ci/asan_tests/run_asan_tests.sh b/ci/asan_tests/run_asan_tests.sh index ea2d4b8a697c5..5f84fe3ff6d40 100755 --- a/ci/asan_tests/run_asan_tests.sh +++ b/ci/asan_tests/run_asan_tests.sh @@ -39,10 +39,10 @@ asan_run() { cd "${RAY_DIR}" # Ray tests - bazel test --test_output=streamed python/ray/serve/... - bazel test --test_output=streamed python/ray/dashboard/... - bazel test --test_output=streamed python/ray/tests/... - bazel test --test_output=streamed python/ray/tune/... + bazel test --test_tag_filters=-jenkins_only --test_output=streamed python/ray/serve/... + bazel test --test_tag_filters=-jenkins_only --test_output=streamed python/ray/dashboard/... + bazel test --test_tag_filters=-jenkins_only --test_output=streamed python/ray/tests/... + bazel test --test_tag_filters=-jenkins_only --test_output=streamed python/ray/tune/... ) } diff --git a/ci/travis/bazel.py b/ci/travis/bazel.py index d462459fc1ead..d731734b6faa9 100755 --- a/ci/travis/bazel.py +++ b/ci/travis/bazel.py @@ -98,45 +98,35 @@ def info(self, *args): return result def aquery(self, *args): - out = self._call("aquery", "--output=jsonproto", *args) - return json.loads(out.decode(self.encoding)) + lines = self._call("aquery", "--output=textproto", *args).splitlines() + return textproto_parse(lines, self.encoding, json.JSONEncoder()) def parse_aquery_shell_calls(aquery_results): """Extracts and yields the command lines representing the genrule() rules from Bazel aquery results. """ - for action in aquery_results["actions"]: - if action["mnemonic"] != "Genrule": - continue - yield action["arguments"] + for (key, val) in aquery_results: + if key == "actions": + [mnemonic] = [pair[1] for pair in val if pair[0] == "mnemonic"] + if mnemonic == "Genrule": + yield [pair[1] for pair in val if pair[0] == "arguments"] def parse_aquery_output_artifacts(aquery_results): """Extracts and yields the file paths representing the output artifact from the provided Bazel aquery results. - - To understand the output of aquery command in textproto format, try: - bazel aquery --include_artifacts=true --output=jsonproto \ - 'mnemonic("Genrule", deps(//:*))' """ - fragments = {} - for fragment in aquery_results["pathFragments"]: - fragments[fragment["id"]] = fragment - artifacts = {} - for artifact in aquery_results["artifacts"]: - artifacts[artifact["id"]] = artifact - - def _path(fragment_id): - fragment = fragments[fragment_id] - parent = _path(fragment["parentId"]) if "parentId" in fragment else [] - return parent + [fragment["label"]] - - for action in aquery_results["actions"]: - for output_id in action["outputIds"]: - path = os.path.join(*_path(artifacts[output_id]["pathFragmentId"])) - yield path + for (key, val) in aquery_results: + if key == "artifacts": + [artifact_id] = [pair[1] for pair in val if pair[0] == "id"] + [exec_path] = [pair[1] for pair in val if pair[0] == "exec_path"] + artifacts[artifact_id] = exec_path + elif key == "actions": + output_ids = [pair[1] for pair in val if pair[0] == "output_ids"] + for output_id in output_ids: + yield artifacts[output_id] def textproto2json(infile, outfile): diff --git a/ci/travis/ci.sh b/ci/travis/ci.sh index 7faa9ae02a5be..6aa33a22a2000 100755 --- a/ci/travis/ci.sh +++ b/ci/travis/ci.sh @@ -139,7 +139,6 @@ test_python() { args+=( python/ray/serve/... python/ray/tests/... - -python/ray/serve:conda_env # runtime_env unsupported on Windows -python/ray/serve:test_api # segfault on windows? https://github.com/ray-project/ray/issues/12541 -python/ray/serve:test_router # timeout -python/ray/serve:test_handle # "fatal error" (?) https://github.com/ray-project/ray/pull/13695 @@ -182,7 +181,6 @@ test_python() { -python/ray/tests:test_ray_init # test_redis_port() seems to fail here, but pass in isolation -python/ray/tests:test_resource_demand_scheduler -python/ray/tests:test_reference_counting # too flaky 9/25/21 - -python/ray/tests:test_runtime_env_plugin # runtime_env not supported on Windows -python/ray/tests:test_runtime_env_env_vars # runtime_env not supported on Windows -python/ray/tests:test_runtime_env_complicated # conda install slow leading to timeout -python/ray/tests:test_stress # timeout @@ -334,52 +332,7 @@ install_ray() { ) } -validate_wheels_commit_str() { - if [ "${OSTYPE}" = msys ]; then - echo "Windows builds do not set the commit string, skipping wheel commit validity check." - return 0 - fi - - if [ -n "${BUILDKITE_COMMIT}" ]; then - EXPECTED_COMMIT=${BUILDKITE_COMMIT:-} - else - EXPECTED_COMMIT=${TRAVIS_COMMIT:-} - fi - - if [ -z "$EXPECTED_COMMIT" ]; then - echo "Could not validate expected wheel commits: TRAVIS_COMMIT is empty." - return 0 - fi - - for whl in .whl/*.whl; do - basename=${whl##*/} - - if [[ "$basename" =~ "_cpp" ]]; then - # cpp wheels cannot be checked this way - echo "Skipping CPP wheel ${basename} for wheel commit validation." - continue - fi - - folder=${basename%%-cp*} - WHL_COMMIT=$(unzip -p "$whl" "${folder}.data/purelib/ray/__init__.py" | grep "__commit__" | awk -F'"' '{print $2}') - - if [ "${WHL_COMMIT}" != "${EXPECTED_COMMIT}" ]; then - echo "Error: Observed wheel commit (${WHL_COMMIT}) is not expected commit (${EXPECTED_COMMIT}). Aborting." - exit 1 - fi - - echo "Wheel ${basename} has the correct commit: ${WHL_COMMIT}" - done - - echo "All wheels passed the sanity check and have the correct wheel commit set." -} - build_wheels() { - # Create wheel output directory and empty contents - # If buildkite runners are re-used, wheels from previous builds might be here, so we delete them. - mkdir -p .whl - rm -rf .whl/* || true - case "${OSTYPE}" in linux*) # Mount bazel cache dir to the docker container. @@ -400,6 +353,7 @@ build_wheels() { -e "RAY_DEBUG_BUILD=${RAY_DEBUG_BUILD:-}" ) + if [ -z "${BUILDKITE-}" ]; then # This command should be kept in sync with ray/python/README-building-wheels.md, # except the "${MOUNT_BAZEL_CACHE[@]}" part. @@ -407,25 +361,19 @@ build_wheels() { quay.io/pypa/manylinux2014_x86_64 /ray/python/build-wheel-manylinux2014.sh else rm -rf /ray-mount/* - rm -rf /ray-mount/.whl || true - rm -rf /ray/.whl || true cp -rT /ray /ray-mount - ls -a /ray-mount + ls /ray-mount docker run --rm -v /ray:/ray-mounted ubuntu:focal ls / docker run --rm -v /ray:/ray-mounted ubuntu:focal ls /ray-mounted docker run --rm -w /ray -v /ray:/ray "${MOUNT_BAZEL_CACHE[@]}" \ quay.io/pypa/manylinux2014_x86_64 /ray/python/build-wheel-manylinux2014.sh cp -rT /ray-mount /ray # copy new files back here find . | grep whl # testing - - validate_wheels_commit_str fi ;; darwin*) # This command should be kept in sync with ray/python/README-building-wheels.md. "${WORKSPACE_DIR}"/python/build-wheel-macos.sh - - validate_wheels_commit_str ;; msys*) keep_alive "${WORKSPACE_DIR}"/python/build-wheel-windows.sh diff --git a/ci/travis/format.sh b/ci/travis/format.sh index 7dbf608d18734..e31245faad61d 100755 --- a/ci/travis/format.sh +++ b/ci/travis/format.sh @@ -83,10 +83,6 @@ if [[ $(flake8 --version) != *"flake8_quotes"* ]]; then echo "WARNING: Ray uses flake8 with flake8_quotes. Might error without it. Install with: pip install flake8-quotes" fi -if [[ $(flake8 --version) != *"flake8-bugbear"* ]]; then - echo "WARNING: Ray uses flake8 with flake8-bugbear. Might error without it. Install with: pip install flake8-bugbear" -fi - SHELLCHECK_FLAGS=( --exclude=1090 # "Can't follow non-constant source. Use a directive to specify location." --exclude=1091 # "Not following {file} due to some error" diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index b52f75e8a4164..32b39ded1401e 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -408,7 +408,7 @@ install_dependencies() { # RLlib testing with TF 1.x. if [ "${RLLIB_TESTING-}" = 1 ] && { [ -n "${TF_VERSION-}" ] || [ -n "${TFP_VERSION-}" ]; }; then - pip install --upgrade tensorflow-probability=="${TFP_VERSION}" tensorflow=="${TF_VERSION}" gym==0.19 + pip install --upgrade tensorflow-probability=="${TFP_VERSION}" tensorflow=="${TF_VERSION}" gym fi # Additional Tune dependency for Horovod. diff --git a/ci/travis/test-worker-in-container.sh b/ci/travis/test-worker-in-container.sh index 00caeeb15839f..0d5b01eb49043 100644 --- a/ci/travis/test-worker-in-container.sh +++ b/ci/travis/test-worker-in-container.sh @@ -23,7 +23,7 @@ bash ./ci/travis/install-bazel.sh --system # shellcheck disable=SC2046 bazel test --test_timeout 60 --config=ci $(./scripts/bazel_export_options) \ ---test_tag_filters=-kubernetes,worker-container,-flaky \ +--test_tag_filters=-kubernetes,-jenkins_only,worker-container,-flaky \ python/ray/tests/... --test_output=all #pytest python/ray/tests/test_actor_in_container.py -s diff --git a/cpp/BUILD.bazel b/cpp/BUILD.bazel index 9603c863546c1..9d4e7416cda1b 100644 --- a/cpp/BUILD.bazel +++ b/cpp/BUILD.bazel @@ -90,7 +90,6 @@ genrule( mkdir -p "$$PY_CPP_DIR/lib/" && cp -f -r $$WORK_DIR/external/msgpack/include/* "$$PY_CPP_DIR/include" && cp -f -r "$$WORK_DIR/external/boost/boost/archive" "$$BOOST_DIR" && - cp -f -r "$$WORK_DIR/external/boost/boost/assert" "$$BOOST_DIR" && cp -f -r "$$WORK_DIR/external/boost/boost/bind" "$$BOOST_DIR" && cp -f -r "$$WORK_DIR/external/boost/boost/callable_traits" "$$BOOST_DIR" && cp -f -r "$$WORK_DIR/external/boost/boost/concept" "$$BOOST_DIR" && diff --git a/cpp/src/ray/api.cc b/cpp/src/ray/api.cc index ed2b1b89230cd..a1a8c6507541c 100644 --- a/cpp/src/ray/api.cc +++ b/cpp/src/ray/api.cc @@ -40,7 +40,7 @@ void Init() { bool IsInitialized() { return is_init_; } void Shutdown() { - // TODO(SongGuyang): Clean the ray runtime. + // TODO(guyang.sgy): Clean the ray runtime. internal::AbstractRayRuntime::DoShutdown(); is_init_ = false; } diff --git a/cpp/src/ray/runtime/abstract_ray_runtime.cc b/cpp/src/ray/runtime/abstract_ray_runtime.cc index db9fac32db4e8..177fae17d3122 100644 --- a/cpp/src/ray/runtime/abstract_ray_runtime.cc +++ b/cpp/src/ray/runtime/abstract_ray_runtime.cc @@ -145,7 +145,7 @@ InvocationSpec BuildInvocationSpec1(TaskType task_type, InvocationSpec invocation_spec; invocation_spec.task_type = task_type; invocation_spec.task_id = - TaskID::ForFakeTask(); // TODO(SongGuyang): make it from different task + TaskID::ForFakeTask(); // TODO(Guyang Song): make it from different task invocation_spec.remote_function_holder = remote_function_holder; invocation_spec.actor_id = actor; invocation_spec.args = TransformArgs(args); diff --git a/cpp/src/ray/runtime/object/native_object_store.cc b/cpp/src/ray/runtime/object/native_object_store.cc index 7add3b72b73af..d9326feb2ae66 100644 --- a/cpp/src/ray/runtime/object/native_object_store.cc +++ b/cpp/src/ray/runtime/object/native_object_store.cc @@ -116,7 +116,7 @@ std::vector NativeObjectStore::Wait(const std::vector &ids, int num_objects, int timeout_ms) { std::vector results; auto &core_worker = CoreWorkerProcess::GetCoreWorker(); - // TODO(SongGuyang): Support `fetch_local` option in API. + // TODO(guyang.sgy): Support `fetch_local` option in API. // Simply set `fetch_local` to be true. ::ray::Status status = core_worker.Wait(ids, num_objects, timeout_ms, &results, true); if (!status.ok()) { diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc index 40b7845578a74..cb24e9d3a2b8d 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc @@ -32,7 +32,7 @@ LocalModeTaskSubmitter::LocalModeTaskSubmitter( ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, const ActorCreationOptions &options) { - /// TODO(SongGuyang): Make the information of TaskSpecification more reasonable + /// TODO(Guyang Song): Make the information of TaskSpecification more reasonable /// We just reuse the TaskSpecification class and make the single process mode work. /// Maybe some infomation of TaskSpecification are not reasonable or invalid. /// We will enhance this after implement the cluster mode. @@ -82,7 +82,7 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, AbstractRayRuntime *runtime = &local_mode_ray_tuntime_; if (invocation.task_type == TaskType::ACTOR_CREATION_TASK || invocation.task_type == TaskType::ACTOR_TASK) { - /// TODO(SongGuyang): Handle task dependencies. + /// TODO(Guyang Song): Handle task dependencies. /// Execute actor task directly in the main thread because we must guarantee the actor /// task executed by calling order. TaskExecutor::Invoke(task_specification, actor, runtime, actor_contexts_, diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index f0a1e12faaa78..be24fe98d9a27 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -75,7 +75,7 @@ std::shared_ptr TaskExecutor::current_actor_ = nullptr; TaskExecutor::TaskExecutor(AbstractRayRuntime &abstract_ray_tuntime_) : abstract_ray_tuntime_(abstract_ray_tuntime_) {} -// TODO(SongGuyang): Make a common task execution function used for both local mode and +// TODO(Guyang Song): Make a common task execution function used for both local mode and // cluster mode. std::unique_ptr TaskExecutor::Execute(InvocationSpec &invocation) { abstract_ray_tuntime_.GetWorkerContext(); diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index 825e5ca52ab20..a528f17e03af3 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -16,10 +16,8 @@ #include #include - #include #include - #include "absl/synchronization/mutex.h" #include "invocation_spec.h" #include "ray/common/id.h" @@ -64,7 +62,7 @@ class TaskExecutor { public: TaskExecutor(AbstractRayRuntime &abstract_ray_tuntime_); - /// TODO(SongGuyang): support multiple tasks execution + /// TODO(Guyang Song): support multiple tasks execution std::unique_ptr Execute(InvocationSpec &invocation); static void Invoke( diff --git a/cpp/src/ray/util/process_helper.cc b/cpp/src/ray/util/process_helper.cc index 35ecd8123daa2..40f115e646e95 100644 --- a/cpp/src/ray/util/process_helper.cc +++ b/cpp/src/ray/util/process_helper.cc @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "process_helper.h" - #include +#include "process_helper.h" #include "ray/util/process.h" #include "ray/util/util.h" #include "src/ray/protobuf/gcs.pb.h" @@ -28,9 +27,9 @@ using ray::core::WorkerType; void ProcessHelper::StartRayNode(const int redis_port, const std::string redis_password, const std::vector &head_args) { - std::vector cmdargs( - {"ray", "start", "--head", "--port", std::to_string(redis_port), "--redis-password", - redis_password, "--node-ip-address", GetNodeIpAddress()}); + std::vector cmdargs({"ray", "start", "--head", "--port", + std::to_string(redis_port), "--redis-password", + redis_password}); if (!head_args.empty()) { cmdargs.insert(cmdargs.end(), head_args.begin(), head_args.end()); } @@ -125,7 +124,7 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback) if (!ConfigInternal::Instance().job_id.empty()) { options.job_id = JobID::FromHex(ConfigInternal::Instance().job_id); } else { - /// TODO(SongGuyang): Get next job id from core worker by GCS client. + /// TODO(Guyang Song): Get next job id from core worker by GCS client. /// Random a number to avoid repeated job ids. /// The repeated job ids will lead to task hang when driver connects to a existing /// cluster more than once. diff --git a/dashboard/agent.py b/dashboard/agent.py index f56e76f61fff9..7301b4299f95f 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -83,7 +83,7 @@ def __init__(self, assert self.ppid > 0 logger.info("Parent pid is %s", self.ppid) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( + self.grpc_port = ray._private.utils.add_port_to_grpc_server( self.server, f"[::]:{self.dashboard_agent_port}") logger.info("Dashboard agent grpc address: %s:%s", self.ip, self.grpc_port) diff --git a/dashboard/client/src/pages/job/JobDetail.tsx b/dashboard/client/src/pages/job/JobDetail.tsx index 892034937f107..b720b9c057de1 100644 --- a/dashboard/client/src/pages/job/JobDetail.tsx +++ b/dashboard/client/src/pages/job/JobDetail.tsx @@ -11,7 +11,6 @@ import { TableRow, Tabs, } from "@material-ui/core"; -import dayjs from "dayjs"; import React from "react"; import { Link, RouteComponentProps } from "react-router-dom"; import ActorTable from "../../components/ActorTable"; @@ -141,16 +140,6 @@ const JobDetailPage = (props: RouteComponentProps<{ id: string }>) => { Driver Pid:{" "} {jobInfo.driverPid} - - StartTime:{" "} - {dayjs(Number(jobInfo.startTime)).format("YYYY/MM/DD HH:mm:ss")} - - - EndTime:{" "} - {jobInfo.endTime > 0 - ? dayjs(Number(jobInfo.endTime)).format("YYYY/MM/DD HH:mm:ss") - : "-"} - {jobInfo.eventUrl && ( Event Link:{" "} diff --git a/dashboard/client/src/pages/job/index.tsx b/dashboard/client/src/pages/job/index.tsx index 81be74b03e2f4..e52af1ce5ec01 100644 --- a/dashboard/client/src/pages/job/index.tsx +++ b/dashboard/client/src/pages/job/index.tsx @@ -24,14 +24,7 @@ const useStyles = makeStyles((theme) => ({ }, })); -const columns = [ - "ID", - "DriverIpAddress", - "DriverPid", - "IsDead", - "StartTime", - "EndTime", -]; +const columns = ["ID", "DriverIpAddress", "DriverPid", "IsDead", "Timestamp"]; const JobList = () => { const classes = useStyles(); @@ -105,8 +98,7 @@ const JobList = () => { driverIpAddress, isDead, driverPid, - startTime, - endTime, + timestamp, }) => ( @@ -118,12 +110,7 @@ const JobList = () => { {isDead ? "true" : "false"} - {dayjs(Number(startTime)).format("YYYY/MM/DD HH:mm:ss")} - - - {endTime > 0 - ? dayjs(Number(endTime)).format("YYYY/MM/DD HH:mm:ss") - : "-"} + {dayjs(Number(timestamp)).format("YYYY/MM/DD HH:mm:ss")} ), diff --git a/dashboard/client/src/type/job.d.ts b/dashboard/client/src/type/job.d.ts index ef9181dd2c92d..c5ca4dce874c1 100644 --- a/dashboard/client/src/type/job.d.ts +++ b/dashboard/client/src/type/job.d.ts @@ -9,8 +9,6 @@ export type Job = { driverEntry: string; state: string; timestamp: number; - startTime: number; - endTime: number; namespaceId: string; driverPid: number; driverIpAddress: string; diff --git a/dashboard/head.py b/dashboard/head.py index c7cc857c5c787..7d7cb002b652a 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -7,7 +7,6 @@ import threading from grpc.experimental import aio as aiogrpc -from distutils.version import LooseVersion import ray._private.utils import ray._private.services @@ -121,7 +120,7 @@ def __init__(self, http_host, http_port, http_port_retries, redis_address, ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( + self.grpc_port = ray._private.utils.add_port_to_grpc_server( self.server, "[::]:0") logger.info("Dashboard head grpc address: %s:%s", self.ip, self.grpc_port) @@ -175,12 +174,8 @@ async def run(self): sys.exit(-1) # Create a http session for all modules. - # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore - if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"): - self.http_session = aiohttp.ClientSession( - loop=asyncio.get_event_loop()) - else: - self.http_session = aiohttp.ClientSession() + self.http_session = aiohttp.ClientSession( + loop=asyncio.get_event_loop()) # Waiting for GCS is ready. self.aiogrpc_gcs_channel = await make_gcs_grpc_channel( diff --git a/dashboard/modules/job/job_agent.py b/dashboard/modules/job/job_agent.py index f56a24db83586..34b72462501ab 100644 --- a/dashboard/modules/job/job_agent.py +++ b/dashboard/modules/job/job_agent.py @@ -202,9 +202,7 @@ def _gen_driver_code(self): # Per job config job_config_items = { - "runtime_env": { - "env_vars": self._job_info.env - }, + "worker_env": self._job_info.env, "code_search_path": [job_package_dir], } diff --git a/dashboard/modules/runtime_env/runtime_env_agent.py b/dashboard/modules/runtime_env/runtime_env_agent.py index 3c8b9c18bf9f3..5151278b1ab26 100644 --- a/dashboard/modules/runtime_env/runtime_env_agent.py +++ b/dashboard/modules/runtime_env/runtime_env_agent.py @@ -6,7 +6,6 @@ import os import time from typing import Dict, Set -from ray._private.utils import import_attr from ray.core.generated import runtime_env_agent_pb2 from ray.core.generated import runtime_env_agent_pb2_grpc @@ -18,8 +17,8 @@ _internal_kv_initialized) from ray._private.ray_logging import setup_component_logger from ray._private.runtime_env.conda import CondaManager -from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.runtime_env.working_dir import WorkingDirManager +from ray._private.runtime_env import RuntimeEnvContext logger = logging.getLogger(__name__) @@ -79,20 +78,13 @@ def get_or_create_logger(self, job_id: bytes): return self._per_job_logger_cache[job_id] async def CreateRuntimeEnv(self, request, context): - async def _setup_runtime_env(serialized_runtime_env, - serialized_allocated_resource_instances): + async def _setup_runtime_env(serialized_runtime_env): # This function will be ran inside a thread def run_setup_with_logger(): runtime_env: dict = json.loads(serialized_runtime_env or "{}") - allocated_resource: dict = json.loads( - serialized_allocated_resource_instances or "{}") # Use a separate logger for each job. per_job_logger = self.get_or_create_logger(request.job_id) - # TODO(chenk008): Add log about allocated_resource to - # avoid lint error. That will be moved to cgroup plugin. - per_job_logger.debug(f"Worker has resource :" - f"{allocated_resource}") context = RuntimeEnvContext( env_vars=runtime_env.get("env_vars")) self._conda_manager.setup( @@ -106,15 +98,6 @@ def run_setup_with_logger(): self._working_dir_uri_to_envs[uri].add( serialized_runtime_env) - # Run setup function from all the plugins - for plugin_class_path in runtime_env.get("plugins", {}).keys(): - plugin_class = import_attr(plugin_class_path) - # TODO(simon): implement uri support - plugin_class.create("uri not implemented", runtime_env, - context) - plugin_class.modify_context("uri not implemented", - runtime_env, context) - return context loop = asyncio.get_event_loop() @@ -155,8 +138,7 @@ def run_setup_with_logger(): for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES): try: runtime_env_context = await _setup_runtime_env( - serialized_env, - request.serialized_allocated_resource_instances) + serialized_env) break except Exception as ex: logger.exception("Runtime env creation failed.") diff --git a/dashboard/modules/snapshot/snapshot_head.py b/dashboard/modules/snapshot/snapshot_head.py index 87082f5463147..424e41ff45e16 100644 --- a/dashboard/modules/snapshot/snapshot_head.py +++ b/dashboard/modules/snapshot/snapshot_head.py @@ -73,10 +73,11 @@ async def get_job_info(self): for job_table_entry in reply.job_info_list: job_id = job_table_entry.job_id.hex() config = { + "env_vars": dict(job_table_entry.config.worker_env), "namespace": job_table_entry.config.ray_namespace, "metadata": dict(job_table_entry.config.metadata), "runtime_env": json.loads( - job_table_entry.config.runtime_env.serialized_runtime_env), + job_table_entry.config.serialized_runtime_env), } entry = { "is_dead": job_table_entry.is_dead, diff --git a/dashboard/modules/snapshot/snapshot_schema.json b/dashboard/modules/snapshot/snapshot_schema.json index 4768c2a5e292c..f660813110f1e 100644 --- a/dashboard/modules/snapshot/snapshot_schema.json +++ b/dashboard/modules/snapshot/snapshot_schema.json @@ -39,6 +39,9 @@ "config": { "type": "object", "properties": { + "envVars": { + "type": "object" + }, "namespace": { "type": "string" }, @@ -50,6 +53,7 @@ } }, "required": [ + "envVars", "namespace", "metadata", "runtimeEnv" diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index 6565ea08814cf..ea335c61bad21 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -107,7 +107,7 @@ def _search_agent(processes): agent_proc.kill() agent_proc.wait() # The agent will be restarted for imports failure. - for _ in range(300): + for x in range(50): agent_proc = _search_agent(raylet_proc.children()) if agent_proc: agent_pids.add(agent_proc.pid) diff --git a/doc/BUILD b/doc/BUILD index 81c112530ffec..eed30be63b145 100644 --- a/doc/BUILD +++ b/doc/BUILD @@ -3,38 +3,6 @@ # Please keep these sorted alphabetically, but start with the # root directory. # -------------------------------------------------------------------- - -# Support for Dask has been dropped in 3.6. -py_test( - name = "dask_xgboost", - size = "medium", - main = "examples/dask_xgboost/dask_xgboost.py", - srcs = ["examples/dask_xgboost/dask_xgboost.py"], - tags = ["exclusive", "team:ml", "py37"], - args = ["--smoke-test", "--address ''", "--num-actors 4", - "--cpus-per-actor 1", "--num-actors-inference 4", - "--cpus-per-actor-inference 1"] -) - -# Support for Modin has been dropped in 3.6. -py_test( - name = "modin_xgboost", - size = "medium", - main = "examples/modin_xgboost/modin_xgboost.py", - srcs = ["examples/modin_xgboost/modin_xgboost.py"], - tags = ["exclusive", "team:ml", "py37"], - args = ["--smoke-test", "--address ''", "--num-actors 4", - "--cpus-per-actor 1", "--num-actors-inference 4", - "--cpus-per-actor-inference 1"] -) - -py_test( - name = "big_data_ingestion", - size = "small", - srcs = ["source/data/_examples/big_data_ingestion.py"], - tags = ["exclusive", "team:core", "py37"] -) - py_test( name = "plot_hyperparameter", size = "small", diff --git a/doc/Makefile b/doc/Makefile index 39013f4175b43..3b0914ab942fe 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -6,7 +6,7 @@ SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build -AUTOGALLERYDIR= source/auto_examples source/tune/tutorials source/tune/generated_guides source/data/examples +AUTOGALLERYDIR= source/auto_examples source/tune/tutorials source/tune/generated_guides # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/doc/examples/dask_xgboost/README.rst b/doc/examples/dask_xgboost/README.rst deleted file mode 100644 index 8feca331c5d78..0000000000000 --- a/doc/examples/dask_xgboost/README.rst +++ /dev/null @@ -1 +0,0 @@ -:orphan: diff --git a/doc/examples/dask_xgboost/dask_xgboost.py b/doc/examples/dask_xgboost/dask_xgboost.py deleted file mode 100644 index d4e50a33faf70..0000000000000 --- a/doc/examples/dask_xgboost/dask_xgboost.py +++ /dev/null @@ -1,321 +0,0 @@ -# flake8: noqa: E501 -""" -XGBoost-Ray with Dask -====================== - -This notebook includes an example workflow using -`XGBoost-Ray `_ and -`Dask `_ for distributed model training, -hyperparameter optimization, and prediction. -""" - -############################################################################### -# Cluster Setup -# ------------- -# -# First, we'll set up our Ray Cluster. The provided ``dask_xgboost.yaml`` -# cluster config can be used to set up an AWS cluster with 64 CPUs. -# -# The following steps assume you are in a directory with both -# ``dask_xgboost.yaml`` and this file saved as ``dask_xgboost.ipynb``. -# -# **Step 1:** Bring up the Ray cluster. -# -# .. code-block:: bash -# -# $ pip install ray boto3 -# $ ray up dask_xgboost.yaml -# -# **Step 2:** Move ``dask_xgboost.ipynb`` to the cluster and start Jupyter. -# -# .. code-block:: bash -# -# $ ray rsync_up dask_xgboost.yaml "./dask_xgboost.ipynb" \ -# "~/dask_xgboost.ipynb" -# $ ray exec dask_xgboost.yaml --port-forward=9999 "jupyter notebook \ -# --port=9999" -# -# You can then access this notebook at the URL that is output: -# ``http://localhost:9999/?token=`` - -############################################################################### -# Python Setup -# ------------ -# -# First, we'll import all the libraries we'll be using. This step also helps us -# verify that the environment is configured correctly. If any of the imports -# are missing, an exception will be raised. - -import argparse -import time - -import dask -import dask.dataframe as dd -from xgboost_ray import RayDMatrix, RayParams, train, predict - -import ray -from ray import tune -from ray.util.dask import ray_dask_get - -############################################################################### -# -# Next, let's parse some arguments. This will be used for executing the ``.py`` -# file, but not for the ``.ipynb``. If you are using the interactive notebook, -# you can directly override the arguments manually. - -parser = argparse.ArgumentParser() -parser.add_argument( - "--address", type=str, default="auto", help="The address to use for Ray.") -parser.add_argument( - "--smoke-test", - action="store_true", - help="Read a smaller dataset for quick testing purposes.") -parser.add_argument( - "--num-actors", - type=int, - default=4, - help="Sets number of actors for training.") -parser.add_argument( - "--cpus-per-actor", - type=int, - default=6, - help="The number of CPUs per actor for training.") -parser.add_argument( - "--num-actors-inference", - type=int, - default=16, - help="Sets number of actors for inference.") -parser.add_argument( - "--cpus-per-actor-inference", - type=int, - default=2, - help="The number of CPUs per actor for inference.") -# Ignore -f from ipykernel_launcher -args, _ = parser.parse_known_args() - -############################################################################### -# Override these arguments as needed: - -address = args.address -smoke_test = args.smoke_test -num_actors = args.num_actors -cpus_per_actor = args.cpus_per_actor -num_actors_inference = args.num_actors_inference -cpus_per_actor_inference = args.cpus_per_actor_inference - -############################################################################### -# Connecting to the Ray cluster -# ----------------------------- -# Now, let's connect our Python script to this newly deployed Ray cluster! - -if not ray.is_initialized(): - ray.init(address=address) - -############################################################################### -# Data Preparation -# ----------------- -# We will use the `HIGGS dataset from the UCI Machine Learning dataset -# repository `_. The HIGGS -# dataset consists of 11,000,000 samples and 28 attributes, which is large -# enough size to show the benefits of distributed computation. -# -# We set the Dask scheduler to ``ray_dask_get`` to use `Dask on Ray -# `_ backend. - -LABEL_COLUMN = "label" -if smoke_test: - # Test dataset with only 10,000 records. - FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \ - ".csv" -else: - # Full dataset. This may take a couple of minutes to load. - FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \ - "/00280/HIGGS.csv.gz" -colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)] -dask.config.set(scheduler=ray_dask_get) - -############################################################################### - -load_data_start_time = time.time() - -data = dd.read_csv(FILE_URL, names=colnames) -data = data[sorted(colnames)] -data = data.persist() - -load_data_end_time = time.time() -load_data_duration = load_data_end_time - load_data_start_time -print(f"Dataset loaded in {load_data_duration} seconds.") - -############################################################################### -# With the connection established, we can now create the Dask dataframe. -# -# We will split the data into a training set and a evaluation set using a 80-20 -# proportion. - -train_df, eval_df = data.random_split([0.8, 0.2]) -train_df, eval_df = train_df.persist(), eval_df.persist() -print(train_df, eval_df) - -############################################################################### -# Distributed Training -# -------------------- -# The ``train_xgboost`` function contains all of the logic necessary for -# training using XGBoost-Ray. -# -# Distributed training can not only speed up the process, but also allow you -# to use datasets that are to large to fit in memory of a single node. With -# distributed training, the dataset is sharded across different actors -# running on separate nodes. Those actors communicate with each other to -# create the final model. -# -# First, the dataframes are wrapped in ``RayDMatrix`` objects, which handle -# data sharding across the cluster. Then, the ``train`` function is called. -# The evaluation scores will be saved to ``evals_result`` dictionary. The -# function returns a tuple of the trained model (booster) and the evaluation -# scores. -# -# The ``ray_params`` variable expects a ``RayParams`` object that contains -# Ray-specific settings, such as the number of workers. - - -def train_xgboost(config, train_df, test_df, target_column, ray_params): - train_set = RayDMatrix(train_df, target_column) - test_set = RayDMatrix(test_df, target_column) - - evals_result = {} - - train_start_time = time.time() - - # Train the classifier - bst = train( - params=config, - dtrain=train_set, - evals=[(test_set, "eval")], - evals_result=evals_result, - ray_params=ray_params) - - train_end_time = time.time() - train_duration = train_end_time - train_start_time - print(f"Total time taken: {train_duration} seconds.") - - model_path = "model.xgb" - bst.save_model(model_path) - print("Final validation error: {:.4f}".format( - evals_result["eval"]["error"][-1])) - - return bst, evals_result - - -############################################################################### -# We can now pass our Dask dataframes and run the function. We will use -# ``RayParams`` to specify that the number of actors and CPUs to train with. -# -# The dataset has to be downloaded onto the cluster, which may take a few -# minutes. - -# standard XGBoost config for classification -config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], -} - -bst, evals_result = train_xgboost( - config, train_df, eval_df, LABEL_COLUMN, - RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors)) -print(f"Results: {evals_result}") - -############################################################################### -# Hyperparameter optimization -# --------------------------- -# If we are not content with the results obtained with default XGBoost -# parameters, we can use `Ray Tune -# `_ for cutting-edge -# distributed hyperparameter tuning. XGBoost-Ray automatically integrates -# with Ray Tune, meaning we can use the same training function as before. -# -# In this workflow, we will tune three hyperparameters - ``eta``, ``subsample`` -# and ``max_depth``. We are using `Tune's samplers to define the search -# space `_. -# -# The experiment configuration is done through ``tune.run``. We set the amount -# of resources each trial (hyperparameter combination) requires by using the -# ``get_tune_resources`` method of ``RayParams``. The ``num_samples`` argument -# controls how many trials will be ran in total. In the end, the best -# combination of hyperparameters evaluated during the experiment will be -# returned. -# -# By default, Tune will use simple random search. However, Tune also -# provides various `search algorithms -# `_ and -# `schedulers `_ -# to further improve the optimization process. - - -def tune_xgboost(train_df, test_df, target_column): - # Set XGBoost config. - config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], - "eta": tune.loguniform(1e-4, 1e-1), - "subsample": tune.uniform(0.5, 1.0), - "max_depth": tune.randint(1, 9) - } - - ray_params = RayParams( - max_actor_restarts=1, - cpus_per_actor=cpus_per_actor, - num_actors=num_actors) - - tune_start_time = time.time() - - analysis = tune.run( - tune.with_parameters( - train_xgboost, - train_df=train_df, - test_df=test_df, - target_column=target_column, - ray_params=ray_params), - # Use the `get_tune_resources` helper function to set the resources. - resources_per_trial=ray_params.get_tune_resources(), - config=config, - num_samples=10, - metric="eval-error", - mode="min") - - tune_end_time = time.time() - tune_duration = tune_end_time - tune_start_time - print(f"Total time taken: {tune_duration} seconds.") - - accuracy = 1. - analysis.best_result["eval-error"] - print(f"Best model parameters: {analysis.best_config}") - print(f"Best model total accuracy: {accuracy:.4f}") - - return analysis.best_config - - -############################################################################### -# Hyperparameter optimization may take some time to complete. - -tune_xgboost(train_df, eval_df, LABEL_COLUMN) - -############################################################################### -# Prediction -# ---------- -# With the model trained, we can now predict on unseen data. For the -# purposes of this example, we will use the same dataset for prediction as -# for training. -# -# Since prediction is naively parallelizable, distributing it over multiple -# actors can measurably reduce the amount of time needed. - -inference_df = RayDMatrix(data, ignore=[LABEL_COLUMN, "partition"]) -results = predict( - bst, - inference_df, - ray_params=RayParams( - cpus_per_actor=cpus_per_actor_inference, - num_actors=num_actors_inference)) - -print(results) diff --git a/doc/examples/dask_xgboost/dask_xgboost.yaml b/doc/examples/dask_xgboost/dask_xgboost.yaml deleted file mode 100644 index e598a115069b6..0000000000000 --- a/doc/examples/dask_xgboost/dask_xgboost.yaml +++ /dev/null @@ -1,24 +0,0 @@ -cluster_name: dask_xgboost - -max_workers: 3 - -provider: - type: aws - region: us-west-1 - -auth: - ssh_user: ubuntu - -available_node_types: - 16_cpu_node: - min_workers: 3 - max_workers: 3 - node_config: - InstanceType: m5.4xlarge - ImageId: latest_dlami - resources: { } - -head_node_type: 16_cpu_node - -setup_commands: - - pip install -U jupyter ray[tune] xgboost_ray[default] dask pandas diff --git a/doc/examples/modin_xgboost/README.rst b/doc/examples/modin_xgboost/README.rst deleted file mode 100644 index 8feca331c5d78..0000000000000 --- a/doc/examples/modin_xgboost/README.rst +++ /dev/null @@ -1 +0,0 @@ -:orphan: diff --git a/doc/examples/modin_xgboost/modin_xgboost.py b/doc/examples/modin_xgboost/modin_xgboost.py deleted file mode 100644 index bcbe6c0968068..0000000000000 --- a/doc/examples/modin_xgboost/modin_xgboost.py +++ /dev/null @@ -1,233 +0,0 @@ -""" -XGBoost-Ray with Modin -====================== - -This notebook includes an example workflow using -`XGBoost-Ray `_ and -`Modin `_ for distributed model -training and prediction. -""" - -############################################################################### -# Cluster Setup -# ------------- -# -# First, we'll set up our Ray Cluster. The provided ``modin_xgboost.yaml`` -# cluster config can be used to set up an AWS cluster with 64 CPUs. -# -# The following steps assume you are in a directory with both -# ``modin_xgboost.yaml`` and this file saved as ``modin_xgboost.ipynb``. -# -# **Step 1:** Bring up the Ray cluster. -# -# .. code-block:: bash -# -# $ pip install ray boto3 -# $ ray up modin_xgboost.yaml -# -# **Step 2:** Move ``modin_xgboost.ipynb`` to the cluster and start Jupyter. -# -# .. code-block:: bash -# -# $ ray rsync_up modin_xgboost.yaml "./modin_xgboost.ipynb" \ -# "~/modin_xgboost.ipynb" -# $ ray exec modin_xgboost.yaml --port-forward=9999 "jupyter notebook \ -# --port=9999" -# -# You can then access this notebook at the URL that is output: -# ``http://localhost:9999/?token=`` - -############################################################################### -# Python Setup -# ------------ -# -# First, we'll import all the libraries we'll be using. This step also helps us -# verify that the environment is configured correctly. If any of the imports -# are missing, an exception will be raised. - -import argparse -import time - -import modin.pandas as pd -from modin.experimental.sklearn.model_selection import train_test_split -from xgboost_ray import RayDMatrix, RayParams, train, predict - -import ray - -############################################################################### -# -# Next, let's parse some arguments. This will be used for executing the ``.py`` -# file, but not for the ``.ipynb``. If you are using the interactive notebook, -# you can directly override the arguments manually. - -parser = argparse.ArgumentParser() -parser.add_argument( - "--address", type=str, default="auto", help="The address to use for Ray.") -parser.add_argument( - "--smoke-test", - action="store_true", - help="Read a smaller dataset for quick testing purposes.") -parser.add_argument( - "--num-actors", - type=int, - default=4, - help="Sets number of actors for training.") -parser.add_argument( - "--cpus-per-actor", - type=int, - default=8, - help="The number of CPUs per actor for training.") -parser.add_argument( - "--num-actors-inference", - type=int, - default=16, - help="Sets number of actors for inference.") -parser.add_argument( - "--cpus-per-actor-inference", - type=int, - default=2, - help="The number of CPUs per actor for inference.") -# Ignore -f from ipykernel_launcher -args, _ = parser.parse_known_args() - -############################################################################### -# Override these arguments as needed: - -address = args.address -smoke_test = args.smoke_test -num_actors = args.num_actors -cpus_per_actor = args.cpus_per_actor -num_actors_inference = args.num_actors_inference -cpus_per_actor_inference = args.cpus_per_actor_inference - -############################################################################### -# Connecting to the Ray cluster -# ----------------------------- -# Now, let's connect our Python script to this newly deployed Ray cluster! - -if not ray.is_initialized(): - ray.init(address=address) - -############################################################################### -# Data Preparation -# ----------------- -# We will use the `HIGGS dataset from the UCI Machine Learning dataset -# repository `_. The HIGGS -# dataset consists of 11,000,000 samples and 28 attributes, which is large -# enough size to show the benefits of distributed computation. - -LABEL_COLUMN = "label" -if smoke_test: - # Test dataset with only 10,000 records. - FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \ - ".csv" -else: - # Full dataset. This may take a couple of minutes to load. - FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \ - "/00280/HIGGS.csv.gz" - -colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)] - -############################################################################### - -load_data_start_time = time.time() - -df = pd.read_csv(FILE_URL, names=colnames) - -load_data_end_time = time.time() -load_data_duration = load_data_end_time - load_data_start_time -print(f"Dataset loaded in {load_data_duration} seconds.") - -############################################################################### -# Split data into training and validation. - -df_train, df_validation = train_test_split(df) -print(df_train, df_validation) - -############################################################################### -# Distributed Training -# -------------------- -# The ``train_xgboost`` function contains all of the logic necessary for -# training using XGBoost-Ray. -# -# Distributed training can not only speed up the process, but also allow you -# to use datasets that are to large to fit in memory of a single node. With -# distributed training, the dataset is sharded across different actors -# running on separate nodes. Those actors communicate with each other to -# create the final model. -# -# First, the dataframes are wrapped in ``RayDMatrix`` objects, which handle -# data sharding across the cluster. Then, the ``train`` function is called. -# The evaluation scores will be saved to ``evals_result`` dictionary. The -# function returns a tuple of the trained model (booster) and the evaluation -# scores. -# -# The ``ray_params`` variable expects a ``RayParams`` object that contains -# Ray-specific settings, such as the number of workers. - - -def train_xgboost(config, train_df, test_df, target_column, ray_params): - train_set = RayDMatrix(train_df, target_column) - test_set = RayDMatrix(test_df, target_column) - - evals_result = {} - - train_start_time = time.time() - - # Train the classifier - bst = train( - params=config, - dtrain=train_set, - evals=[(test_set, "eval")], - evals_result=evals_result, - verbose_eval=False, - num_boost_round=100, - ray_params=ray_params) - - train_end_time = time.time() - train_duration = train_end_time - train_start_time - print(f"Total time taken: {train_duration} seconds.") - - model_path = "model.xgb" - bst.save_model(model_path) - print("Final validation error: {:.4f}".format( - evals_result["eval"]["error"][-1])) - - return bst, evals_result - - -############################################################################### -# We can now pass our Modin dataframes and run the function. We will use -# ``RayParams`` to specify that the number of actors and CPUs to train with. - -# standard XGBoost config for classification -config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], -} - -bst, evals_result = train_xgboost( - config, df_train, df_validation, LABEL_COLUMN, - RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors)) -print(f"Results: {evals_result}") - -############################################################################### -# Prediction -# ---------- -# With the model trained, we can now predict on unseen data. For the -# purposes of this example, we will use the same dataset for prediction as -# for training. -# -# Since prediction is naively parallelizable, distributing it over multiple -# actors can measurably reduce the amount of time needed. - -inference_df = RayDMatrix(df, ignore=[LABEL_COLUMN, "partition"]) -results = predict( - bst, - inference_df, - ray_params=RayParams( - cpus_per_actor=cpus_per_actor_inference, - num_actors=num_actors_inference)) - -print(results) diff --git a/doc/examples/modin_xgboost/modin_xgboost.yaml b/doc/examples/modin_xgboost/modin_xgboost.yaml deleted file mode 100644 index 914cbdb207af2..0000000000000 --- a/doc/examples/modin_xgboost/modin_xgboost.yaml +++ /dev/null @@ -1,24 +0,0 @@ -cluster_name: modin_xgboost - -max_workers: 3 - -provider: - type: aws - region: us-west-1 - -auth: - ssh_user: ubuntu - -available_node_types: - 16_cpu_node: - min_workers: 3 - max_workers: 3 - node_config: - InstanceType: m5.4xlarge - ImageId: latest_dlami - resources: { } - -head_node_type: 16_cpu_node - -setup_commands: - - pip install -U jupyter ray xgboost_ray[default] modin pandas diff --git a/doc/examples/overview.rst b/doc/examples/overview.rst index be438f3580783..8555799094ef9 100644 --- a/doc/examples/overview.rst +++ b/doc/examples/overview.rst @@ -61,8 +61,6 @@ Machine Learning Examples plot_lbfgs.rst plot_example-lm.rst plot_newsreader.rst - dask_xgboost/dask_xgboost.rst - modin_xgboost/modin_xgboost.rst .. customgalleryitem:: @@ -88,14 +86,6 @@ Machine Learning Examples :tooltip: Implementing a simple news reader using Ray. :description: :doc:`/auto_examples/plot_newsreader` -.. customgalleryitem:: - :tooltip: Train an XGBoost-Ray model using Dask for data processing. - :description: :doc:`/auto_examples/dask_xgboost/dask_xgboost` - -.. customgalleryitem:: - :tooltip: Train an XGBoost-Ray model using Modin for data processing. - :description: :doc:`/auto_examples/modin_xgboost/modin_xgboost` - .. raw:: html @@ -148,4 +138,4 @@ These are full guides on how you can use Ray with various Machine Learning libra .. customgalleryitem:: :tooltip: Using Ray with PyTorch Lightning. :figure: /images/pytorch_lightning_small.png - :description: :doc:`/auto_examples/using-ray-with-pytorch-lightning` + :description: :doc:`/auto_examples/using-ray-with-pytorch-lightning` \ No newline at end of file diff --git a/doc/kubernetes/ray-cluster.yaml b/doc/kubernetes/ray-cluster.yaml index f4f493152608c..1b3da82e9ccaa 100644 --- a/doc/kubernetes/ray-cluster.yaml +++ b/doc/kubernetes/ray-cluster.yaml @@ -3,7 +3,7 @@ apiVersion: v1 kind: Service metadata: namespace: ray - name: example-cluster-ray-head + name: ray-head spec: ports: - name: client @@ -111,7 +111,7 @@ spec: imagePullPolicy: IfNotPresent command: ["/bin/bash", "-c", "--"] args: - - "ray start --num-cpus=$MY_CPU_REQUEST --address=$EXAMPLE_CLUSTER_RAY_HEAD_SERVICE_HOST:$EXAMPLE_CLUSTER_RAY_HEAD_SERVICE_PORT_REDIS --object-manager-port=12345 --node-manager-port=12346 --block" + - "ray start --num-cpus=$MY_CPU_REQUEST --address=$RAY_HEAD_SERVICE_HOST:$RAY_HEAD_SERVICE_PORT_REDIS --object-manager-port=12345 --node-manager-port=12346 --block" # This volume allocates shared memory for Ray to use for its plasma # object store. If you do not provide this, Ray will fall back to # /tmp which cause slowdowns if is not a shared memory volume. diff --git a/doc/source/advanced.rst b/doc/source/advanced.rst index fa4ff9cffa65c..75ff25045592e 100644 --- a/doc/source/advanced.rst +++ b/doc/source/advanced.rst @@ -42,23 +42,17 @@ This often occurs for data loading and preprocessing. # hi there! # hi there! -Multi-node synchronization using an Actor -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Multi-node synchronization using ``SignalActor`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -When you have multiple tasks that need to wait on some condition or otherwise -need to synchronize across tasks & actors on a cluster, you can use a central -actor to coordinate among them. Below is an example of using a ``SignalActor`` -that wraps an ``asyncio.Event`` for basic synchronization. +When you have multiple tasks that need to wait on some condition, you can use a ``SignalActor`` to coordinate. .. code-block:: python - import asyncio - + # Also available via `from ray._private.test_utils import SignalActor` import ray + import asyncio - ray.init() - - # We set num_cpus to zero because this actor will mostly just block on I/O. @ray.remote(num_cpus=0) class SignalActor: def __init__(self): @@ -79,6 +73,7 @@ that wraps an ``asyncio.Event`` for basic synchronization. print("go!") + ray.init() signal = SignalActor.remote() tasks = [wait_and_go.remote(signal) for _ in range(4)] print("ready...") @@ -446,7 +441,7 @@ On Mac OS and Linux, Ray 1.4+ supports dynamically setting the runtime environme The ``runtime_env`` is a (JSON-serializable) dictionary that can be passed as an option to tasks and actors, and can also be passed to ``ray.init()``. The runtime environment defines the dependencies required for your workload. -You can specify a runtime environment for your whole job using ``ray.init()`` or Ray Client: +You can specify a runtime environment for your whole job using ``ray.init()`` or Ray Client... .. literalinclude:: ../examples/doc_code/runtime_env_example.py :language: python @@ -461,20 +456,19 @@ You can specify a runtime environment for your whole job using ``ray.init()`` or # Using Ray Client ray.init("ray://localhost:10001", runtime_env=runtime_env) -Or specify per-actor or per-task in the ``@ray.remote()`` decorator or by using ``.options()``: +...or specify per-actor or per-task in the ``@ray.remote()`` decorator or by using ``.options()``: .. literalinclude:: ../examples/doc_code/runtime_env_example.py :language: python :start-after: __per_task_per_actor_start__ :end-before: __per_task_per_actor_end__ -Note: specifying within the ``@ray.remote()`` decorator is currently unsupported while using Ray Client; please use ``.options()`` instead in this case. - The ``runtime_env`` is a Python dictionary including one or more of the following arguments: - ``working_dir`` (Path): Specifies the working directory for your job. This must be an existing local directory. It will be cached on the cluster, so the next time you connect with Ray Client you will be able to skip uploading the directory contents. - All Ray workers for your job will be started in their node's local copy of this working directory. + Furthermore, if you locally make a small change to your directory, the next time you connect only the updated part will be uploaded. + All Ray workers for your job will be started in their node's copy of this working directory. - Examples @@ -492,7 +486,7 @@ The ``runtime_env`` is a Python dictionary including one or more of the followin - Example: ``["my_file.txt", "path/to/dir", "*.log"]`` - ``pip`` (List[str] | str): Either a list of pip packages, or a string containing the path to a pip - `“requirements.txt” `_ file. The path may be an absolute path or a relative path. + `“requirements.txt” `_ file. The path may be an absolute path or a relative path. (Note: A relative path will be interpreted relative to ``working_dir`` if ``working_dir`` is specified.) This will be dynamically installed in the ``runtime_env``. To use a library like Ray Serve or Ray Tune, you will need to include ``"ray[serve]"`` or ``"ray[tune]"`` here. @@ -500,7 +494,7 @@ The ``runtime_env`` is a Python dictionary including one or more of the followin - Example: ``"./requirements.txt"`` -- ``conda`` (dict | str): Either (1) a dict representing the conda environment YAML, (2) a string containing the absolute or relative path to a +- ``conda`` (dict | str): Either (1) a dict representing the conda environment YAML, (2) a string containing the path to a `conda “environment.yml” `_ file, or (3) the name of a local conda environment already installed on each node in your cluster (e.g., ``"pytorch_p36"``). In the first two cases, the Ray and Python dependencies will be automatically injected into the environment to ensure compatibility, so there is no need to manually include them. @@ -512,15 +506,12 @@ The ``runtime_env`` is a Python dictionary including one or more of the followin - Example: ``"pytorch_p36"`` + Note: if specifying the path to an "environment.yml" file, you may provide an absolute path or a relative path. A relative path will be interpreted relative to ``working_dir`` if ``working_dir`` is specified. - ``env_vars`` (Dict[str, str]): Environment variables to set. - Example: ``{"OMP_NUM_THREADS": "32", "TF_WARNINGS": "none"}`` -- ``eager_install`` (bool): A boolean indicates whether to install runtime env eagerly before the workers are leased. This flag is set to false by default. - - - Example: ``{"eager_install": True}`` - The runtime environment is inheritable, so it will apply to all tasks/actors within a job and all child tasks/actors of a task or actor, once set. If a child actor or task specifies a new ``runtime_env``, it will be merged with the parent’s ``runtime_env`` via a simple dict update. diff --git a/doc/source/cluster/config.rst b/doc/source/cluster/config.rst index 867e8398e6985..7ba7e2ccbcbef 100644 --- a/doc/source/cluster/config.rst +++ b/doc/source/cluster/config.rst @@ -109,8 +109,6 @@ Provider :ref:`region `: str :ref:`availability_zone `: str :ref:`cache_stopped_nodes `: bool - :ref:`security_group `: - :ref:`Security Group ` .. group-tab:: Azure @@ -132,20 +130,6 @@ Provider :ref:`project_id `: str :ref:`cache_stopped_nodes `: bool -.. _cluster-configuration-security-group-type: - -Security Group -~~~~~~~~~~~~~~ - -.. tabs:: - .. group-tab:: AWS - - .. parsed-literal:: - - :ref:`GroupName `: str - :ref:`IpPermissions `: - - `IpPermission `_ - .. _cluster-configuration-node-types-type: Node types @@ -939,52 +923,6 @@ If enabled, nodes will be *stopped* when the cluster scales down. If disabled, n * **Type:** Boolean * **Default:** ``True`` -.. _cluster-configuration-security-group: - -``provider.security_group`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. tabs:: - .. group-tab:: AWS - - A security group that can be used to specify custom inbound rules. - - * **Required:** No - * **Importance:** Medium - * **Type:** :ref:`Security Group ` - - .. group-tab:: Azure - - Not available. - - .. group-tab:: GCP - - Not available. - - -.. _cluster-configuration-group-name: - -``security_group.GroupName`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The name of the security group. This name must be unique within the VPC. - -* **Required:** No -* **Importance:** Low -* **Type:** String -* **Default:** ``"ray-autoscaler-{cluster-name}"`` - -.. _cluster-configuration-ip-permissions: - -``security_group.IpPermissions`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The inbound rules associated with the security group. - -* **Required:** No -* **Importance:** Medium -* **Type:** `IpPermission `_ - .. _cluster-configuration-node-config: ``available_node_types..node_type.node_config`` diff --git a/doc/source/cluster/ray-client.rst b/doc/source/cluster/ray-client.rst index 1b9099160f9ae..550bb75480127 100644 --- a/doc/source/cluster/ray-client.rst +++ b/doc/source/cluster/ray-client.rst @@ -62,9 +62,9 @@ Step 1: set up your Ray cluster First, you'll want to create a remote Ray cluster. Follow the directions in :ref:`ref-cluster-quick-start` to do this. -If using the :doc:`Ray cluster launcher `, the remote cluster will be listening on port ``10001`` of the head node. If necessary, you can modify this port by setting ``--ray-client-server-port`` to the ``ray start`` `command `_. +If using the `Ray cluster launcher `_, the remote cluster will be listening on port ``10001`` of the head node. If necessary, you can modify this port by setting ``--ray-client-server-port`` to the ``ray start`` `command `_. -If not using the :doc:`Ray cluster launcher `, you can start the "Ray Client Server" manually on the head node of your remote cluster by running the following: +If not using the `Ray cluster launcher `_, you can start the "Ray Client Server" manually on the head node of your remote cluster by running the following: .. code-block:: bash @@ -77,32 +77,6 @@ Ensure that the Ray Client port on the head node is reachable from your local ma This means opening that port up by configuring security groups or other access controls (on `EC2 `_) or proxying from your local machine to the cluster (on `K8s `_). -.. tabs:: - .. group-tab:: AWS - - With the Ray cluster launcher, you can configure the security group - to allow inbound access by defining :ref:`cluster-configuration-security-group` - in your `cluster.yaml`. - - .. code-block:: yaml - - # An unique identifier for the head node and workers of this cluster. - cluster_name: minimal_security_group - - # Cloud-provider specific configuration. - provider: - type: aws - region: us-west-2 - security_group: - GroupName: ray_client_security_group - IpPermissions: - - FromPort: 10001 - ToPort: 10001 - IpProtocol: TCP - IpRanges: - # This will enable inbound access from ALL IPv4 addresses. - - CidrIp: 0.0.0.0/0 - Step 3: Run Ray code ~~~~~~~~~~~~~~~~~~~~ @@ -125,43 +99,8 @@ Now, connect to the Ray Cluster with the following and then use Ray like you nor #.... -Alternative Approach: SSH Port Forwarding -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -As an alternative to configuring inbound traffic rules, you can also set up -Ray Client via port forwarding. While this approach does require an open SSH -connection, it can be useful in a test environment where the -``head_node_host`` often changes. - -First, open up an SSH connection with your Ray cluster and forward the -listening port (``10001``). - -.. code-block:: bash - - $ ray up cluster.yaml - $ ray attach cluster.yaml -p 10001 - -Then, you can connect to the Ray cluster using ``localhost`` as the -``head_node_host``. - -.. code-block:: python - - import ray - - # This will connect to the cluster via the open SSH session. - ray.init("ray://localhost:10001") - - # Normal Ray code follows - @ray.remote - def do_work(x): - return x ** x - - do_work.remote(2) - - #.... - -Connect to multiple ray clusters (Experimental) ------------------------------------------------ +Connect to multiple ray clusters +-------------------------------- Ray client allows connecting to multiple ray clusters in one Python process. To do this, just pass ``allow_multiple=True`` to ``ray.init``: diff --git a/doc/source/conf.py b/doc/source/conf.py index c554dfec1eda9..05cc18898b7dc 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -162,10 +162,10 @@ def __getattr__(cls, name): versionwarning_body_selector = "#main-content" sphinx_gallery_conf = { - "examples_dirs": ["../examples", "tune/_tutorials", - "data/_examples"], # path to example scripts + "examples_dirs": ["../examples", + "tune/_tutorials"], # path to example scripts # path where to save generated examples - "gallery_dirs": ["auto_examples", "tune/tutorials", "data/examples"], + "gallery_dirs": ["auto_examples", "tune/tutorials"], "ignore_pattern": "../examples/doc_code/", "plot_gallery": "False", "min_reported_time": sys.maxsize, diff --git a/doc/source/configure.rst b/doc/source/configure.rst index 186255d855373..5e93b2c6e4f82 100644 --- a/doc/source/configure.rst +++ b/doc/source/configure.rst @@ -234,28 +234,6 @@ to localhost when the ray is started using ``ray.init``. See the `Redis security documentation `__ for more information. -TLS Authentication ------------------- - -Ray can be configured to use TLS on it's gRPC channels. -This has means that connecting to the Ray client on the head node will -require an appropriate set of credentials and also that data exchanged between -various processes (client, head, workers) will be encrypted. - -Enabling TLS will cause a performance hit due to the extra overhead of mutual -authentication and encryption. -Testing has shown that this overhead is large for small workloads and becomes -relatively smaller for large workloads. -The exact overhead will depend on the nature of your workload. - -TLS is enabled by setting environment variables. - -- ``RAY_USE_TLS``: Either 1 or 0 to use/not-use TLS. If this is set to 1 then all of the environment variables below must be set. Default: 0. -- ``RAY_TLS_SERVER_CERT``: Location of a `certificate file` which is presented to other endpoints so as to achieve mutual authentication. -- ``RAY_TLS_SERVER_KEY``: Location of a `private key file` which is the cryptographic means to prove to other endpoints that you are the authorized user of a given certificate. -- ``RAY_TLS_CA_CERT``: Location of a `CA certificate file` which allows TLS to decide whether an endpoint's certificate has been signed by the correct authority. - - Java Applications ----------------- diff --git a/doc/source/data/.gitignore b/doc/source/data/.gitignore deleted file mode 100644 index d838da9865693..0000000000000 --- a/doc/source/data/.gitignore +++ /dev/null @@ -1 +0,0 @@ -examples/ diff --git a/doc/source/data/_examples/README.rst b/doc/source/data/_examples/README.rst deleted file mode 100644 index 8feca331c5d78..0000000000000 --- a/doc/source/data/_examples/README.rst +++ /dev/null @@ -1 +0,0 @@ -:orphan: diff --git a/doc/source/data/_examples/big_data_ingestion.py b/doc/source/data/_examples/big_data_ingestion.py deleted file mode 100644 index 7cc569ea8e161..0000000000000 --- a/doc/source/data/_examples/big_data_ingestion.py +++ /dev/null @@ -1,276 +0,0 @@ -# flake8: noqa: E501 -""" -Example: Large-scale ML Ingest -================================================= - -In this example, you will learn how to build, deploy and scale up a machine -learning shuffle ingestion pipeline using -`Ray Dataset `_ and -`Dataset Pipelines `_. - -In particular, we will show you: - -* How to build a shuffle ingestion pipeline that loads, shuffles and feeds data - into distributed trainers in a few lines of code; -* How to scale the pipeline from ingesting 100MiB data to - 500GiB data. - -.. image:: ../../data/dataset-repeat-2.svg - :align: center - -""" - -############################################################################### -# Python Setup -# ------------ -# -# First, we'll import all of the libraries we'll be using. This step also helps us -# verify that the environment is configured correctly. If any of the imports -# are missing, an exception will be raised. - -import argparse -import tempfile -import time -from typing import List - -import pandas -import pyarrow - -import ray -from ray.data.dataset_pipeline import DatasetPipeline -from ray.data.datasource.datasource import RandomIntRowDatasource - -####################################################################### -# Build shuffle ingestion pipeline -# ---------------------------------- -# -# A typical machine learning ingestion pipeline consists of the following 4 -# steps: -# -# 1. Load the training data from external storage; -# 2. Iterate over the data for multiple epochs; -# 3. In each epoch, applying global shuffle to decorrelate the data; -# 4. In each epoch, split the shuffled data into shards, and feed shards to -# distributed trainers; -# -# Let’s see how we implement such pipeline using Ray Dataset: - - -def create_shuffle_pipeline(training_data_dir: str, num_epochs: int, - num_shards: int) -> List[DatasetPipeline]: - - return ray.data.read_parquet(training_data_dir) \ - .repeat(num_epochs) \ - .random_shuffle_each_window() \ - .split(num_shards, equal=True) - - -############################################################################ -# We’ve now defined a ``create_shuffle_pipeline`` function that creates an -# ingestion pipeline. -# It reads ``training_data_dir``, iterates for ``num_epochs`` times, -# where in each epoch it -# shuffles and splits the training data into ``num_shards``. - -############################################################################### -# Feed the pipeline into trainers -# ----------------------------------- -# Let’s also implement a ``TrainingWorker`` which consumes the shuffled data -# from each shard. -# -# For simplicity, we will define a -# `Ray Actor `_ that emulates -# training workers. Specifically, -# -# 1. It takes one shard of the shuffle pipeline for training; -# 2. It iterates over the shard to get a training dataset per epoch; -# 3. It then consumes the dataset by batches; - - -@ray.remote -class TrainingWorker: - def __init__(self, rank: int, shard: DatasetPipeline): - self.rank = rank - self.shard = shard - - def train(self): - for epoch, training_dataset in enumerate(self.shard.iter_datasets()): - # Following code emulates epoch based SGD training. - print(f"Training... worker: {self.rank}, epoch: {epoch}") - for i, batch in enumerate(training_dataset.iter_batches()): - # TODO: replace the code for real training. - pass - - -########################################################################### -# Let's run it -# ----------------------------- -# -# Now let’s run the data pipeline end-to-end: -# -# First, let's parse some arguments. - -parser = argparse.ArgumentParser() -parser.add_argument( - "--large-scale-test", - action="store_true", - help="Run large scale test (500GiB of data).") - -args, _ = parser.parse_known_args() - -############################################################################### -# -# After that, let's generate 100MiB of Parquet files, -# create the shuffle pipeline by reading those generated Parquet files, -# and use training workers to consume the pipeline. - -if not args.large_scale_test: - - NUM_TRAINING_WORKERS = 4 - NUM_EPOCHS = 5 - NUM_COLUMNS = 10 - SIZE_100MiB = 100 * 1024 * 1024 - - # create a local ray cluster. - ray.init() - - def generate_example_files(size_bytes: int) -> str: - tmpdir = tempfile.mkdtemp() - ray.data.read_datasource( - RandomIntRowDatasource(), - n=size_bytes // 8 // NUM_COLUMNS, - num_columns=NUM_COLUMNS).write_parquet(tmpdir) - return tmpdir - - example_files_dir = generate_example_files(SIZE_100MiB) - - splits = create_shuffle_pipeline(example_files_dir, NUM_EPOCHS, - NUM_TRAINING_WORKERS) - - training_workers = [ - TrainingWorker.remote(rank, shard) for rank, shard in enumerate(splits) - ] - - # Let's run the e2e pipeline - start = time.time() - ray.get([worker.train.remote() for worker in training_workers]) - print(f"total ingestion time: {int(time.time() - start)}s") - - # -> Write Progress: 100%|████████████████████| 201/201 [00:00<00:00, 228.67it/s] - # -> Stage 0: 0%| | 0/5 [00:00 Stage 0: 40%|████ | 2/5 [00:11<00:17, 5.75s/it] - # -> Stage 0: 60%|██████ | 3/5 [00:23<00:16, 8.15s/it] - # -> ... - # -> (TrainingWorker pid=1651600) Training... worker: 2, epoch: 0 - # -> Stage 0: 80%|████████ | 4/5 [00:35<00:09, 9.59s/it] - # -> ... - # -> (TrainingWorker pid=1651599) Training... worker: 0, epoch: 1 - # -> Stage 0: 100%|██████████| 5/5 [00:46<00:00, 10.34s/it] - # -> ... - # -> (TrainingWorker pid=1651387) Training... worker: 3, epoch: 4 - # -> total ingestion time: 61s - -################################################################################# -# Scale the shuffle ingestion pipeline -# -------------------------------------------------------- -# -# Scaling the shuffle ingestion pipeline is simple. With Ray, we can linearly -# scale the pipeline from ingesting 100MiB of data to 500GiB of data by adding -# more machines. -# -# To ingest 500GiB of data, we'll set up a Ray Cluster. -# The provided :download:`big_data_ingestion.yaml <../big_data_ingestion.yaml>` -# cluster config can be used to set up an AWS cluster with 70 CPU nodes and -# 16 GPU nodes. Using following command to bring up the Ray cluster. -# -# .. code-block:: bash -# -# $ pip install ray boto3 -# $ ray up big_data_ingestion.yaml -# -# After the cluster is started, let's implement our large scale ingestion test: -# -# First, since we are runing on a cluster, let's create the pipeline from -# RandomIntRowDatasource directly. In this way we don't need to set up S3 for storing -# generated data. - - -def create_large_shuffle_pipeline(data_size_bytes: int, num_epochs: int, - num_columns: int, - num_shards: int) -> List[DatasetPipeline]: - # _spread_resource_prefix is used to ensure tasks are evenly spread to all - # CPU nodes. - return ray.data.read_datasource( - RandomIntRowDatasource(), n=data_size_bytes // 8 // num_columns, - num_columns=num_columns, - _spread_resource_prefix="node:") \ - .repeat(num_epochs) \ - .random_shuffle_each_window(_spread_resource_prefix="node:") \ - .split(num_shards, equal=True) - - -################################################################################# -# -# Now, it's time to implement the 500GiB shuffle ingestion pipeline. - -if args.large_scale_test: - NUM_TRAINING_WORKERS = 16 - NUM_EPOCHS = 5 - NUM_COLUMNS = 10 - GiB = 1024 * 1024 * 1024 - SIZE_500GiB = 500 * GiB - TOTAL_NUM_NODES = 70 + 16 + 1 - - # use the AWS cluster we just set up. - ray.init(address="auto") - - # waiting for cluster nodes to come up. - while len(ray.nodes()) < TOTAL_NUM_NODES: - print( - f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}" - ) - time.sleep(5) - - splits = create_large_shuffle_pipeline(SIZE_500GiB, NUM_EPOCHS, - NUM_COLUMNS, NUM_TRAINING_WORKERS) - - # Note we set num_gpus=1 for workers so that - # the workers will only run on GPU nodes. - training_workers = [ - TrainingWorker.options(num_gpus=1) \ - .remote(rank, shard) for rank, shard in enumerate(splits) - ] - - start = time.time() - - # Let's run the large scale test. - ray.get([worker.train.remote() for worker in training_workers]) - print(f"total ingestion time: {int(time.time() - start)}s") - throughput = SIZE_500GiB * NUM_EPOCHS / (time.time() - start) / GiB - print("throughput: {0:0.2f}GiB/s".format(throughput)) - -################################################################################# -# -# Finally, let's run our pipeline on the cluster we just started: -# -# .. code-block:: bash -# -# $ ray submit ./big_data_ingestion.yaml ./big_data_ingestion.py --large-scale-test -# # -> Connecting to existing Ray cluster at address: 172.31.47.38:6379 -# # -> waiting for nodes to start up: 1/87 -# # -> ... -# # -> waiting for nodes to start up: 87/87 -# # -> Stage 0: 0%| | 0/5 [00:00 Stage 0: 20%|██ | 1/5 [00:00<00:02, 1.77it/s] -# # -> Stage 0: 40%|████ | 2/5 [00:38<00:35, 11.67s/it] -# # -> Stage 0: 60%|██████ | 3/5 [01:13<00:37, 18.83s/it] -# # -> ... -# # -> (TrainingWorker pid=5084, ip=172.31.35.245) Training... worker: 12, epoch: 0 -# # -> Stage 0: 80%|████████ | 4/5 [03:15<00:49, 49.63s/it] -# # -> ... -# # -> (TrainingWorker pid=5076, ip=172.31.40.190) Training... worker: 9, epoch: 1 -# # -> Stage 0: 100%|██████████| 5/5 [05:02<00:00, 67.01s/it] -# # -> ... -# # -> (TrainingWorker pid=5074, ip=172.31.40.190) Training... worker: 0, epoch: 4 -# # -> total ingestion time: 291s -# # -> throughput: 8.56GiB/s diff --git a/doc/source/data/big_data_ingestion.yaml b/doc/source/data/big_data_ingestion.yaml deleted file mode 100644 index 2609afdf4426d..0000000000000 --- a/doc/source/data/big_data_ingestion.yaml +++ /dev/null @@ -1,54 +0,0 @@ -cluster_name: big_data_ingestion.yaml - -max_workers: 86 - -provider: - type: aws - region: us-west-1 - -auth: - ssh_user: ubuntu - -available_node_types: - head: - node_config: - InstanceType: i3.8xlarge - BlockDeviceMappings: - - DeviceName: /dev/sda1 - Ebs: - VolumeSize: 300 - resources: { } - - gpu_nodes: - min_workers: 16 - max_workers: 16 - node_config: - InstanceType: i3.8xlarge - BlockDeviceMappings: - - DeviceName: /dev/sda1 - Ebs: - VolumeSize: 300 - resources: - GPU: 1 - - memory_nodes: - min_workers: 70 - max_workers: 70 - node_config: - InstanceType: i3.8xlarge - BlockDeviceMappings: - - DeviceName: /dev/sda1 - Ebs: - VolumeSize: 300 - resources: { } - -head_node_type: head - -setup_commands: - - pip install -U ray ray[default] pyarrow pandas - -head_start_ray_commands: - - ray start --head --port=6379 --object-manager-port=8076 --object-store-memory=90000000000 --autoscaling-config=~/ray_bootstrap_config.yaml - -worker_start_ray_commands: - - ray start --address=$RAY_HEAD_IP:6379 --object-manager-port=8076 --object-store-memory=90000000000 diff --git a/doc/source/data/dask-on-ray.rst b/doc/source/data/dask-on-ray.rst index 9e08977bdb16e..6057b740db441 100644 --- a/doc/source/data/dask-on-ray.rst +++ b/doc/source/data/dask-on-ray.rst @@ -6,16 +6,16 @@ Dask on Ray `Dask `__ is a Python parallel computing library geared towards scaling analytics and scientific computing workloads. It provides `big data collections `__ that mimic the APIs of -the familiar `NumPy `__ and `Pandas `__ libraries, +the familiar `NumPy `__ and `Pandas `__ libraries, allowing those abstractions to represent -larger-than-memory data and/or allowing operations on that data to be run on a multi-machine cluster, +larger-than-memory data and/or allowing operations on that data to be run on a multi-machine cluster, while also providing automatic data parallelism, smart scheduling, and optimized operations. Operations on these collections create a task graph, which is executed by a scheduler. Ray provides a scheduler for Dask (`dask_on_ray`) which allows you to build data analyses using Dask's collections and execute -the underlying tasks on a Ray cluster. +the underlying tasks on a Ray cluster. `dask_on_ray` uses Dask's scheduler API, which allows you to specify any callable as the scheduler that you would like Dask to use to execute your @@ -30,12 +30,8 @@ workload. Using the Dask-on-Ray scheduler, the entire Dask ecosystem can be exec * - Ray Version - Dask Version - * - ``1.7.0`` - - ``2021.9.1`` - * - ``1.6.0`` - - ``2021.8.1`` * - ``1.5.0`` - - ``2021.7.0`` + - ``2021.7.0`` * - ``1.4.1`` - ``2021.6.1`` * - ``1.4.0`` @@ -86,7 +82,7 @@ In this case, there are two recommended setup. # Head node. Set `num_cpus=0` to avoid tasks are being scheduled on a head node. RAY_SCHEDULER_SPREAD_THRESHOLD=0.0 ray start --head --num-cpus=0 - # Worker node. + # Worker node. RAY_SCHEDULER_SPREAD_THRESHOLD=0.0 ray start --address=[head-node-address] Out-of-Core Data Processing @@ -105,10 +101,10 @@ Persist .. _dask-on-ray-persist: -Dask-on-Ray patches `dask.persist() -`__ in order to match `Dask +Dask-on-Ray patches `dask.persist() +`__ in order to match `Dask Distributed's persist semantics -`; namely, calling `dask.persist()` with a Dask-on-Ray +`; namely, calling `dask.persist()` with a Dask-on-Ray scheduler will submit the tasks to the Ray cluster and return Ray futures inlined in the Dask collection. This is nice if you wish to compute some base collection (such as a Dask array), followed by multiple different downstream computations (such as diff --git a/doc/source/data/dataset-pipeline.rst b/doc/source/data/dataset-pipeline.rst index d954df8051eb5..8b60ca3cb7985 100644 --- a/doc/source/data/dataset-pipeline.rst +++ b/doc/source/data/dataset-pipeline.rst @@ -6,12 +6,12 @@ Overview Datasets execute their transformations synchronously in blocking calls. However, it can be useful to overlap dataset computations with output. This can be done with a `DatasetPipeline `__. -A DatasetPipeline is an unified iterator over a (potentially infinite) sequence of Ray Datasets, each of which represents a *window* over the original data. Conceptually it is similar to a `Spark DStream `__, but manages execution over a bounded amount of source data instead of an unbounded stream. Ray computes each dataset window on-demand and stitches their output together into a single logical data iterator. DatasetPipeline implements most of the same transformation and output methods as Datasets (e.g., map, filter, split, iter_rows, to_torch, etc.). +A DatasetPipeline is an unified iterator over a (potentially infinite) sequence of Ray Datasets. Conceptually it is similar to a `Spark DStream `__, but manages execution over a bounded amount of source data instead of an unbounded stream. Ray computes each dataset on-demand and stitches their output together into a single logical data iterator. DatasetPipeline implements most of the same transformation and output methods as Datasets (e.g., map, filter, split, iter_rows, to_torch, etc.). Creating a DatasetPipeline ~~~~~~~~~~~~~~~~~~~~~~~~~~ -A DatasetPipeline can be constructed in two ways: either by pipelining the execution of an existing Dataset (via ``Dataset.window``), or generating repeats of an existing Dataset (via ``Dataset.repeat``). Similar to Datasets, you can freely pass DatasetPipelines between Ray tasks, actors, and libraries. Get started with this synthetic data example: +A DatasetPipeline can be constructed in two ways: either by pipelining the execution of an existing Dataset (via ``Dataset.pipeline``), or generating repeats of an existing Dataset (via ``Dataset.repeat``). Similar to Datasets, you can freely pass DatasetPipelines between Ray tasks, actors, and libraries. Get started with this synthetic data example: .. code-block:: python @@ -30,16 +30,16 @@ A DatasetPipeline can be constructed in two ways: either by pipelining the execu base = ray.data.range(1000000) print(base) # -> Dataset(num_blocks=200, num_rows=1000000, schema=) - pipe = base.window(blocks_per_window=10) + pipe = base.pipeline(parallelism=10) print(pipe) - # -> DatasetPipeline(num_windows=20, num_stages=1) + # -> DatasetPipeline(length=20, num_stages=1) # Applying transforms to pipelines adds more pipeline stages. pipe = pipe.map(func1) pipe = pipe.map(func2) pipe = pipe.map(func3) print(pipe) - # -> DatasetPipeline(num_windows=20, num_stages=4) + # -> DatasetPipeline(length=20, num_stages=4) # Output can be pulled from the pipeline concurrently with its execution. num_rows = 0 @@ -53,7 +53,8 @@ A DatasetPipeline can be constructed in two ways: either by pipelining the execu print("Total num rows", num_rows) # -> Total num rows 1000000 -You can also create a DatasetPipeline from a custom iterator over dataset creators using ``DatasetPipeline.from_iterable``. For example, this is how you would implement ``Dataset.repeat`` and ``Dataset.window`` using ``from_iterable``: + +You can also create a DatasetPipeline from a custom iterator over dataset creators using ``DatasetPipeline.from_iterable``. For example, this is how you would implement ``Dataset.repeat`` and ``Dataset.pipeline`` using ``from_iterable``: .. code-block:: python @@ -65,52 +66,10 @@ You can also create a DatasetPipeline from a custom iterator over dataset creato pipe = DatasetPipeline.from_iterable( [lambda: source, lambda: source, lambda: source, lambda: source]) - # Equivalent to ray.data.range(1000).window(blocks_per_window=10) + # Equivalent to ray.data.range(1000).pipeline(parallelism=10) splits = ray.data.range(1000, parallelism=200).split(20) pipe = DatasetPipeline.from_iterable([lambda s=s: s for s in splits]) -Per-Window Transformations -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -While most Dataset operations are per-row (e.g., map, filter), some operations apply to the Dataset as a whole (e.g., sort, shuffle). When applied to a pipeline, holistic transforms like shuffle are applied separately to each window in the pipeline: - -.. code-block:: python - - # Example of randomly shuffling each window of a pipeline. - ray.data.range(5).repeat(2).random_shuffle_each_window().show_windows() - # -> - # === Window 0 === - # 4 - # 3 - # 1 - # 0 - # 2 - # === Window 1 === - # 2 - # 1 - # 4 - # 0 - # 3 - -You can also apply arbitrary transformations to each window using ``DatasetPipeline.foreach_window()``: - -.. code-block:: python - - # Equivalent transformation using .foreach_window() - ray.data.range(5).repeat(2).foreach_window(lambda w: w.random_shuffle()).show_windows() - # -> - # === Window 0 === - # 1 - # 0 - # 4 - # 2 - # 3 - # === Window 1 === - # 4 - # 2 - # 0 - # 3 - # 1 Example: Pipelined Batch Inference ---------------------------------- @@ -150,28 +109,28 @@ Ignoring the output, the above script has three separate stages: loading, prepro Enabling Pipelining ~~~~~~~~~~~~~~~~~~~ -We can optimize this by *pipelining* the execution of the dataset with the ``.window()`` call, which returns a DatasetPipeline instead of a Dataset object. The pipeline supports similar transformations to the original Dataset: +We can optimize this by *pipelining* the execution of the dataset with the ``.pipeline()`` call, which returns a DatasetPIpeline instead of a Dataset object. The pipeline supports similar transformations to the original Dataset: .. code-block:: python # Convert the Dataset into a DatasetPipeline. pipe: DatasetPipeline = ray.data \ .read_binary_files("s3://bucket/image-dir") \ - .window(blocks_per_window=2) + .pipeline(parallelism=2) # The remainder of the steps do not change. pipe = pipe.map(preprocess) pipe = pipe.map_batches(BatchInferModel, compute="actors", batch_size=256, num_gpus=1) pipe.write_json("/tmp/results") -Here we specified ``blocks_per_window=2``, which means that the Dataset is split into smaller sub-Datasets of two blocks each. Each transformation or *stage* of the pipeline is operating over these two-block Datasets in parallel. This means batch inference processing can start as soon as two blocks are read and preprocessed, greatly reducing the GPU idle time: +Here we specified ``parallelism=2``, which means that the Dataset is split into smaller sub-Datasets of two blocks each. Each transformation or *stage* of the pipeline is operating over these two-block Datasets in parallel. This means batch inference processing can start as soon as two blocks are read and preprocessed, greatly reducing the GPU idle time: .. image:: dataset-pipeline-2.svg Tuning Parallelism ~~~~~~~~~~~~~~~~~~ -Tune the throughput vs latency of your pipeline with the ``blocks_per_window`` setting. As a rule of thumb, higher parallelism settings perform better, however ``blocks_per_window == num_blocks`` effectively disables pipelining, since the DatasetPipeline will only contain a single Dataset. The other extreme is setting ``blocks_per_window=1``, which minimizes the latency to initial output but only allows one concurrent transformation task per stage: +Tune the throughput vs latency of your pipeline with the ``parallelism`` setting. As a rule of thumb, higher parallelism settings perform better, however ``parallelism == num_blocks`` effectively disables pipelining, since the DatasetPipeline will only contain a single Dataset. The other extreme is setting ``parallelism=1``, which minimizes the latency to initial output but only allows one concurrent transformation task per stage: .. image:: dataset-pipeline-3.svg @@ -196,7 +155,7 @@ Transformations made prior to the Dataset prior to the call to ``.repeat()`` are pipe: DatasetPipeline = ray.data \ .read_datasource(...) \ .repeat() \ - .random_shuffle_each_window() + .random_shuffle() @ray.remote(num_gpus=1) def train_func(pipe: DatasetPipeline): @@ -225,7 +184,7 @@ Similar to how you can ``.split()`` a Dataset, you can also split a DatasetPipel pipe: DatasetPipeline = ray.data \ .read_parquet("s3://bucket/dir") \ .repeat() \ - .random_shuffle_each_window() + .random_shuffle() @ray.remote(num_gpus=1) class TrainingWorker: @@ -242,55 +201,3 @@ Similar to how you can ``.split()`` a Dataset, you can also split a DatasetPipel **Pipeline**: .. image:: dataset-repeat-2.svg - -Changing Pipeline Structure ---------------------------- - -Sometimes, you may want to change the structure of an existing pipeline. For example, after generating a pipeline with ``ds.window(k)``, you may want to repeat that windowed pipeline ``n`` times. This can be done with ``ds.window(k).repeat(n)``. As another example, suppose you have a repeating pipeline generated with ``ds.repeat(n)``. The windowing of that pipeline can be changed with ``ds.repeat(n).rewindow(k)``. Note the subtle difference in the two examples: the former is repeating a windowed pipeline that has a base window size of ``k``, while the latter is re-windowing a pipeline of initial window size of ``ds.num_blocks()``. The latter may produce windows that span multiple copies of the same original data: - -.. code-block:: python - - # Window followed by repeat. - ray.data.range(5) \ - .window(blocks_per_window=2) \ - .repeat(2) \ - .show_windows() - # -> - # === Window 0 === - # 0 - # 1 - # === Window 1 === - # 2 - # 3 - # === Window 2 === - # 4 - # === Window 3 === - # 0 - # 1 - # === Window 4 === - # 2 - # 3 - # === Window 5 === - # 4 - - # Repeat followed by window. - ray.data.range(5) \ - .repeat(2) \ - .rewindow(blocks_per_window=2) \ - .show_windows() - # -> - # === Window 0 === - # 0 - # 1 - # === Window 1 === - # 2 - # 3 - # === Window 2 === - # 4 - # 0 - # === Window 3 === - # 1 - # 2 - # === Window 4 === - # 3 - # 4 diff --git a/doc/source/data/dataset-tensor-support.rst b/doc/source/data/dataset-tensor-support.rst index d2d3ebf40c6f1..b8a4ad68eed4e 100644 --- a/doc/source/data/dataset-tensor-support.rst +++ b/doc/source/data/dataset-tensor-support.rst @@ -3,34 +3,66 @@ Dataset Tensor Support ====================== -Tables with tensor columns --------------------------- - -Datasets supports tables with fixed-shape tensor columns, where each element in the column is a tensor (n-dimensional array) with the same shape. As an example, this allows you to use Pandas and Ray Datasets to read, write, and manipulate e.g., images. All conversions between Pandas, Arrow, and Parquet, and all application of aggregations/operations to the underlying image ndarrays are taken care of by Ray Datasets. - -With our Pandas extension type, :class:`TensorDtype `, and extension array, :class:`TensorArray `, you can do familiar aggregations and arithmetic, comparison, and logical operations on a DataFrame containing a tensor column and the operations will be applied to the underlying tensors as expected. With our Arrow extension type, :class:`ArrowTensorType `, and extension array, :class:`ArrowTensorArray `, you'll be able to import that DataFrame into Ray Datasets and read/write the data from/to the Parquet format. - -Automatic conversion between the Pandas and Arrow extension types/arrays keeps the details under-the-hood, so you only have to worry about casting the column to a tensor column using our Pandas extension type when first ingesting the table into a ``Dataset``, whether from storage or in-memory. All table operations downstream from that cast should work automatically. +Tensor-typed values +------------------- -Single-column tensor datasets -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The most basic case is when a dataset only has a single column, which is of tensor type. This kind of dataset can be created with ``.range_tensor()``, and can be read from and written to ``.npy`` files. Here are some examples: +Datasets support tensor-typed values, which are represented in-memory as Arrow tensors (i.e., np.ndarray format). Tensor datasets can be read from and written to ``.npy`` files. Here are some examples: .. code-block:: python # Create a Dataset of tensor-typed values. ds = ray.data.range_tensor(10000, shape=(3, 5)) # -> Dataset(num_blocks=200, num_rows=10000, - # schema={value: }) + # schema=) + + ds.map_batches(lambda t: t + 2).show(2) + # -> [[2 2 2 2 2] + # [2 2 2 2 2] + # [2 2 2 2 2]] + # [[3 3 3 3 3] + # [3 3 3 3 3] + # [3 3 3 3 3]] # Save to storage. - ds.write_numpy("/tmp/tensor_out", column="value") + ds.write_numpy("/tmp/tensor_out") # Read from storage. ray.data.read_numpy("/tmp/tensor_out") # -> Dataset(num_blocks=200, num_rows=?, - # schema={value: }) + # schema=) + +Tensor datasets are also created whenever an array type is returned from a map function: + +.. code-block:: python + + # Create a dataset of Python integers. + ds = ray.data.range(10) + # -> Dataset(num_blocks=10, num_rows=10, schema=) + + # It is now converted into a Tensor dataset. + ds = ds.map_batches(lambda x: np.array(x)) + # -> Dataset(num_blocks=10, num_rows=10, + # schema=) + +Tensor datasets can also be created from NumPy ndarrays that are already stored in the Ray object store: + +.. code-block:: python + + import numpy as np + + # Create a Dataset from a list of NumPy ndarray objects. + arr1 = np.arange(0, 10) + arr2 = np.arange(10, 20) + ds = ray.data.from_numpy([ray.put(arr1), ray.put(arr2)]) + +Tables with tensor columns +-------------------------- + +In addition to tensor datasets, Datasets also supports tables with fixed-shape tensor columns, where each element in the column is a tensor (n-dimensional array) with the same shape. As an example, this allows you to use both Pandas and Ray Datasets to read, write, and manipulate a table with a column of e.g. images (2D arrays), with all conversions between Pandas, Arrow, and Parquet, and all application of aggregations/operations to the underlying image ndarrays, being taken care of by Ray Datasets. + +With our Pandas extension type, :class:`TensorDtype `, and extension array, :class:`TensorArray `, you can do familiar aggregations and arithmetic, comparison, and logical operations on a DataFrame containing a tensor column and the operations will be applied to the underlying tensors as expected. With our Arrow extension type, :class:`ArrowTensorType `, and extension array, :class:`ArrowTensorArray `, you'll be able to import that DataFrame into Ray Datasets and read/write the data from/to the Parquet format. + +Automatic conversion between the Pandas and Arrow extension types/arrays keeps the details under-the-hood, so you only have to worry about casting the column to a tensor column using our Pandas extension type when first ingesting the table into a ``Dataset``, whether from storage or in-memory. All table operations downstream from that cast should work automatically. Reading existing serialized tensor columns ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -55,7 +87,7 @@ If you already have a Parquet dataset with columns containing serialized tensors # Write the dataset to Parquet. The tensor column will be written as an # array of opaque byte blobs. - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds.write_parquet(path) # Read the Parquet files into a new Dataset, with the serialized tensors @@ -85,7 +117,7 @@ If your serialized tensors don't fit the above constraints (e.g. they're stored # Write the dataset to Parquet. The tensor column will be written as an # array of opaque byte blobs. - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds.write_parquet(path) # Manually deserialize the tensor pickle bytes and cast to our tensor @@ -118,7 +150,7 @@ Now that the tensor column is properly typed and in a ``Dataset``, we can perfor # Arrow and Pandas is now aware of this tensor column, so we can do the # typical DataFrame operations on this column. - ds = ds.map_batches(lambda x: 2 * (x + 1), batch_format="pandas") + ds = ds.map_batches(lambda x: 2 * (x + 1), format="pandas") # -> Map Progress: 100%|████████████████████| 200/200 [00:00<00:00, 1123.54it/s] print(ds) # -> Dataset( @@ -212,7 +244,7 @@ If working with in-memory Pandas DataFrames that you want to analyze, manipulate # In addition to doing Pandas operations on the tensor column, # you can now put the DataFrame directly into a Dataset. - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) # Internally, this column is represented with the corresponding # Arrow tensor extension type. print(ds.schema()) @@ -227,7 +259,7 @@ If working with in-memory Pandas DataFrames that you want to analyze, manipulate # -> one: int64 # two: extension> - read_df = read_ds.to_pandas() + read_df = ray.get(read_ds.to_pandas())[0] print(read_df.dtypes) # -> one int64 # two TensorDtype diff --git a/doc/source/data/dataset.rst b/doc/source/data/dataset.rst index 20018765c1a69..7142691e5df45 100644 --- a/doc/source/data/dataset.rst +++ b/doc/source/data/dataset.rst @@ -16,7 +16,7 @@ Ray Datasets are the standard way to load and exchange data in Ray libraries and Concepts -------- -Ray Datasets implement `Distributed Arrow `__. A Dataset consists of a list of Ray object references to *blocks*. Each block holds a set of items in either an `Arrow table `__ or a Python list (for Arrow incompatible objects). Having multiple blocks in a dataset allows for parallel transformation and ingest of the data. +Ray Datasets implement `Distributed Arrow `__. A Dataset consists of a list of Ray object references to *blocks*. Each block holds a set of items in either an `Arrow table `__, `Arrow tensor `__, or a Python list (for Arrow incompatible objects). Having multiple blocks in a dataset allows for parallel transformation and ingest of the data. The following figure visualizes a Dataset that has three Arrow table blocks, each block holding 1000 rows each: @@ -145,10 +145,6 @@ Datasource Compatibility Matrices Creating Datasets ----------------- -.. tip:: - - Run ``pip install ray[data]`` to get started! - Get started by creating Datasets from synthetic data using ``ray.data.range()`` and ``ray.data.from_items()``. Datasets can hold either plain Python objects (schema is a Python type), or Arrow records (schema is Arrow). .. code-block:: python @@ -202,7 +198,7 @@ Finally, you can create a ``Dataset`` from existing data in the Ray object store # Create a Dataset from a list of Pandas DataFrame objects. pdf = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([pdf]) + ds = ray.data.from_pandas([ray.put(pdf)]) # Create a Dataset from a Dask-on-Ray DataFrame. dask_df = dd.from_pandas(pdf, npartitions=10) diff --git a/doc/source/data/package-ref.rst b/doc/source/data/package-ref.rst index afdace98bf719..0af38ba8297c7 100644 --- a/doc/source/data/package-ref.rst +++ b/doc/source/data/package-ref.rst @@ -15,13 +15,11 @@ Creating a Dataset .. autofunction:: ray.data.read_datasource .. autofunction:: ray.data.from_items .. autofunction:: ray.data.from_arrow -.. autofunction:: ray.data.from_arrow_refs .. autofunction:: ray.data.from_spark .. autofunction:: ray.data.from_dask .. autofunction:: ray.data.from_modin .. autofunction:: ray.data.from_mars .. autofunction:: ray.data.from_pandas -.. autofunction:: ray.data.from_pandas_refs .. autofunction:: ray.data.from_numpy Dataset API diff --git a/doc/source/development.rst b/doc/source/development.rst index f41b48d14fdef..d672c2b3fb5a0 100644 --- a/doc/source/development.rst +++ b/doc/source/development.rst @@ -100,9 +100,8 @@ Ray can be built from the repository as follows. git clone https://github.com/ray-project/ray.git # Install Bazel. + # (Windows users: please manually place Bazel in your PATH, and point BAZEL_SH to MSYS2's Bash.) ray/ci/travis/install-bazel.sh - # (Windows users: please manually place Bazel in your PATH, and point - # BAZEL_SH to MSYS2's Bash: ``set BAZEL_SH=C:\Program Files\Git\bin\bash.exe``) # Build the dashboard # (requires Node.js, see https://nodejs.org/ for more information). @@ -127,7 +126,7 @@ Building Ray on Windows (full) The following links were correct during the writing of this section. In case the URLs changed, search at the organizations' sites. -- bazel 4.2 (https://github.com/bazelbuild/bazel/releases/tag/4.2.1) +- bazel 3.4 (https://github.com/bazelbuild/bazel/releases/tag/3.4.0) - Microsoft Visual Studio 2019 (or Microsoft Build Tools 2019 - https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2019) - JDK 15 (https://www.oracle.com/java/technologies/javase-jdk15-downloads.html) - Miniconda 3 (https://docs.conda.io/en/latest/miniconda.html) @@ -150,11 +149,7 @@ The following links were correct during the writing of this section. In case the 3. Define an environment variable BAZEL_SH to point to bash.exe. If git for Windows was installed for all users, bash's path should be ``C:\Program Files\Git\bin\bash.exe``. If git was installed for a single user, adjust the path accordingly. -4. Bazel 4.2 installation. Go to bazel 4.2 release web page and download -bazel-4.2.1-windows-x86_64.exe. Copy the exe into the directory of your choice. -Define an environment variable BAZEL_PATH to full exe path (example: -``set BAZEL_PATH=C:\bazel\bazel.exe``). Also add the bazel directory to the -``PATH`` (example: ``set PATH=%PATH%;C:\bazel``) +4. Bazel 3.4 installation. Go to bazel 3.4 release web page and download bazel-3.4.0-windows-x86_64.exe. Copy the exe into the directory of your choice. Define an environment variable BAZEL_PATH to full exe path (example: ``C:\bazel\bazel-3.4.0-windows-x86_64.exe``) 5. Install cython and pytest: diff --git a/doc/source/index.rst b/doc/source/index.rst index 784df20c59e07..2024802af37d7 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -277,9 +277,8 @@ Papers :caption: Ray Data data/dataset.rst - data/dataset-pipeline.rst - data/examples/big_data_ingestion data/dataset-tensor-support.rst + data/dataset-pipeline.rst data/package-ref.rst data/dask-on-ray.rst data/mars-on-ray.rst @@ -339,7 +338,6 @@ Papers raysgd/v2/examples.rst raysgd/v2/architecture.rst raysgd/v2/api.rst - raysgd/v2/migration-guide.rst RaySGD v1: Distributed Training Wrappers .. toctree:: @@ -367,7 +365,7 @@ Papers .. toctree:: :hidden: :maxdepth: -1 - :caption: Contributor Guide + :caption: Contributing getting-involved.rst development.rst diff --git a/doc/source/raysgd/raysgd.rst b/doc/source/raysgd/raysgd.rst index 55ddcdb389fc1..87696e68d6535 100644 --- a/doc/source/raysgd/raysgd.rst +++ b/doc/source/raysgd/raysgd.rst @@ -6,7 +6,7 @@ RaySGD: Distributed Training Wrappers .. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. - See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. + See the documentation :ref:`here `. RaySGD is a lightweight library for distributed deep learning, providing thin wrappers around PyTorch and TensorFlow native modules for data parallel training. diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index 635d003e55032..5e9c1ce099141 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -3,16 +3,13 @@ Distributed PyTorch =================== -.. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. - See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. - The RaySGD ``TorchTrainer`` simplifies distributed model training for PyTorch. .. image:: raysgd-actors.svg :align: center -.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! +.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! The ``TorchTrainer`` is a wrapper around ``torch.distributed.launch`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to needing to wrap your training code in bash scripts. diff --git a/doc/source/raysgd/raysgd_tensorflow.rst b/doc/source/raysgd/raysgd_tensorflow.rst index 2cbf01da2e3c3..f18d7f9ec3924 100644 --- a/doc/source/raysgd/raysgd_tensorflow.rst +++ b/doc/source/raysgd/raysgd_tensorflow.rst @@ -1,9 +1,6 @@ Distributed TensorFlow ====================== -.. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. - See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. - RaySGD's ``TFTrainer`` simplifies distributed model training for Tensorflow. The ``TFTrainer`` is a wrapper around ``MultiWorkerMirroredStrategy`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to write custom logic of setting environments and starting separate processes. Under the hood, ``TFTrainer`` will create *replicas* of your model (controlled by ``num_replicas``), each of which is managed by a Ray actor. @@ -11,7 +8,7 @@ Under the hood, ``TFTrainer`` will create *replicas* of your model (controlled b .. image:: raysgd-actors.svg :align: center -.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! +.. tip:: We need your feedback! RaySGD is currently early in its development, and we're hoping to get feedback from people using or considering it. We'd love `to get in touch `_! ---------- diff --git a/doc/source/raysgd/raysgd_tune.rst b/doc/source/raysgd/raysgd_tune.rst index 740ff78b0390c..cacaea0a20c4e 100644 --- a/doc/source/raysgd/raysgd_tune.rst +++ b/doc/source/raysgd/raysgd_tune.rst @@ -3,9 +3,6 @@ RaySGD Hyperparameter Tuning ============================ -.. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. - See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. - RaySGD integrates with :ref:`Ray Tune ` to easily run distributed hyperparameter tuning experiments with your RaySGD Trainer. PyTorch diff --git a/doc/source/raysgd/v2/api.rst b/doc/source/raysgd/v2/api.rst index 97b48a26b11ce..fc3028bc9fc19 100644 --- a/doc/source/raysgd/v2/api.rst +++ b/doc/source/raysgd/v2/api.rst @@ -22,8 +22,10 @@ SGDIterator .. _sgd-api-backend-config: -Backend Configurations ----------------------- +BackendConfig +------------- + +.. autoclass:: ray.sgd.BackendConfig .. _sgd-api-torch-config: @@ -46,14 +48,10 @@ HorovodConfig .. autoclass:: ray.sgd.HorovodConfig - -Callbacks ---------- - .. _sgd-api-callback: SGDCallback -~~~~~~~~~~~ +----------- .. autoclass:: ray.sgd.SGDCallback :members: @@ -63,22 +61,19 @@ SGDCallback JsonLoggerCallback ~~~~~~~~~~~~~~~~~~ -.. autoclass:: ray.sgd.callbacks.JsonLoggerCallback +.. autoclass:: ray.sgd.JsonLoggerCallback .. _sgd-api-tbx-logger-callback: TBXLoggerCallback ~~~~~~~~~~~~~~~~~ -.. autoclass:: ray.sgd.callbacks.TBXLoggerCallback - -Checkpointing -------------- +.. autoclass:: ray.sgd.TBXLoggerCallback .. _sgd-api-checkpoint-strategy: CheckpointStrategy -~~~~~~~~~~~~~~~~~~ +------------------ .. autoclass:: ray.sgd.CheckpointStrategy diff --git a/doc/source/raysgd/v2/examples.rst b/doc/source/raysgd/v2/examples.rst index 3edee334aea2a..a35f394c7593c 100644 --- a/doc/source/raysgd/v2/examples.rst +++ b/doc/source/raysgd/v2/examples.rst @@ -61,9 +61,6 @@ Ray Tune Integration Examples * :doc:`/raysgd/v2/examples/tune_tensorflow_mnist_example`: End-to-end example for tuning a TensorFlow model. -* :doc:`/raysgd/v2/examples/tune_cifar_pytorch_pbt_example`: - End-to-end example for tuning a PyTorch model with PBT. - .. TODO implement these examples! diff --git a/doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst b/doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst deleted file mode 100644 index 31aabc7ca78ab..0000000000000 --- a/doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst +++ /dev/null @@ -1,6 +0,0 @@ -:orphan: - -tune_cifar_pytorch_pbt_example -============================== - -.. literalinclude:: /../../python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py diff --git a/doc/source/raysgd/v2/migration-guide.rst b/doc/source/raysgd/v2/migration-guide.rst deleted file mode 100644 index 08effe4b25e98..0000000000000 --- a/doc/source/raysgd/v2/migration-guide.rst +++ /dev/null @@ -1,393 +0,0 @@ -.. _sgd-migration: - -Migrating from Ray SGD v1 -========================= - -In Ray 1.7, we are rolling out a new and more streamlined version of Ray SGD. Ray SGD v2 focuses on usability and composability - it has a much simpler API, has support for more deep learning backends, integrates better with other libraries in the Ray ecosystem, and will continue to be actively developed with more features. - -This guide will help you easily migrate existing code from Ray SGD v1 to Ray SGD v2. If you are new to Ray SGD as a whole, you should get started with :ref:`Ray SGD v2 directly `. - -For a full list of features that Ray SGD v2 provides, please check out the :ref:`user guide`. - -.. note:: If there are any issues or anything missing with this guide or any feedback on Ray SGD v2 overall, please file a `Github issue on the Ray repo `_! - -What are the API differences? ------------------------------ - -There are 3 primary API differences between Ray SGD v1 and v2. - -1. There is a single ``Trainer`` interface for all backends (torch, tensorflow, horovod), and the backend is simply specified via an argument: ``Trainer(backend="torch")``\ , ``Trainer(backend="horovod")``\ , etc. Any features that we add to Ray SGD will be supported for all backends, and there won't be any API divergence like there was with a separate ``TorchTrainer`` and ``TFTrainer``. -2. The ``TrainingOperator`` and creator functions are replaced by a more natural user-defined training function. You no longer have to make your training logic fit into a restrictive interface. In Ray SGD v2, you simply have to provide a training function that describes the full logic for your training execution and this will be distributed by Ray SGD v2. - - .. code-block:: python - - from torch.nn.parallel import DistributedDataParallel - from torch import nn, optim - - # Torch Example - def train_func_distributed(): - num_epochs = 3 - model = NeuralNetwork() - model = DistributedDataParallel(model) - loss_fn = nn.MSELoss() - optimizer = optim.SGD(model.parameters(), lr=0.1) - - for epoch in range(num_epochs): - output = model(input) - loss = loss_fn(output, labels) - optimizer.zero_grad() - loss.backward() - optimizer.step() - print(f"epoch: {epoch}, loss: {loss.item()}") - - from ray.sgd import Trainer - - trainer = Trainer(backend="torch", num_workers=4) - trainer.start() - results = trainer.run(train_func_distributed) - trainer.shutdown() - -Currently, this means that you are now responsible for modifying your code to support distributed training (specifying ``DistributedDataParallel`` for ``torch`` or ``MultiWorkerMirroredStrategy`` for ``tensorflow``) as opposed to having this be automatically handled internally. However, we have plans to provide utilities that you can use to automatically handle these recipes for you. - -3. Rather than iteratively calling ``trainer.train()`` or ``trainer.validate()`` for each epoch, in Ray SGD v2 the training function defines the full training execution and is run via ``trainer.run(train_func)``. - -In the following sections, we will guide you through the steps to migrate: - -1. :ref:`sgd-migration-logic` -2. :ref:`Interacting with Trainer state (intermediate metrics, checkpointing) ` -3. :ref:`Hyperparameter Tuning with Ray Tune ` - -.. _sgd-migration-logic: - -Training Logic --------------- -The main change you will have to make is how you define your training logic. In Ray SGD v1, the API for defining training logic differed for `TorchTrainer` vs. `TFTrainer`, so the steps to migrate will be different for each of these. - -PyTorch -~~~~~~~ -In v1, the training logic is defined through the ``train_epoch`` and ``train_batch`` methods of a ``TrainingOperator`` class which is passed into the ``TorchTrainer``. To migrate to Ray SGD v2, there are 2 options: - -1. If you felt the ``TrainingOperator`` is too unnecessary and complex, or you had to customize it extensively, you can define your own training function. -2. If you liked having your training logic in the ``TrainingOperator``, you can continue to use the ``TrainingOperator`` with Ray SGD v2. - -**Alternative 1: Custom Training Function** -You can define your own custom training function, and use only the parts from ``TrainingOperator.train_epoch``, ``TrainingOperator.setup``, and ``TrainingOperator.validate`` that are necessary for your application. - -You can see a full example on how to :ref:`port over regular PyTorch DDP code to Ray SGD here ` - -**Alternative 2: Continue to use TrainingOperator** -Alternatively, if you liked having the ``TrainingOperator``, you can define a training function that instantiates your `TrainingOperator` and you can call methods directly on the operator object. - -So instead of - -.. code-block:: python - - from ray.util.sgd import TrainingOperator, TorchTrainer - - class MyTrainingOperator(TrainingOperator): - ... - - trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=4, use_gpu=True) - - num_epochs=10 - for _ in range(num_epochs): - trainer.train() - trainer.validate() - - final_model = trainer.get_model() - - -you would do - -.. code-block:: python - - from ray.util.sgd import TrainingOperator - from ray.sgd import Trainer - from ray import sgd - - class MyTrainingOperator(TrainingOperator): - ... - - def train_func(config): - device = torch.device(f"cuda:{sgd.local_rank()}" if - torch.cuda.is_available() else "cpu") - if torch.cuda.is_available(): - torch.cuda.set_device(device) - - # Set the args to whatever values you want. - training_operator = MyTrainingOperator( - config=config, - world_rank=sgd.world_rank(), - local_rank=sgd.local_rank(), - is_distributed=True, - device=device, - use_gpu=True, - wrap_ddp=True, - add_dist_sampler=True - - training_operator.setup(config) - - for idx in range(config["num_epochs"]): - train_loader = training_operator._get_train_loader() - # If using DistributedSampler, set the epoch here. - train_loader.set_epoch(idx) - training_operator.train_epoch(epoch_idx=idx, iter(train_loader)) - - validation_loader = training_operator._get_validation_loader() - training_operator.validate(iterator=iter(validation_loader)) - - if sgd.world_rank() == 0: - return training_operator._get_original_models() - else: - return None - - trainer = Trainer(backend="torch", num_workers=4, use_gpu=True) - trainer.start() - results = trainer.run(train_func, config={"num_epochs": 10}) - final_model = results[0] - -Tensorflow -~~~~~~~~~~ - -The API for ``TFTrainer`` uses creator functions instead of a ``TrainingOperator`` to define the training logic. To port over Ray SGD v1 Tensorflow code to v2 you can do the following: - -.. code-block:: python - - from tensorflow.distribute import MultiWorkerMirroredStrategy - - from ray.sgd import Trainer - from ray import sgd - - def train_func(config): - train_dataset, val_dataset = data_creator(config) - strategy = MultiWorkerMirroredStrategy() - with strategy.scope(): - model = model_creator(config) - - for epoch_idx in range(config["num_epochs"]): - model.fit(train_dataset) - - if sgd.world_rank() == 0: - return model - else: - return None - - trainer = Trainer(backend="tensorflow", num_workers=4, config={"num_epochs": 3, ...}) - trainer.start() - model = trainer.run(train_func)[0] - -You can see a full example :ref:`here `. - -.. _sgd-migration-trainer: - -Interacting with the ``Trainer`` --------------------------------- - -In Ray SGD v1, you can iteratively call ``trainer.train()`` or ``trainer.validate()`` for each epoch, and can then interact with the trainer to get certain state (model, checkpoints, results, etc.). In Ray SGD v2, this is replaced by a single training function that defines the full training & validation loop for all epochs. - -There are 3 ways to get state during or after the training execution: - - -#. Return values from your training function -#. Intermediate results via ``sgd.report()`` -#. Saving & loading checkpoints via ``sgd.save_checkpoint()`` and ``sgd.load_checkpoint()`` - -Return Values -~~~~~~~~~~~~~ - -To get any state from training *after* training has completed, you can simply return it from your training function. The return values from each the workers will be added to a list and returned from the ``trainer.run()`` call. - -For example, to get the final model: - -**SGD v1** - -.. code-block:: python - - from ray.util.sgd import TorchTrainer, TrainingOperator - - class MyTrainingOperator(TrainingOperator): - ... - - trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=2) - - trainer.train() - - trained_model = trainer.get_model() - -**SGD v2** - -.. code-block:: python - - from ray.sgd import Trainer - - def train_func(): - model = Net() - trainer_loader = MyDataset() - for batch in train_loader: - model.train(batch) - - return model - - trainer = Trainer(backend="torch") - trainer.start() - results = trainer.run(train_func, num_workers=2) - assert len(results) == 2 - trained_model = results[0] - -Intermediate Reporting -~~~~~~~~~~~~~~~~~~~~~~ - -If you want to access any values *during* the training process, you can do so via ``sgd.report()``. You can pass in any values to ``sgd.report()`` and these values from all workers will be sent to any callbacks passed into your ``Trainer``. - -**SGD v1** - -.. code-block:: python - - from ray.util.sgd import TorchTrainer, TrainingOperator - - class MyTrainingOperator(TrainingOperator): - ... - - trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=2) - - for _ in range(3): - print(trainer.train(reduce_results=False)) - - -**SGD v2** - -.. code-block:: python - - from ray import sgd - from ray.sgd Trainer - from ray.sgd.callbacks import SGDCallback - from typing import List, Dict - - class PrintingCallback(SGDCallback): - def handle_result(self, results: List[Dict], **info): - print(results) - - def train_func(): - for i in range(3): - sgd.report(epoch=i) - - trainer = Trainer(backend="torch", num_workers=2) - trainer.start() - result = trainer.run( - train_func, - callbacks=[PrintingCallback()] - ) - # [{'epoch': 0, '_timestamp': 1630471763, '_time_this_iter_s': 0.0020279884338378906, '_training_iteration': 1}, {'epoch': 0, '_timestamp': 1630471763, '_time_this_iter_s': 0.0014922618865966797, '_training_iteration': 1}] - # [{'epoch': 1, '_timestamp': 1630471763, '_time_this_iter_s': 0.0008401870727539062, '_training_iteration': 2}, {'epoch': 1, '_timestamp': 1630471763, '_time_this_iter_s': 0.0007486343383789062, '_training_iteration': 2}] - # [{'epoch': 2, '_timestamp': 1630471763, '_time_this_iter_s': 0.0014500617980957031, '_training_iteration': 3}, {'epoch': 2, '_timestamp': 1630471763, '_time_this_iter_s': 0.0015292167663574219, '_training_iteration': 3}] - trainer.shutdown() - -See the :ref:`v2 User Guide ` for more details. - -Checkpointing -~~~~~~~~~~~~~ - -Finally, you can also use ``sgd.save_checkpoint()`` and ``sgd.load_checkpoint()`` to write checkpoints to disk during the training process, and to load from the most recently saved checkpoint in the case of node failures. - -See the :ref:`Checkpointing ` and :ref:`Fault Tolerance & Elastic Training ` sections on the user guide for more info. - -For example, in order to save checkpoints after every epoch: - -**SGD v1** - -.. code-block:: python - - from ray.util.sgd import TorchTrainer, TrainingOperator - - class MyTrainingOperator(TrainingOperator): - ... - - trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=2) - - for _ in range(3): - trainer.train() - trainer.save_checkpoint(checkpoint_dir="~/ray_results") - - -**SGD v2** - -.. code-block:: python - - from ray.sgd import Trainer - from ray import sgd - - def train_func(): - model = Net() - trainer_loader = MyDataset() - for i in range(3): - for batch in train_loader: - model.train(batch) - sgd.save_checkpoint(epoch=i, model=model.state_dict())) - - trainer = Trainer(backend="torch") - trainer.start() - trainer.run(train_func, num_workers=2) - - -.. _sgd-migration-tune: - -Hyperparameter Tuning with Ray Tune ------------------------------------ - -Ray SGD v2 also comes with an easier to use interface for Hyperparameter Tuning with Ray Tune using Tune's function API instead of its Class API. In particular, it is much easier to define custom procedures because the logic is entirely defined by your training function. - -There is a 1:1 mapping between rank 0 worker's ``sgd.report()``\ , ``sgd.save_checkpoint()``\ , and ``sgd.load_checkpoint()`` with ``tune.report()``\ , ``tune.save_checkpoint()``\ , and ``tune.load_checkpoint()``. - -**SGD v1** - -.. code-block:: python - - from ray import tune - from ray.util.sgd import TrainingOperator, TorchTrainer - - class MyTrainingOperator(TrainingOperator): - ... - - def custom_step(trainer, info): - train_stats = trainer.train() - return train_stats - - # TorchTrainable is subclass of BaseTorchTrainable. - TorchTrainable = TorchTrainer.as_trainable( - training_operator_cls=MyTrainingOperator, - num_workers=2, - use_gpu=True, - override_tune_step=custom_step - ) - - analysis = tune.run( - TorchTrainable, - config={"input": tune.grid_search([1, 2, 3])} - ) - - - -**SGD v2** - -.. code-block:: python - - from ray import tune - from ray import sgd - from ray.sgd import Trainer - - def train_func(config) - # In this example, nothing is expected to change over epochs, - # and the output metric is equivalent to the input value. - for _ in range(config["num_epochs"]): - sgd.report(output=config["input"]) - - trainer = Trainer(backend="torch", num_workers=2) - trainable = trainer.to_tune_trainable(train_func) - analysis = tune.run(trainable, config={ - "num_epochs": 2, - "input": tune.grid_search([1, 2, 3]) - }) - print(analysis.get_best_config(metric="output", mode="max")) - # {'num_epochs': 2, 'input': 3} - -For more information see :ref:`sgd-tune` \ No newline at end of file diff --git a/doc/source/raysgd/v2/raysgd.rst b/doc/source/raysgd/v2/raysgd.rst index a37e583a7fe7e..02111cdae1672 100644 --- a/doc/source/raysgd/v2/raysgd.rst +++ b/doc/source/raysgd/v2/raysgd.rst @@ -5,8 +5,6 @@ RaySGD: Deep Learning on Ray .. _`issue on GitHub`: https://github.com/ray-project/ray/issues -.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! - RaySGD is a lightweight library for distributed deep learning, allowing you to scale up and speed up training for your deep learning models. @@ -23,6 +21,7 @@ The main features are: `issue on GitHub`_. If you are looking for the previous API documentation, see :ref:`sgd-index`. + Intro to RaySGD --------------- diff --git a/doc/source/raysgd/v2/user_guide.rst b/doc/source/raysgd/v2/user_guide.rst index 2c34e59dd29f2..fe33949342af0 100644 --- a/doc/source/raysgd/v2/user_guide.rst +++ b/doc/source/raysgd/v2/user_guide.rst @@ -3,8 +3,6 @@ RaySGD User Guide ================= -.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! - In this guide, we cover examples for the following use cases: * How do I :ref:`port my code ` to using RaySGD? @@ -90,7 +88,6 @@ training. If you are using GPUs, you need to make sure to the CUDA devices are properly setup inside your training function. This involves 3 steps: - 1. Use the local rank to set the default CUDA device for the worker. 2. Move the model to the default CUDA device (or a specific CUDA device). 3. Specify ``device_ids`` when wrapping in ``DistributedDataParallel``. @@ -344,8 +341,7 @@ You can plug all of these into RaySGD with the following interface: .. code-block:: python from ray import sgd - from ray.sgd Trainer - from ray.sgd.callbacks import SGDCallback + from ray.sgd import SGDCallback, Trainer from typing import List, Dict class PrintingCallback(SGDCallback): @@ -399,7 +395,7 @@ A simple example for creating a callback that will print out results: .. code-block:: python - from ray.sgd.callbacks import SGDCallback + from ray.sgd import SGDCallback class PrintingCallback(SGDCallback): def handle_result(self, results: List[Dict], **info): @@ -639,7 +635,7 @@ Underneath the hood, RaySGD will automatically shard the given dataset. return model trainer = Trainer(num_workers=8, backend="torch") - dataset = ray.data.read_csv("...").filter().window(blocks_per_window=50) + dataset = ray.data.read_csv("...").filter().pipeline(length=50) result = trainer.run( train_func, @@ -742,7 +738,7 @@ A couple caveats: # Declare the specification for training. trainer = Trainer(backend="torch", num_workers=12, use_gpu=True) - dataset = ray.dataset.window() + dataset = ray.dataset.pipeline() # Convert this to a trainable. trainable = trainer.to_tune_trainable(training_func, dataset=dataset) diff --git a/doc/source/serve/core-apis.rst b/doc/source/serve/core-apis.rst index 2bd1f834c465d..e5130821c98be 100644 --- a/doc/source/serve/core-apis.rst +++ b/doc/source/serve/core-apis.rst @@ -35,14 +35,7 @@ Deployments can be exposed in two ways: over HTTP or in Python via the :ref:`ser By default, HTTP requests will be forwarded to the ``__call__`` method of the class (or the function) and a ``Starlette Request`` object will be the sole argument. You can also define a deployment that wraps a FastAPI app for more flexible handling of HTTP requests. See :ref:`serve-fastapi-http` for details. -To serve multiple deployments defined by the same class, use the ``name`` option: - -.. code-block:: python - - MyFirstDeployment.options(name="hello_service").deploy("Hello!") - MyFirstDeployment.options(name="hi_service").deploy("Hi!) - -You can also list all available deployments and dynamically get references to them: +We can also list all available deployments and dynamically get a reference to them: .. code-block:: python @@ -245,31 +238,27 @@ Ray Serve supports serving deployments with different (possibly conflicting) Python dependencies. For example, you can simultaneously serve one deployment that uses legacy Tensorflow 1 and another that uses Tensorflow 2. -This is supported on Mac OS and Linux using Ray's :ref:`runtime-environments` feature. -As with all other Ray actor options, pass the runtime environment in via ``ray_actor_options`` in -your deployment. Be sure to first run ``pip install "ray[default]"`` to ensure the -Runtime Environments feature is installed. - -Example: +Currently this is supported on Mac OS and Linux using `conda `_ +via Ray's built-in ``runtime_env`` option for actors. +As with all other actor options, pass these in via ``ray_actor_options`` in +your deployment. +You must have a conda environment set up for each set of +dependencies you want to isolate. If using a multi-node cluster, the +desired conda environment must be present on all nodes. Also, the Python patch version +(e.g. 3.8.10) must be identical on all nodes (this is a requirement for any Ray cluster). +See :ref:`runtime-environments` for details. + +Here's an example script. For it to work, first create a conda +environment named ``ray-tf1`` with Ray Serve and Tensorflow 1 installed, +and another named ``ray-tf2`` with Ray Serve and Tensorflow 2. The Ray and +Python versions must be the same in both environments. .. literalinclude:: ../../../python/ray/serve/examples/doc/conda_env.py -.. note:: - When using a Ray library (for example, Ray Serve) in a runtime environment, it must - explicitly be included in the dependencies, as in the above example. This is not - required when just using Ray Core. - -.. tip:: - Avoid dynamically installing packages that install from source: these can be slow and - use up all resources while installing, leading to problems with the Ray cluster. Consider - precompiling such packages in a private repository or Docker image. - The dependencies required in the deployment may be different than the dependencies installed in the driver program (the one running Serve API calls). In this case, you should use a delayed import within the class to avoid -importing unavailable packages in the driver. This applies even when not -using runtime environments. - +importing unavailable packages in the driver. Example: .. literalinclude:: ../../../python/ray/serve/examples/doc/imported_backend.py diff --git a/doc/source/serve/deployment.rst b/doc/source/serve/deployment.rst index a57c152561481..200b97f5dc710 100644 --- a/doc/source/serve/deployment.rst +++ b/doc/source/serve/deployment.rst @@ -25,7 +25,7 @@ to update the Serve instance, you can run another script that connects to the sa All non-detached Serve instances will be started in the current namespace that was specified when connecting to the cluster. If a namespace is specified for a detached Serve instance, it will be used. Otherwise if the current namespace is anonymous, the Serve instance will be started in the ``serve`` namespace. -If ``serve.start()`` is called again in a process in which there is already a running Serve instance, Serve will re-connect to the existing instance (regardless of whether the original instance was detached or not). To reconnect to a Serve instance that exists in the Ray cluster but not in the current process, connect to the cluster with the same namespace that was specified when starting the instance and run ``serve.start()``. +If ``serve.start()`` is called again in a process in which there is already a running Serve instance, Serve will re-connect to the existing instance (regardless of whether the original instance was detached or not). To reconnect to a Serve instance that exists in the Ray cluster but not in the current process, connect to the cluster with the same namespace that was specified when starting the instance and run ``serve.start()``. Deploying on a Single Node ========================== @@ -244,7 +244,7 @@ To automatically include the current deployment and replica in your logs, simply ``logger = logging.getLogger("ray")``, and use ``logger`` within your deployment code: .. literalinclude:: ../../../python/ray/serve/examples/doc/snippet_logger.py - :lines: 1, 9, 11-14, 16-17 + :lines: 1, 9, 11-13, 15-16 Querying a Serve endpoint with the above deployment will produce a log line like the following: @@ -290,7 +290,7 @@ Save the following file as ``promtail-local-config.yaml``: job: ray __path__: /tmp/ray/session_latest/logs/*.* -The relevant part for Ray is the ``static_configs`` field, where we have indicated the location of our log files with ``__path__``. +The relevant part for Ray is the ``static_configs`` field, where we have indicated the location of our log files with ``__path__``. The expression ``*.*`` will match all files, but not directories, which cause an error with Promtail. We will run Loki locally. Grab the default config file for Loki with the following command in your terminal: @@ -334,7 +334,7 @@ Now click "Explore" in the left-side panel. You are ready to run some queries! To filter all these Ray logs for the ones relevant to our deployment, use the following `LogQL `__ query: -.. code-block:: shell +.. code-block:: shell {job="ray"} |= "deployment=Counter" @@ -377,7 +377,7 @@ The following metrics are exposed by Ray Serve: - The number of requests processed by the router. * - ``serve_handle_request_counter`` - The number of requests processed by this ServeHandle. - * - ``serve_deployment_queued_queries`` + * - ``serve_deployment_queued_queries`` - The number of queries for this deployment waiting to be assigned to a replica. To see this in action, run ``ray start --head --metrics-export-port=8080`` in your terminal, and then run the following script: diff --git a/doc/source/serve/ml-models.rst b/doc/source/serve/ml-models.rst index 192207b041ac5..8fe3330af0498 100644 --- a/doc/source/serve/ml-models.rst +++ b/doc/source/serve/ml-models.rst @@ -70,10 +70,10 @@ Integration with Model Registries Ray Serve is flexible. If you can load your model as a Python function or class, then you can scale it up and serve it with Ray Serve. -For example, if you are using the +For example, if you are using the `MLflow Model Registry `_ to manage your models, the following wrapper -class will allow you to load a model using its MLflow `Model URI`: +class will allow you to load a model using its MLflow `Model URI`: .. code-block:: python @@ -93,19 +93,12 @@ class will allow you to load a model using its MLflow `Model URI`: model_uri = "model:/my_registered_model/Production" MLflowDeployment.deploy(model_uri) -To serve multiple different MLflow models in the same program, use the ``name`` option: - -.. code-block:: python - - MLflowDeployment.options(name="my_mlflow_model_1").deploy(model_uri) - - -.. tip:: +.. tip:: The above approach will work for any model registry, not just MLflow. Namely, load the model from the registry in ``__init__``, and forward the request to the model in ``__call__``. -For an even more hands-off and seamless integration with MLflow, check out the +For an even more hands-off and seamless integration with MLflow, check out the `Ray Serve MLflow deployment plugin `__. A full tutorial is available `here `__. diff --git a/doc/source/tune/_tutorials/_faq.inc b/doc/source/tune/_tutorials/_faq.inc index c14a0aa4504cd..d9bb39e1f94dc 100644 --- a/doc/source/tune/_tutorials/_faq.inc +++ b/doc/source/tune/_tutorials/_faq.inc @@ -19,18 +19,10 @@ Deciding on which to use mostly depends on your problem: * How many hyperparameters would you like to tune? * What values are valid for hyperparameters? -**If your model returns incremental results** (eg. results per epoch in deep learning, -results per each added tree in GBDTs, etc.) using early stopping usually allows for sampling -more configurations, as unpromising trials are pruned before they run their full course. -Please note that not all search algorithms can use information from pruned trials. -Early stopping cannot be used without incremental results - in case of the functional API, -that means that ``tune.report()`` has to be called more than once - usually in a loop. - **If your model is small**, you can usually try to run many different configurations. A **random search** can be used to generate configurations. You can also grid search over some values. You should probably still use -:ref:`ASHA for early termination of bad trials ` (if your problem -supports early stopping). +:ref:`ASHA for early termination of bad trials `. **If your model is large**, you can try to either use **Bayesian Optimization-based search algorithms** like :ref:`BayesOpt ` or @@ -41,19 +33,14 @@ Alternatively, you can use :ref:`Population Based Training ` works well with few trials, e.g. 8 or even 4. However, this will output a hyperparameter *schedule* rather than one fixed set of hyperparameters. -**If you have a small number of hyperparameters**, Bayesian Optimization methods -work well. Take a look at :ref:`BOHB ` or :ref:`Optuna ` -with the :ref:`ASHA ` scheduler to combine the -benefits of Bayesian Optimization with early stopping. +**If you have a small number of hyperparameters**, Bayesian Optimization-methods +work well. Take a look at :ref:`BOHB ` to combine the +benefits of bayesian optimization with early stopping. **If you only have continuous values for hyperparameters** this will work well -with most Bayesian Optimization methods. Discrete or categorical variables still +with most Bayesian-Optimization methods. Discrete or categorical variables still work, but less good with an increasing number of categories. -**If you have many categorical values for hyperparameters**, consider using random search, -or a TPE-based Bayesian Optimization algorithm such as :ref:`Optuna ` or -:ref:`HyperOpt `. - **Our go-to solution** is usually to use **random search** with :ref:`ASHA for early stopping ` for smaller problems. Use :ref:`BOHB ` for **larger problems** with a **small number of hyperparameters** and :ref:`Population Based Training ` for **larger problems** with a **large number of hyperparameters** @@ -261,34 +248,6 @@ on other nodes as well. Please refer to the :ref:`placement groups documentation ` to learn more about these placement strategies. -Why is my training stuck and Ray reporting that pending actor or tasks cannot be scheduled? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This is usually caused by Ray actors or tasks being started by the -trainable without the trainable resources accounting for them, leading to a deadlock. -This can also be "stealthly" caused by using other libraries in the trainable that are -based on Ray, such as Modin. In order to fix the issue, request additional resources for -the trial using :ref:`placement groups `, as outlined in -the section above. - -For example, if your trainable is using Modin dataframes, operations on those will spawn -Ray tasks. By allocating an additional CPU bundle to the trial, those tasks will be able -to run without being starved of resources. - -.. code-block:: python - - import modin.pandas as pd - - def train_fn(config, checkpoint_dir=None): - # some Modin operations here - tune.report(metric=metric) - - tune.run( - train_fn, - resources_per_trial=tune.PlacementGroupFactory([ - {"CPU": 1}, # this bundle will be used by the trainable itself - {"CPU": 1}, # this bundle will be used by Modin - ], strategy="PACK") How can I pass further parameter values to my trainable? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -327,8 +286,8 @@ also works with class trainables. Please see :ref:`here for further details ` and examples. -How can I reproduce experiments? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +How can I reproduce experiments +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Reproducing experiments and experiment results means that you get the exact same results when running an experiment again and again. To achieve this, the conditions have to be exactly the same each time you run the exeriment. diff --git a/doc/source/tune/api_docs/suggestion.rst b/doc/source/tune/api_docs/suggestion.rst index 4795f0c97816f..32728c4ab2273 100644 --- a/doc/source/tune/api_docs/suggestion.rst +++ b/doc/source/tune/api_docs/suggestion.rst @@ -16,7 +16,6 @@ Summary ------- .. list-table:: - :widths: 5 5 2 10 :header-rows: 1 * - SearchAlgorithm @@ -138,6 +137,8 @@ identifier. search_alg2.restore_from_dir( os.path.join("~/my_results", "my-experiment-1")) +.. note:: This is currently not implemented for: AxSearch, TuneBOHB, SigOptSearch, and DragonflySearch. + .. _tune-basicvariant: Random search and grid search (tune.suggest.basic_variant.BasicVariantGenerator) diff --git a/doc/source/tune/user-guide.rst b/doc/source/tune/user-guide.rst index 962d53bdad848..a4f522add908e 100644 --- a/doc/source/tune/user-guide.rst +++ b/doc/source/tune/user-guide.rst @@ -50,8 +50,7 @@ the respective placement group. If not enough resources are available, this will If your trainable function starts more remote workers, you will need to pass placement groups factory objects to request these resources. See the :class:`PlacementGroupFactory documentation ` -for further information. This also applies if you are using other libraries making use of Ray, such -as Modin. Failure to set resources correctly may result in a deadlock, "hanging" the cluster. +for further information. Using GPUs ~~~~~~~~~~ @@ -871,10 +870,6 @@ These are the environment variables Ray Tune currently considers: Ctrl+C) to gracefully shutdown and do a final checkpoint. Setting this variable to ``1`` will disable signal handling and stop execution right away. Defaults to ``0``. -* **TUNE_FORCE_TRIAL_CLEANUP_S**: By default, Ray Tune will gracefully terminate trials, - letting them finish the current training step and any user-defined cleanup. - Setting this variable to a non-zero, positive integer will cause trials to be forcefully - terminated after a grace period of that many seconds. Defaults to ``0``. * **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits for threads to finish after instructing them to complete. Defaults to ``2``. * **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's @@ -908,9 +903,6 @@ These are the environment variables Ray Tune currently considers: to the driver. Enabling this might delay scheduling decisions, as trainables are speculatively continued. Setting this to ``0`` disables result buffering. Defaults to 1000 (results), or to 1 (no buffering) if used with ``checkpoint_at_end``. -* **TUNE_RESULT_DELIM**: Delimiter used for nested entries in - :class:`ExperimentAnalysis ` dataframes. Defaults to ``.`` (but will be - changed to ``/`` in future versions of Ray). * **TUNE_RESULT_BUFFER_MAX_TIME_S**: Similarly, Ray Tune buffers results up to ``number_of_trial/10`` seconds, but never longer than this value. Defaults to 100 (seconds). * **TUNE_RESULT_BUFFER_MIN_TIME_S**: Additionally, you can specify a minimum time to buffer results. Defaults to 0. diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 06a974befed02..2d98abb9402ba 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -147,12 +147,9 @@ define_java_module( ":io_ray_ray_api", ":io_ray_ray_runtime", ":io_ray_ray_serve", - "@maven//:com_google_code_gson_gson", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:org_apache_commons_commons_lang3", - "@maven//:org_apache_httpcomponents_client5_httpclient5", - "@maven//:org_apache_httpcomponents_core5_httpcore5", "@maven//:org_slf4j_slf4j_api", "@maven//:org_testng_testng", ], @@ -160,11 +157,9 @@ define_java_module( deps = [ ":io_ray_ray_api", ":io_ray_ray_runtime", - "@maven//:com_google_code_gson_gson", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:org_apache_commons_commons_lang3", - "@maven//:org_apache_httpcomponents_core5_httpcore5", "@maven//:org_slf4j_slf4j_api", ], ) diff --git a/java/dependencies.bzl b/java/dependencies.bzl index e6bb9e384d1cf..9c411a1bd9982 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -24,8 +24,6 @@ def gen_java_deps(): "com.lmax:disruptor:3.3.4", "org.yaml:snakeyaml:1.26", "net.java.dev.jna:jna:5.5.0", - "org.apache.httpcomponents.client5:httpclient5:5.0.3", - "org.apache.httpcomponents.core5:httpcore5:5.0.2", maven.artifact( group = "org.testng", artifact = "testng", diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index 172ff78dfa397..acda82aa6f1d6 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -1,7 +1,6 @@ package io.ray.runtime; import com.google.common.base.Preconditions; -import com.google.gson.Gson; import io.ray.api.BaseActorHandle; import io.ray.api.id.ActorId; import io.ray.api.id.JobId; @@ -11,7 +10,6 @@ import io.ray.runtime.exception.RayIntentionalSystemExitException; import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.gcs.GcsClientOptions; -import io.ray.runtime.generated.Common.RuntimeEnv; import io.ray.runtime.generated.Common.WorkerType; import io.ray.runtime.generated.Gcs.GcsNodeInfo; import io.ray.runtime.generated.Gcs.JobConfig; @@ -22,8 +20,6 @@ import io.ray.runtime.task.TaskExecutor; import io.ray.runtime.util.BinaryFileUtil; import io.ray.runtime.util.JniUtils; -import java.util.HashMap; -import java.util.Map; import java.util.Optional; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; @@ -106,20 +102,8 @@ public void start() { JobConfig.newBuilder() .setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess) .addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker) + .putAllWorkerEnv(rayConfig.workerEnv) .addAllCodeSearchPath(rayConfig.codeSearchPath); - RuntimeEnv.Builder runtimeEnvBuilder = RuntimeEnv.newBuilder(); - if (!rayConfig.workerEnv.isEmpty()) { - // TODO(SongGuyang): Suppport complete runtime env interface for users. - // Set worker env to the serialized runtime env json. - Gson gson = new Gson(); - Map> runtimeEnv = new HashMap<>(); - runtimeEnv.put("env_vars", rayConfig.workerEnv); - String gsonString = gson.toJson(runtimeEnv); - runtimeEnvBuilder.setSerializedRuntimeEnv(gsonString); - } else { - runtimeEnvBuilder.setSerializedRuntimeEnv("{}"); - } - jobConfigBuilder.setRuntimeEnv(runtimeEnvBuilder.build()); serializedJobConfig = jobConfigBuilder.build().toByteArray(); } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java index fc139985955c9..131d71c5fa2f9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java @@ -117,7 +117,7 @@ public Address getOwnerAddress(ObjectId id) { } @Override - public byte[] getOwnershipInfo(ObjectId objectId) { + public byte[] promoteAndGetOwnershipInfo(ObjectId objectId) { return new byte[0]; } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java index 136712c096cd8..7e0ddc5c9aa74 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java @@ -78,8 +78,8 @@ public void removeLocalReference(UniqueId workerId, ObjectId objectId) { } @Override - public byte[] getOwnershipInfo(ObjectId objectId) { - return nativeGetOwnershipInfo(objectId.getBytes()); + public byte[] promoteAndGetOwnershipInfo(ObjectId objectId) { + return nativePromoteAndGetOwnershipInfo(objectId.getBytes()); } @Override @@ -132,7 +132,7 @@ private static native List nativeWait( private static native byte[] nativeGetOwnerAddress(byte[] objectId); - private static native byte[] nativeGetOwnershipInfo(byte[] objectId); + private static native byte[] nativePromoteAndGetOwnershipInfo(byte[] objectId); private static native void nativeRegisterOwnershipInfoAndResolveFuture( byte[] objectId, byte[] outerObjectId, byte[] ownerAddress); diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java index a352ca22632ef..cb9b35becd02d 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java @@ -63,7 +63,7 @@ public void writeExternal(ObjectOutput out) throws IOException { out.writeObject(this.getId()); out.writeObject(this.getType()); RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal(); - byte[] ownerAddress = runtime.getObjectStore().getOwnershipInfo(this.getId()); + byte[] ownerAddress = runtime.getObjectStore().promoteAndGetOwnershipInfo(this.getId()); out.writeInt(ownerAddress.length); out.write(ownerAddress); ObjectSerializer.addContainedObjectId(this.getId()); diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java index 6db39cc1e4bd6..d61694fab7e93 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java @@ -224,12 +224,12 @@ public WaitResult wait( public abstract Address getOwnerAddress(ObjectId id); /** - * Get the ownership info. + * Promote the given object to the underlying object store, and get the ownership info. * * @param objectId The ID of the object to promote * @return the serialized ownership address */ - public abstract byte[] getOwnershipInfo(ObjectId objectId); + public abstract byte[] promoteAndGetOwnershipInfo(ObjectId objectId); /** * Add a reference to an ObjectID that will deserialized. This will also start the process to diff --git a/java/serve/pom.xml b/java/serve/pom.xml index 7291d4ec79666..d945f8fe83172 100644 --- a/java/serve/pom.xml +++ b/java/serve/pom.xml @@ -27,11 +27,6 @@ ray-runtime ${project.version} - - com.google.code.gson - gson - 2.8.5 - com.google.guava guava @@ -47,16 +42,6 @@ commons-lang3 3.4 - - org.apache.httpcomponents.client5 - httpclient5 - 5.0.3 - - - org.apache.httpcomponents.core5 - httpcore5 - 5.0.2 - org.slf4j slf4j-api diff --git a/java/serve/src/main/java/io/ray/serve/Constants.java b/java/serve/src/main/java/io/ray/serve/Constants.java index 1ca1739f8d734..2d8ac4f702839 100644 --- a/java/serve/src/main/java/io/ray/serve/Constants.java +++ b/java/serve/src/main/java/io/ray/serve/Constants.java @@ -16,10 +16,4 @@ public class Constants { /** Name of controller listen_for_change method. */ public static final String CONTROLLER_LISTEN_FOR_CHANGE_METHOD = "listen_for_change"; - - public static final String SERVE_CONTROLLER_NAME = "SERVE_CONTROLLER_ACTOR"; - - public static final String DEFAULT_CALL_METHOD = "call"; - - public static final String UTF8 = "UTF-8"; } diff --git a/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java b/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java deleted file mode 100644 index 2ab02deeeeaeb..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java +++ /dev/null @@ -1,38 +0,0 @@ -package io.ray.serve; - -import java.io.Serializable; - -public class DeploymentInfo implements Serializable { - - private static final long serialVersionUID = -4198364411759931955L; - - private byte[] backendConfig; - - private ReplicaConfig replicaConfig; - - private byte[] backendVersion; - - public byte[] getBackendConfig() { - return backendConfig; - } - - public void setBackendConfig(byte[] backendConfig) { - this.backendConfig = backendConfig; - } - - public ReplicaConfig getReplicaConfig() { - return replicaConfig; - } - - public void setReplicaConfig(ReplicaConfig replicaConfig) { - this.replicaConfig = replicaConfig; - } - - public byte[] getBackendVersion() { - return backendVersion; - } - - public void setBackendVersion(byte[] backendVersion) { - this.backendVersion = backendVersion; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java b/java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java deleted file mode 100644 index 874a71c26d6db..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java +++ /dev/null @@ -1,12 +0,0 @@ -package io.ray.serve; - -import java.util.concurrent.atomic.AtomicInteger; - -public class DummyBackendReplica { - - private AtomicInteger counter = new AtomicInteger(); - - public String call() { - return String.valueOf(counter.incrementAndGet()); - } -} diff --git a/java/serve/src/main/java/io/ray/serve/HandleOptions.java b/java/serve/src/main/java/io/ray/serve/HandleOptions.java deleted file mode 100644 index e301332976ea3..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/HandleOptions.java +++ /dev/null @@ -1,15 +0,0 @@ -package io.ray.serve; - -/** Options for each ServeHandle instances. These fields are immutable. */ -public class HandleOptions { - - private String methodName = "call"; - - public String getMethodName() { - return methodName; - } - - public void setMethodName(String methodName) { - this.methodName = methodName; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/HttpProxy.java b/java/serve/src/main/java/io/ray/serve/HttpProxy.java deleted file mode 100644 index 809337e75d902..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/HttpProxy.java +++ /dev/null @@ -1,161 +0,0 @@ -package io.ray.serve; - -import com.google.common.collect.ImmutableMap; -import io.ray.api.Ray; -import io.ray.runtime.metric.Count; -import io.ray.runtime.metric.Metrics; -import io.ray.runtime.metric.TagKey; -import io.ray.runtime.serializer.MessagePackSerializer; -import io.ray.serve.util.LogUtil; -import io.ray.serve.util.SocketUtil; -import java.io.IOException; -import java.net.HttpURLConnection; -import java.net.InetAddress; -import java.nio.charset.Charset; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import org.apache.hc.core5.http.ClassicHttpRequest; -import org.apache.hc.core5.http.ClassicHttpResponse; -import org.apache.hc.core5.http.HttpEntity; -import org.apache.hc.core5.http.HttpException; -import org.apache.hc.core5.http.impl.bootstrap.HttpServer; -import org.apache.hc.core5.http.impl.bootstrap.ServerBootstrap; -import org.apache.hc.core5.http.io.HttpRequestHandler; -import org.apache.hc.core5.http.io.entity.ByteArrayEntity; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.hc.core5.http.io.entity.StringEntity; -import org.apache.hc.core5.http.protocol.HttpContext; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class HttpProxy implements ServeProxy { - - private static final Logger LOGGER = LoggerFactory.getLogger(HttpProxy.class); - - public static final String PROXY_NAME = "HTTP_PROXY"; - - public static final String PROXY_HTTP_PORT = "ray.serve.proxy.http.port"; - - public static final String PROXY_HTTP_METHODS = "ray.serve.proxy.http.methods"; - - private int port; - - private Count requestCounter; - - private HttpServer httpServer; - - private ProxyRouter proxyRouter; - - private Object asyncContext = Ray.getAsyncContext(); - - @Override - public void init(Map config, ProxyRouter proxyRouter) { - this.port = - Optional.ofNullable(config) - .map(conf -> conf.get(PROXY_HTTP_PORT)) - .map(httpPort -> Integer.valueOf(httpPort)) - .orElse(SocketUtil.findAvailableTcpPort(8000)); - this.proxyRouter = proxyRouter; - RayServeMetrics.execute( - () -> - this.requestCounter = - Metrics.count() - .name("serve_num_http_requests") - .description("The number of HTTP requests processed.") - .unit("") - .tags(new HashMap<>()) - .register()); - startupHttpServer(port); - LOGGER.info("Proxy {} has been started with port:{}", getName(), this.port); - } - - private void startupHttpServer(int port) { - try { - this.httpServer = - ServerBootstrap.bootstrap() - .setListenerPort(port) - .register("*", new ServeHttpHandler()) - .registerVirtual( - InetAddress.getLocalHost().getHostAddress(), "*", new ServeHttpHandler()) - .create(); - this.httpServer.start(); - } catch (Throwable e) { - String errMsg = - LogUtil.format( - "Proxy {} failed to startup HTTP server on port {}.", getName(), this.port); - LOGGER.error(errMsg); - throw new RayServeException(errMsg, e); - } - } - - @Override - public String getName() { - return PROXY_NAME; - } - - private class ServeHttpHandler implements HttpRequestHandler { - - @Override - public void handle( - ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context) - throws HttpException, IOException { - - Ray.setAsyncContext(asyncContext); - - int code = HttpURLConnection.HTTP_OK; - Object result = null; - String route = request.getPath(); - try { - RayServeMetrics.execute( - () -> - requestCounter.update( - 1.0, - ImmutableMap.of( - new TagKey(RayServeMetrics.TAG_ROUTE), - route))); // TODO the old tag will be covered, it may be a bug. - - Object[] parameters = null; - HttpEntity httpEntity = request.getEntity(); - if (null == httpEntity) { - parameters = new Object[0]; - } else { - byte[] body = EntityUtils.toByteArray(httpEntity); - parameters = MessagePackSerializer.decode(body, Object[].class); - } - - RayServeHandle rayServeHandle = proxyRouter.matchRoute(route); - if (rayServeHandle == null) { - code = HttpURLConnection.HTTP_NOT_FOUND; - } else { - result = rayServeHandle.remote(parameters).get(); - } - - } catch (Throwable e) { - LOGGER.error("HTTP Proxy failed to process request.", e); - code = HttpURLConnection.HTTP_INTERNAL_ERROR; - } finally { - response.setCode(code); - if (code == HttpURLConnection.HTTP_NOT_FOUND) { - response.setEntity( - new StringEntity( - LogUtil.format( - "Path '{}' not found. Please ping http://.../-/routes for route table.", - route), - Charset.forName(Constants.UTF8))); - } else if (result != null) { - response.setEntity( - new ByteArrayEntity(MessagePackSerializer.encode(result).getLeft(), null)); - } - } - } - } - - public int getPort() { - return port; - } - - public ProxyRouter getProxyRouter() { - return proxyRouter; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/ProxyActor.java b/java/serve/src/main/java/io/ray/serve/ProxyActor.java deleted file mode 100644 index ac5d1cf870ea9..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/ProxyActor.java +++ /dev/null @@ -1,175 +0,0 @@ -package io.ray.serve; - -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; -import io.ray.api.BaseActorHandle; -import io.ray.api.Ray; -import io.ray.serve.api.Serve; -import io.ray.serve.generated.EndpointInfo; -import io.ray.serve.generated.EndpointSet; -import io.ray.serve.poll.KeyListener; -import io.ray.serve.poll.KeyType; -import io.ray.serve.poll.LongPollClient; -import io.ray.serve.poll.LongPollNamespace; -import io.ray.serve.util.CollectionUtil; -import io.ray.serve.util.LogUtil; -import io.ray.serve.util.ReflectUtil; -import java.lang.reflect.InvocationTargetException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.ServiceLoader; -import java.util.concurrent.ConcurrentHashMap; -import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class ProxyActor { - - private static final Logger LOGGER = LoggerFactory.getLogger(ProxyActor.class); - - private Map config; - - private Map proxies = new ConcurrentHashMap<>(); - - /** Used only for displaying the route table. Key: route, value: endpoint. */ - private volatile Map routeInfo = new HashMap<>(); - - private LongPollClient longPollClient; - - private ProxyRouter proxyRouter = new ProxyRouter(); - - public ProxyActor(String controllerName, Map config) { - this.config = config; - - // Set the controller name so that serve will connect to the controller instance this proxy is - // running in. - Serve.setInternalReplicaContext(null, null, controllerName, null); - - Optional optional = Ray.getActor(controllerName); - Preconditions.checkState(optional.isPresent(), "Controller does not exist"); - - Map keyListeners = new HashMap<>(); - keyListeners.put( - new KeyType(LongPollNamespace.ROUTE_TABLE, null), endpoints -> updateRoutes(endpoints)); - this.longPollClient = new LongPollClient(optional.get(), keyListeners); - this.longPollClient.start(); - this.run(); - } - - private void run() { - startupProxy(); - registerServiceDiscovery(); - } - - private void startupProxy() { - - List serveProxies = null; - - // Get proxy instances according to class names. - String proxyClassNames = config != null ? config.get(RayServeConfig.PROXY_CLASS) : null; - if (StringUtils.isNotBlank(proxyClassNames)) { - try { - serveProxies = ReflectUtil.getInstancesByClassNames(proxyClassNames, ServeProxy.class); - } catch (ClassNotFoundException - | InstantiationException - | IllegalAccessException - | IllegalArgumentException - | InvocationTargetException - | NoSuchMethodException - | SecurityException e) { - String errorMsg = - LogUtil.format("Failed to initialize proxies by class names : {}", proxyClassNames); - LOGGER.error(errorMsg, e); - throw new RayServeException(errorMsg, e); - } - } - - // Get proxy instances through SPI. - if (CollectionUtil.isEmpty(serveProxies)) { - List spiProxies = new ArrayList<>(); - ServiceLoader serviceLoader = ServiceLoader.load(ServeProxy.class); - serviceLoader.forEach(serveProxy -> spiProxies.add(serveProxy)); - serveProxies = spiProxies; - } - - // Set the default proxy if proxies still empty. - if (CollectionUtil.isEmpty(serveProxies)) { - serveProxies = Lists.newArrayList(new HttpProxy()); - } - - if (!CollectionUtil.isEmpty(serveProxies)) { - for (ServeProxy serveProxy : serveProxies) { - if (proxies.containsKey(serveProxy.getName())) { - String errorMsg = - LogUtil.format( - "Proxy {} name {} is duplicate with proxy {} name {}", - serveProxy.getClass().getName(), - serveProxy.getName(), - proxies.get(serveProxy.getName()).getClass().getName(), - proxies.get(serveProxy.getName()).getName()); - LOGGER.error(errorMsg); - throw new RayServeException(errorMsg); - } - proxies.put(serveProxy.getName(), serveProxy); - serveProxy.init(config, proxyRouter); - LOGGER.info("Proxy actor initialized proxy: {}", serveProxy.getName()); - } - } - } - - public void registerServiceDiscovery() { - proxies.forEach((key, value) -> value.registerServiceDiscovery()); - } - - public void updateRoutes(Object endpoints) { - Map endpointInfos = ((EndpointSet) endpoints).getEndpointsMap(); - Map routeInfo = new HashMap<>(); - if (endpointInfos != null) { - endpointInfos.forEach( - (key, value) -> - routeInfo.put( - StringUtils.isNotBlank(value.getRoute()) ? value.getRoute() : key, value)); - } - this.routeInfo = routeInfo; - this.proxyRouter.updateRoutes(endpointInfos); - } - - public void ready() { - return; - } - - public void blockUntilEndpointExists(String endpoint, double timeoutS) { - long timeoutMs = (long) (timeoutS * 1000); - long startTime = System.currentTimeMillis(); - while (true) { - if (System.currentTimeMillis() - startTime > timeoutMs) { - throw new RayServeException( - LogUtil.format("Waited {} for {} to propagate.", timeoutS, endpoint)); - } - for (EndpointInfo endpointInfo : routeInfo.values()) { - if (StringUtils.equals(endpointInfo.getEndpointName(), endpoint)) { - return; - } - } - try { - Thread.sleep(200); - } catch (InterruptedException e) { - LOGGER.error( - "The sleeping was interrupted when waiting for the endpoint {} being existing.", - endpoint, - e); - } - } - } - - public ProxyRouter getProxyRouter() { - return proxyRouter; - } - - public Map getProxies() { - return proxies; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/ProxyRouter.java b/java/serve/src/main/java/io/ray/serve/ProxyRouter.java deleted file mode 100644 index 041da46bfee08..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/ProxyRouter.java +++ /dev/null @@ -1,72 +0,0 @@ -package io.ray.serve; - -import io.ray.serve.api.Serve; -import io.ray.serve.generated.EndpointInfo; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** Default common router for proxy to match incomming routes. */ -public class ProxyRouter { - - private static final Logger LOGGER = LoggerFactory.getLogger(ProxyRouter.class); - - /** Key: route, value: endpoint. */ - private Map routeInfo = new HashMap<>(); - - /** Key: endpointName, value: handle. */ - private Map handles = new ConcurrentHashMap<>(); - - public void updateRoutes(Map endpoints) { - LOGGER.info("Got updated endpoints: {}.", endpoints); - - Set existingHandles = new HashSet<>(handles.keySet()); - Map routeInfo = new HashMap<>(); - - if (endpoints != null) { - for (Map.Entry entry : endpoints.entrySet()) { - String route = - StringUtils.isNotBlank(entry.getValue().getRoute()) - ? entry.getValue().getRoute() - : entry.getKey(); - routeInfo.put(route, entry.getValue()); - - if (handles.containsKey(entry.getKey())) { - existingHandles.remove(entry.getKey()); - } else { - handles.put(entry.getKey(), Serve.getGlobalClient().getHandle(entry.getKey(), true)); - } - } - } - - this.routeInfo = routeInfo; - for (String endpoint : existingHandles) { - handles.remove(endpoint); - } - LOGGER.info("The final route info: {}.", routeInfo); - } - - /** - * Return the longest prefix match among existing routes for the route. - * - * @param route route to match against. - * @return serve_handle (RayServeHandle) if found, else null. - */ - public RayServeHandle matchRoute(String route) { - EndpointInfo endpointInfo = routeInfo.get(route); - return endpointInfo == null ? null : handles.get(endpointInfo.getEndpointName()); - } - - public Map getRouteInfo() { - return routeInfo; - } - - public Map getHandles() { - return handles; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeConfig.java b/java/serve/src/main/java/io/ray/serve/RayServeConfig.java deleted file mode 100644 index 5762aae40be4e..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/RayServeConfig.java +++ /dev/null @@ -1,6 +0,0 @@ -package io.ray.serve; - -public class RayServeConfig { - - public static final String PROXY_CLASS = "ray.serve.proxy.class"; -} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeHandle.java b/java/serve/src/main/java/io/ray/serve/RayServeHandle.java deleted file mode 100644 index abcf6ac5abdf2..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/RayServeHandle.java +++ /dev/null @@ -1,73 +0,0 @@ -package io.ray.serve; - -import com.google.common.collect.ImmutableMap; -import io.ray.api.BaseActorHandle; -import io.ray.api.ObjectRef; -import io.ray.runtime.metric.Count; -import io.ray.runtime.metric.Metrics; -import io.ray.serve.generated.RequestMetadata; -import org.apache.commons.lang3.RandomStringUtils; - -public class RayServeHandle { - - private String endpointName; - - private HandleOptions handleOptions; - - private String handleTag; - - private Count requestCounter; - - private Router router; - - public RayServeHandle( - BaseActorHandle controllerHandle, - String endpointName, - HandleOptions handleOptions, - Router router) { - this.endpointName = endpointName; - this.handleOptions = handleOptions != null ? handleOptions : new HandleOptions(); - this.handleTag = endpointName + "#" + RandomStringUtils.randomAlphabetic(6); - this.router = router != null ? router : new Router(controllerHandle, endpointName); - RayServeMetrics.execute( - () -> - this.requestCounter = - Metrics.count() - .name(RayServeMetrics.SERVE_HANDLE_REQUEST_COUNTER.name()) - .description(RayServeMetrics.SERVE_HANDLE_REQUEST_COUNTER.getDescription()) - .unit("") - .tags( - ImmutableMap.of( - RayServeMetrics.TAG_HANDLE, - handleTag, - RayServeMetrics.TAG_ENDPOINT, - endpointName)) - .register()); - } - - /** - * Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get - * (or ``await object_ref``), respectively. - * - * @param parameters The input parameters of the specified method to invoke on the backend. - * @return ray.ObjectRef - */ - public ObjectRef remote(Object[] parameters) { - RayServeMetrics.execute(() -> requestCounter.inc(1.0)); - RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); - requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); - requestMetadata.setEndpoint(endpointName); - requestMetadata.setCallMethod( - handleOptions != null ? handleOptions.getMethodName() : Constants.DEFAULT_CALL_METHOD); - return router.assignRequest(requestMetadata.build(), parameters); - } - - public RayServeHandle setMethodName(String methodName) { - handleOptions.setMethodName(methodName); - return this; - } - - public Router getRouter() { - return router; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeMetrics.java b/java/serve/src/main/java/io/ray/serve/RayServeMetrics.java deleted file mode 100644 index f7b1fac730da9..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/RayServeMetrics.java +++ /dev/null @@ -1,74 +0,0 @@ -package io.ray.serve; - -import io.ray.api.Ray; - -public enum RayServeMetrics { - SERVE_HANDLE_REQUEST_COUNTER( - "serve_handle_request_counter", - "The number of handle.remote() calls that have been made on this handle."), - - SERVE_NUM_ROUTER_REQUESTS( - "serve_num_router_requests", "The number of requests processed by the router."), - - SERVE_DEPLOYMENT_QUEUED_QUERIES( - "serve_deployment_queued_queries", - "The current number of queries to this deployment waiting to be assigned to a replica."), - - SERVE_BACKEND_REQUEST_COUNTER( - "serve_backend_request_counter", - "The number of queries that have been processed in this replica."), - - SERVE_BACKEND_ERROR_COUNTER( - "serve_backend_error_counter", - "The number of exceptions that have occurred in this replica."), - - SERVE_BACKEND_REPLICA_STARTS( - "serve_backend_replica_starts", - "The number of times this replica has been restarted due to failure."), - - SERVE_BACKEND_PROCESSING_LATENCY_MS( - "serve_backend_processing_latency_ms", "The latency for queries to be processed."), - - SERVE_REPLICA_PROCESSING_QUERIES( - "serve_replica_processing_queries", "The current number of queries being processed."), - ; - - public static final String TAG_HANDLE = "handle"; - - public static final String TAG_ENDPOINT = "endpoint"; - - public static final String TAG_DEPLOYMENT = "deployment"; - - public static final String TAG_ROUTE = "route"; - - public static final String TAG_BACKEND = "backend"; - - public static final String TAG_REPLICA = "replica"; - - private static final boolean isMetricsEnabled = - Ray.isInitialized() && !Ray.getRuntimeContext().isSingleProcess(); - - private String name; - - private String description; - - private RayServeMetrics(String name, String description) { - this.name = name; - this.description = description; - } - - public static void execute(Runnable runnable) { - if (!isMetricsEnabled) { - return; - } - runnable.run(); - } - - public String getName() { - return name; - } - - public String getDescription() { - return description; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeReplica.java index 259c8555cf3e4..9949115fbbd72 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeReplica.java @@ -1,16 +1,16 @@ package io.ray.serve; import com.google.common.collect.ImmutableMap; -import com.google.protobuf.ByteString; import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; import io.ray.runtime.metric.Count; import io.ray.runtime.metric.Gauge; import io.ray.runtime.metric.Histogram; +import io.ray.runtime.metric.MetricConfig; import io.ray.runtime.metric.Metrics; import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendVersion; import io.ray.serve.generated.RequestWrapper; import io.ray.serve.poll.KeyListener; import io.ray.serve.poll.KeyType; @@ -18,6 +18,7 @@ import io.ray.serve.poll.LongPollNamespace; import io.ray.serve.util.LogUtil; import io.ray.serve.util.ReflectUtil; +import io.ray.serve.util.ServeProtoUtil; import java.lang.reflect.Method; import java.util.HashMap; import java.util.Map; @@ -40,6 +41,8 @@ public class RayServeReplica { private Object callable; + private boolean metricsRegistered = false; + private Count requestCounter; private Count errorCounter; @@ -52,20 +55,13 @@ public class RayServeReplica { private LongPollClient longPollClient; - private BackendVersion version; - - private boolean isDeleted = false; - public RayServeReplica( - Object callable, - BackendConfig backendConfig, - BackendVersion version, - BaseActorHandle actorHandle) { + Object callable, BackendConfig backendConfig, BaseActorHandle actorHandle) { this.backendTag = Serve.getReplicaContext().getBackendTag(); this.replicaTag = Serve.getReplicaContext().getReplicaTag(); this.callable = callable; this.config = backendConfig; - this.version = version; + this.reconfigure(ServeProtoUtil.parseUserConfig(backendConfig)); Map keyListeners = new HashMap<>(); keyListeners.put( @@ -77,84 +73,55 @@ public RayServeReplica( } private void registerMetrics() { - RayServeMetrics.execute( - () -> - requestCounter = - Metrics.count() - .name(RayServeMetrics.SERVE_BACKEND_REQUEST_COUNTER.getName()) - .description(RayServeMetrics.SERVE_BACKEND_REQUEST_COUNTER.getDescription()) - .unit("") - .tags( - ImmutableMap.of( - RayServeMetrics.TAG_BACKEND, - backendTag, - RayServeMetrics.TAG_REPLICA, - replicaTag)) - .register()); - - RayServeMetrics.execute( - () -> - errorCounter = - Metrics.count() - .name(RayServeMetrics.SERVE_BACKEND_ERROR_COUNTER.getName()) - .description(RayServeMetrics.SERVE_BACKEND_ERROR_COUNTER.getDescription()) - .unit("") - .tags( - ImmutableMap.of( - RayServeMetrics.TAG_BACKEND, - backendTag, - RayServeMetrics.TAG_REPLICA, - replicaTag)) - .register()); - - RayServeMetrics.execute( - () -> - restartCounter = - Metrics.count() - .name(RayServeMetrics.SERVE_BACKEND_REPLICA_STARTS.getName()) - .description(RayServeMetrics.SERVE_BACKEND_REPLICA_STARTS.getDescription()) - .unit("") - .tags( - ImmutableMap.of( - RayServeMetrics.TAG_BACKEND, - backendTag, - RayServeMetrics.TAG_REPLICA, - replicaTag)) - .register()); - - RayServeMetrics.execute( - () -> - processingLatencyTracker = - Metrics.histogram() - .name(RayServeMetrics.SERVE_BACKEND_PROCESSING_LATENCY_MS.getName()) - .description( - RayServeMetrics.SERVE_BACKEND_PROCESSING_LATENCY_MS.getDescription()) - .unit("") - .boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS) - .tags( - ImmutableMap.of( - RayServeMetrics.TAG_BACKEND, - backendTag, - RayServeMetrics.TAG_REPLICA, - replicaTag)) - .register()); - - RayServeMetrics.execute( - () -> - numProcessingItems = - Metrics.gauge() - .name(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getName()) - .description(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getDescription()) - .unit("") - .tags( - ImmutableMap.of( - RayServeMetrics.TAG_BACKEND, - backendTag, - RayServeMetrics.TAG_REPLICA, - replicaTag)) - .register()); - - RayServeMetrics.execute(() -> restartCounter.inc(1.0)); + if (!Ray.isInitialized() || Ray.getRuntimeContext().isSingleProcess()) { + return; + } + + Metrics.init(MetricConfig.DEFAULT_CONFIG); + requestCounter = + Metrics.count() + .name("serve_backend_request_counter") + .description("The number of queries that have been processed in this replica.") + .unit("") + .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) + .register(); + + errorCounter = + Metrics.count() + .name("serve_backend_error_counter") + .description("The number of exceptions that have occurred in this replica.") + .unit("") + .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) + .register(); + + restartCounter = + Metrics.count() + .name("serve_backend_replica_starts") + .description("The number of times this replica has been restarted due to failure.") + .unit("") + .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) + .register(); + + processingLatencyTracker = + Metrics.histogram() + .name("serve_backend_processing_latency_ms") + .description("The latency for queries to be processed.") + .unit("") + .boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS) + .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) + .register(); + + numProcessingItems = + Metrics.gauge() + .name("serve_replica_processing_queries") + .description("The current number of queries being processed.") + .unit("") + .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) + .register(); + + metricsRegistered = true; + + restartCounter.inc(1.0); } public Object handleRequest(Query request) { @@ -163,7 +130,7 @@ public Object handleRequest(Query request) { "Replica {} received request {}", replicaTag, request.getMetadata().getRequestId()); numOngoingRequests.incrementAndGet(); - RayServeMetrics.execute(() -> numProcessingItems.update(numOngoingRequests.get())); + reportMetrics(() -> numProcessingItems.update(numOngoingRequests.get())); Object result = invokeSingle(request); numOngoingRequests.decrementAndGet(); @@ -190,10 +157,10 @@ private Object invokeSingle(Query requestItem) { Object[] args = parseRequestItem(requestItem); methodToCall = getRunnerMethod(requestItem.getMetadata().getCallMethod(), args); Object result = methodToCall.invoke(callable, args); - RayServeMetrics.execute(() -> requestCounter.inc(1.0)); + reportMetrics(() -> requestCounter.inc(1.0)); return result; } catch (Throwable e) { - RayServeMetrics.execute(() -> errorCounter.inc(1.0)); + reportMetrics(() -> errorCounter.inc(1.0)); throw new RayServeException( LogUtil.format( "Replica {} failed to invoke method {}", @@ -201,8 +168,7 @@ private Object invokeSingle(Query requestItem) { methodToCall == null ? "unknown" : methodToCall.getName()), e); } finally { - RayServeMetrics.execute( - () -> processingLatencyTracker.update(System.currentTimeMillis() - start)); + reportMetrics(() -> processingLatencyTracker.update(System.currentTimeMillis() - start)); } } @@ -243,12 +209,10 @@ private Method getRunnerMethod(String methodName, Object[] args) { * Perform graceful shutdown. Trigger a graceful shutdown protocol that will wait for all the * queued tasks to be completed and return to the controller. */ - public synchronized boolean prepareForShutdown() { + public void drainPendingQueries() { while (true) { - // Sleep first because we want to make sure all the routers receive the notification to remove - // this replica first. try { - Thread.sleep((long) (config.getGracefulShutdownWaitLoopS() * 1000)); + Thread.sleep((long) (config.getExperimentalGracefulShutdownWaitLoopS() * 1000)); } catch (InterruptedException e) { LOGGER.error( "Replica {} was interrupted in sheep when draining pending queries", replicaTag); @@ -256,27 +220,13 @@ public synchronized boolean prepareForShutdown() { if (numOngoingRequests.get() == 0) { break; } else { - LOGGER.info( + LOGGER.debug( "Waiting for an additional {}s to shut down because there are {} ongoing requests.", - config.getGracefulShutdownWaitLoopS(), + config.getExperimentalGracefulShutdownWaitLoopS(), numOngoingRequests.get()); } } - - // Explicitly call the del method to trigger clean up. We set isDeleted = true after - // succssifully calling it so the destructor is called only once. - try { - if (!isDeleted) { - ReflectUtil.getMethod(callable.getClass(), "del").invoke(callable); - } - } catch (NoSuchMethodException e) { - LOGGER.warn("Deployment {} has no del method.", backendTag); - } catch (Throwable e) { - LOGGER.error("Exception during graceful shutdown of replica."); - } finally { - isDeleted = true; - } - return true; + Ray.exitActor(); } /** @@ -284,34 +234,28 @@ public synchronized boolean prepareForShutdown() { * * @param userConfig new user's configuration */ - public BackendVersion reconfigure(Object userConfig) { - BackendVersion.Builder builder = BackendVersion.newBuilder(); - builder.setCodeVersion(version.getCodeVersion()); - if (userConfig != null) { - builder.setUserConfig(ByteString.copyFrom((byte[]) userConfig)); + private void reconfigure(Object userConfig) { + if (userConfig == null) { + return; } - version = builder.build(); - try { Method reconfigureMethod = ReflectUtil.getMethod( callable.getClass(), Constants.BACKEND_RECONFIGURE_METHOD, - userConfig != null - ? MessagePackSerializer.decode((byte[]) userConfig, Object[].class) - : new Object[0]); // TODO cache reconfigure method + userConfig); // TODO cache reconfigureMethod reconfigureMethod.invoke(callable, userConfig); } catch (NoSuchMethodException e) { - LOGGER.warn( - "user_config specified but backend {} missing {} method", - backendTag, - Constants.BACKEND_RECONFIGURE_METHOD); + throw new RayServeException( + LogUtil.format( + "user_config specified but backend {} missing {} method", + backendTag, + Constants.BACKEND_RECONFIGURE_METHOD)); } catch (Throwable e) { throw new RayServeException( LogUtil.format("Backend {} failed to reconfigure user_config {}", backendTag, userConfig), e); } - return version; } /** @@ -321,9 +265,12 @@ public BackendVersion reconfigure(Object userConfig) { */ private void updateBackendConfigs(Object newConfig) { config = (BackendConfig) newConfig; + reconfigure(((BackendConfig) newConfig).getUserConfig()); } - public BackendVersion getVersion() { - return version; + private void reportMetrics(Runnable runnable) { + if (metricsRegistered) { + runnable.run(); + } } } diff --git a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java index 53e0854044c71..9ccc6c6f7a448 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java @@ -7,7 +7,6 @@ import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendVersion; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.util.ReflectUtil; import io.ray.serve.util.ServeProtoUtil; @@ -28,7 +27,6 @@ public RayServeWrappedReplica( String backendDef, byte[] initArgsbytes, byte[] backendConfigBytes, - byte[] backendVersionBytes, String controllerName) throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException { @@ -54,26 +52,7 @@ public RayServeWrappedReplica( Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, callable); // Construct worker replica. - backend = - new RayServeReplica( - callable, - backendConfig, - ServeProtoUtil.parseBackendVersion(backendVersionBytes), - optional.get()); - } - - public RayServeWrappedReplica( - String backendTag, String replicaTag, DeploymentInfo deploymentInfo, String controllerName) - throws ClassNotFoundException, NoSuchMethodException, InstantiationException, - IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException { - this( - backendTag, - replicaTag, - deploymentInfo.getReplicaConfig().getBackendDef(), - deploymentInfo.getReplicaConfig().getInitArgs(), - deploymentInfo.getBackendConfig(), - deploymentInfo.getBackendVersion(), - controllerName); + backend = new RayServeReplica(callable, backendConfig, optional.get()); } private Object[] parseInitArgs(byte[] initArgsbytes, BackendConfig backendConfig) @@ -122,21 +101,8 @@ public void ready() { return; } - /** - * Wait until there is no request in processing. It is used for stopping replica gracefully. - * - * @return true if it is ready for shutdown. - */ - public boolean prepareForShutdown() { - return backend.prepareForShutdown(); - } - - public byte[] reconfigure(Object userConfig) { - BackendVersion backendVersion = backend.reconfigure(userConfig); - return backendVersion.toByteArray(); - } - - public byte[] getVersion() { - return backend.getVersion().toByteArray(); + /** Wait until there is no request in processing. It is used for stopping replica gracefully. */ + public void drainPendingQueries() { + backend.drainPendingQueries(); } } diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java b/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java index a24ceea124963..ff19348098027 100644 --- a/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java +++ b/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java @@ -12,13 +12,13 @@ public class ReplicaConfig implements Serializable { private String backendDef; - private byte[] initArgs; + private Object[] initArgs; private Map rayActorOptions; private Map resource; - public ReplicaConfig(String backendDef, byte[] initArgs, Map rayActorOptions) { + public ReplicaConfig(String backendDef, Object[] initArgs, Map rayActorOptions) { this.backendDef = backendDef; this.initArgs = initArgs; this.rayActorOptions = rayActorOptions; @@ -89,11 +89,11 @@ public void setBackendDef(String backendDef) { this.backendDef = backendDef; } - public byte[] getInitArgs() { + public Object[] getInitArgs() { return initArgs; } - public void setInitArgs(byte[] initArgs) { + public void setInitArgs(Object[] initArgs) { this.initArgs = initArgs; } diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaContext.java b/java/serve/src/main/java/io/ray/serve/ReplicaContext.java index 7bd768f7cdd53..10c62cf7eb411 100644 --- a/java/serve/src/main/java/io/ray/serve/ReplicaContext.java +++ b/java/serve/src/main/java/io/ray/serve/ReplicaContext.java @@ -3,7 +3,7 @@ /** Stores data for Serve API calls from within the user's backend code. */ public class ReplicaContext { - private String backendTag; // TODO deployment + private String backendTag; private String replicaTag; diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaSet.java b/java/serve/src/main/java/io/ray/serve/ReplicaSet.java deleted file mode 100644 index 1c7e757bba449..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/ReplicaSet.java +++ /dev/null @@ -1,138 +0,0 @@ -package io.ray.serve; - -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Sets; -import io.ray.api.ActorHandle; -import io.ray.api.ObjectRef; -import io.ray.api.Ray; -import io.ray.runtime.metric.Gauge; -import io.ray.runtime.metric.Metrics; -import io.ray.runtime.metric.TagKey; -import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.util.CollectionUtil; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; -import org.apache.commons.lang3.RandomUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** Data structure representing a set of replica actor handles. */ -public class ReplicaSet { - - private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class); - - private volatile int maxConcurrentQueries = 8; - - private final Map, Set>> inFlightQueries; - - private AtomicInteger numQueuedQueries = new AtomicInteger(); - - private Gauge numQueuedQueriesGauge; - - public ReplicaSet(String backendTag) { - this.inFlightQueries = new ConcurrentHashMap<>(); - RayServeMetrics.execute( - () -> - this.numQueuedQueriesGauge = - Metrics.gauge() - .name(RayServeMetrics.SERVE_DEPLOYMENT_QUEUED_QUERIES.getName()) - .description(RayServeMetrics.SERVE_DEPLOYMENT_QUEUED_QUERIES.getDescription()) - .unit("") - .tags(ImmutableMap.of(RayServeMetrics.TAG_DEPLOYMENT, backendTag)) - .register()); - } - - public void setMaxConcurrentQueries(Object backendConfig) { - int newValue = ((BackendConfig) backendConfig).getMaxConcurrentQueries(); - if (newValue != this.maxConcurrentQueries) { - this.maxConcurrentQueries = newValue; - LOGGER.info("ReplicaSet: changing max_concurrent_queries to {}", newValue); - } - } - - public int getMaxConcurrentQueries() { - return maxConcurrentQueries; - } - - @SuppressWarnings("unchecked") - public synchronized void updateWorkerReplicas(Object actorSet) { - List actorNames = ((ActorSet) actorSet).getNamesList(); - Set> workerReplicas = new HashSet<>(); - if (!CollectionUtil.isEmpty(actorNames)) { - actorNames.forEach( - name -> - workerReplicas.add((ActorHandle) Ray.getActor(name).get())); - } - - Set> added = - new HashSet<>(Sets.difference(workerReplicas, inFlightQueries.keySet())); - Set> removed = - new HashSet<>(Sets.difference(inFlightQueries.keySet(), workerReplicas)); - - added.forEach(actorHandle -> inFlightQueries.put(actorHandle, Sets.newConcurrentHashSet())); - removed.forEach(actorHandle -> inFlightQueries.remove(actorHandle)); - - if (added.size() > 0 || removed.size() > 0) { - LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size()); - } - } - - /** - * Given a query, submit it to a replica and return the object ref. This method will keep track of - * the in flight queries for each replicas and only send a query to available replicas (determined - * by the backend max_concurrent_quries value.) - * - * @param query the incoming query. - * @return ray.ObjectRef - */ - public ObjectRef assignReplica(Query query) { - String endpoint = query.getMetadata().getEndpoint(); - numQueuedQueries.incrementAndGet(); - RayServeMetrics.execute( - () -> - numQueuedQueriesGauge.update( - numQueuedQueries.get(), - TagKey.tagsFromMap(ImmutableMap.of(RayServeMetrics.TAG_ENDPOINT, endpoint)))); - ObjectRef assignedRef = - tryAssignReplica(query); // TODO controll concurrency using maxConcurrentQueries - numQueuedQueries.decrementAndGet(); - RayServeMetrics.execute( - () -> - numQueuedQueriesGauge.update( - numQueuedQueries.get(), - TagKey.tagsFromMap(ImmutableMap.of(RayServeMetrics.TAG_ENDPOINT, endpoint)))); - return assignedRef; - } - - /** - * Try to assign query to a replica, return the object ref if succeeded or return None if it can't - * assign this query to any replicas. - * - * @param query query the incoming query. - * @return ray.ObjectRef - */ - private ObjectRef tryAssignReplica(Query query) { - - List> handles = new ArrayList<>(inFlightQueries.keySet()); - if (CollectionUtil.isEmpty(handles)) { - throw new RayServeException("ReplicaSet found no replica."); - } - int randomIndex = RandomUtils.nextInt(0, handles.size()); - ActorHandle replica = - handles.get(randomIndex); // TODO controll concurrency using maxConcurrentQueries - LOGGER.debug("Assigned query {} to replica {}.", query.getMetadata().getRequestId(), replica); - return replica - .task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs()) - .remote(); - } - - public Map, Set>> getInFlightQueries() { - return inFlightQueries; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/Router.java b/java/serve/src/main/java/io/ray/serve/Router.java deleted file mode 100644 index 5ef339d77767c..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/Router.java +++ /dev/null @@ -1,64 +0,0 @@ -package io.ray.serve; - -import com.google.common.collect.ImmutableMap; -import io.ray.api.BaseActorHandle; -import io.ray.api.ObjectRef; -import io.ray.runtime.metric.Count; -import io.ray.runtime.metric.Metrics; -import io.ray.serve.generated.RequestMetadata; -import io.ray.serve.poll.KeyListener; -import io.ray.serve.poll.KeyType; -import io.ray.serve.poll.LongPollClient; -import io.ray.serve.poll.LongPollNamespace; -import java.util.HashMap; -import java.util.Map; - -/** Router process incoming queries: choose backend, and assign replica. */ -public class Router { - - private ReplicaSet replicaSet; - - private Count numRouterRequests; - - private LongPollClient longPollClient; - - public Router(BaseActorHandle controllerHandle, String backendTag) { - this.replicaSet = new ReplicaSet(backendTag); - - RayServeMetrics.execute( - () -> - this.numRouterRequests = - Metrics.count() - .name(RayServeMetrics.SERVE_NUM_ROUTER_REQUESTS.getName()) - .description(RayServeMetrics.SERVE_NUM_ROUTER_REQUESTS.getDescription()) - .unit("") - .tags(ImmutableMap.of(RayServeMetrics.TAG_DEPLOYMENT, backendTag)) - .register()); - - Map keyListeners = new HashMap<>(); - keyListeners.put( - new KeyType(LongPollNamespace.BACKEND_CONFIGS, backendTag), - backendConfig -> replicaSet.setMaxConcurrentQueries(backendConfig)); // cross language - keyListeners.put( - new KeyType(LongPollNamespace.REPLICA_HANDLES, backendTag), - workerReplicas -> replicaSet.updateWorkerReplicas(workerReplicas)); // cross language - this.longPollClient = new LongPollClient(controllerHandle, keyListeners); - this.longPollClient.start(); - } - - /** - * Assign a query and returns an object ref represent the result. - * - * @param requestMetadata the metadata of incoming queries. - * @param requestArgs the request body of incoming queries. - * @return ray.ObjectRef - */ - public ObjectRef assignRequest(RequestMetadata requestMetadata, Object[] requestArgs) { - RayServeMetrics.execute(() -> numRouterRequests.inc(1.0)); - return replicaSet.assignReplica(new Query(requestMetadata, requestArgs)); - } - - public ReplicaSet getReplicaSet() { - return replicaSet; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/ServeController.java b/java/serve/src/main/java/io/ray/serve/ServeController.java deleted file mode 100644 index 1589f4c73b4c2..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/ServeController.java +++ /dev/null @@ -1,6 +0,0 @@ -package io.ray.serve; - -public interface ServeController { - - byte[] getAllEndpoints(); -} diff --git a/java/serve/src/main/java/io/ray/serve/ServeProxy.java b/java/serve/src/main/java/io/ray/serve/ServeProxy.java deleted file mode 100644 index 532a2413f9ba5..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/ServeProxy.java +++ /dev/null @@ -1,14 +0,0 @@ -package io.ray.serve; - -import java.util.Map; - -public interface ServeProxy { - - void init(Map config, ProxyRouter proxyRouter); - - default String getName() { - return getClass().getName(); - } - - default void registerServiceDiscovery() {} -} diff --git a/java/serve/src/main/java/io/ray/serve/api/Client.java b/java/serve/src/main/java/io/ray/serve/api/Client.java deleted file mode 100644 index e5c63b5c8e184..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/api/Client.java +++ /dev/null @@ -1,72 +0,0 @@ -package io.ray.serve.api; - -import io.ray.api.ActorHandle; -import io.ray.api.BaseActorHandle; -import io.ray.api.PyActorHandle; -import io.ray.api.function.PyActorMethod; -import io.ray.serve.RayServeException; -import io.ray.serve.RayServeHandle; -import io.ray.serve.ServeController; -import io.ray.serve.generated.EndpointInfo; -import io.ray.serve.util.LogUtil; -import io.ray.serve.util.ServeProtoUtil; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class Client { - - private static final Logger LOGGER = LoggerFactory.getLogger(Client.class); - - private BaseActorHandle controller; - - private Map handleCache = new ConcurrentHashMap<>(); - - public Client(BaseActorHandle controller, String controllerName, boolean detached) { - this.controller = controller; - } - - /** - * Retrieve RayServeHandle for service endpoint to invoke it from Python. - * - * @param endpointName A registered service endpoint. - * @param missingOk If true, then Serve won't check the endpoint is registered. False by default. - * @return - */ - @SuppressWarnings("unchecked") - public RayServeHandle getHandle(String endpointName, boolean missingOk) { - - String cacheKey = endpointName + "_" + missingOk; - if (handleCache.containsKey(cacheKey)) { - return handleCache.get(cacheKey); - } - - Map endpoints = null; - if (controller instanceof PyActorHandle) { - endpoints = - ServeProtoUtil.parseEndpointSet( - (byte[]) - ((PyActorHandle) controller) - .task(PyActorMethod.of("get_all_endpoints")) - .remote() - .get()); - } else { - LOGGER.warn("Client only support Python controller now."); - endpoints = - ServeProtoUtil.parseEndpointSet( - ((ActorHandle) controller) - .task(ServeController::getAllEndpoints) - .remote() - .get()); - } - - if (!missingOk && (endpoints == null || !endpoints.containsKey(endpointName))) { - throw new RayServeException(LogUtil.format("Endpoint {} does not exist.", endpointName)); - } - - RayServeHandle handle = new RayServeHandle(controller, endpointName, null, null); - handleCache.put(cacheKey, handle); - return handle; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/api/Serve.java b/java/serve/src/main/java/io/ray/serve/api/Serve.java index 3b2c0ed7a2833..8133e5bd7f23e 100644 --- a/java/serve/src/main/java/io/ray/serve/api/Serve.java +++ b/java/serve/src/main/java/io/ray/serve/api/Serve.java @@ -1,20 +1,12 @@ package io.ray.serve.api; -import com.google.common.base.Preconditions; -import io.ray.api.BaseActorHandle; -import io.ray.api.Ray; -import io.ray.serve.Constants; import io.ray.serve.RayServeException; import io.ray.serve.ReplicaContext; -import io.ray.serve.util.LogUtil; -import java.util.Optional; /** Ray Serve global API. TODO: will be riched in the Java SDK/API PR. */ public class Serve { - private static ReplicaContext INTERNAL_REPLICA_CONTEXT; - - private static Client GLOBAL_CLIENT; + public static ReplicaContext INTERNAL_REPLICA_CONTEXT; /** * Set replica information to global context. @@ -26,14 +18,11 @@ public class Serve { */ public static void setInternalReplicaContext( String backendTag, String replicaTag, String controllerName, Object servableObject) { + // TODO singleton. INTERNAL_REPLICA_CONTEXT = new ReplicaContext(backendTag, replicaTag, controllerName, servableObject); } - public static void setInternalReplicaContext(ReplicaContext replicaContext) { - INTERNAL_REPLICA_CONTEXT = replicaContext; - } - /** * Get the global replica context. * @@ -46,43 +35,4 @@ public static ReplicaContext getReplicaContext() { } return INTERNAL_REPLICA_CONTEXT; } - - public static Client getGlobalClient() { - if (GLOBAL_CLIENT != null) { - return GLOBAL_CLIENT; - } - synchronized (Client.class) { - if (GLOBAL_CLIENT != null) { - return GLOBAL_CLIENT; - } - return connect(); - } - } - - public static void setGlobalClient(Client client) { - GLOBAL_CLIENT = client; - } - - public static Client connect() { - - if (!Ray.isInitialized()) { - Ray.init(); - } - - String controllerName = - INTERNAL_REPLICA_CONTEXT != null - ? INTERNAL_REPLICA_CONTEXT.getInternalControllerName() - : Constants.SERVE_CONTROLLER_NAME; - - Optional optional = Ray.getActor(controllerName); - Preconditions.checkState( - optional.isPresent(), - LogUtil.format( - "There is no instance running on this Ray cluster. " - + "Please call `serve.start(detached=True) to start one.")); - - Client client = new Client(optional.get(), controllerName, true); - setGlobalClient(client); - return client; - } } diff --git a/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java b/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java index 514193e28c37d..91e9ceca04723 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java +++ b/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java @@ -4,5 +4,5 @@ @FunctionalInterface public interface KeyListener { - void notifyChanged(Object updatedObject); + void notifyChanged(Object object); } diff --git a/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java b/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java index 308391254e109..4017be3af9db9 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java +++ b/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java @@ -1,7 +1,6 @@ package io.ray.serve.poll; import com.google.common.base.Preconditions; -import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.BaseActorHandle; import io.ray.api.ObjectRef; import io.ray.api.PyActorHandle; @@ -9,16 +8,8 @@ import io.ray.runtime.exception.RayActorException; import io.ray.runtime.exception.RayTaskException; import io.ray.serve.Constants; -import io.ray.serve.RayServeException; -import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.UpdatedObject; -import io.ray.serve.util.LogUtil; -import io.ray.serve.util.ServeProtoUtil; -import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Function; -import org.apache.commons.lang3.builder.ReflectionToStringBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,26 +33,6 @@ public class LongPollClient { /** An async thread to post the callback into. */ private Thread pollThread; - private static final Map> DESERIALIZERS = - new HashMap<>(); - - static { - DESERIALIZERS.put( - LongPollNamespace.BACKEND_CONFIGS, body -> ServeProtoUtil.parseBackendConfig(body)); - DESERIALIZERS.put( - LongPollNamespace.REPLICA_HANDLES, body -> ServeProtoUtil.parseEndpointSet(body)); - DESERIALIZERS.put( - LongPollNamespace.REPLICA_HANDLES, - body -> { - try { - return ActorSet.parseFrom(body); - } catch (InvalidProtocolBufferException e) { - throw new RayServeException( - LogUtil.format("Failed to parse ActorSet from protobuf bytes."), e); - } - }); - } - public LongPollClient(BaseActorHandle hostActor, Map keyListeners) { Preconditions.checkArgument(keyListeners != null && keyListeners.size() != 0); @@ -80,7 +51,7 @@ public LongPollClient(BaseActorHandle hostActor, Map keyLi try { pollNext(); } catch (RayActorException e) { - LOGGER.error("LongPollClient failed to connect to host. Shutting down."); + LOGGER.debug("LongPollClient failed to connect to host. Shutting down."); break; } catch (RayTaskException e) { LOGGER.error("LongPollHost errored", e); @@ -100,44 +71,24 @@ public void start() { pollThread.start(); } - /** - * Poll the update. - * - * @throws InvalidProtocolBufferException if the protobuf deserialization fails. - */ - public void pollNext() throws InvalidProtocolBufferException { + /** Poll the update. */ + @SuppressWarnings("unchecked") + public void pollNext() { currentRef = ((PyActorHandle) hostActor) .task(PyActorMethod.of(Constants.CONTROLLER_LISTEN_FOR_CHANGE_METHOD), snapshotIds) .remote(); - processUpdate(ServeProtoUtil.parseUpdatedObjects((byte[]) currentRef.get())); + processUpdate((Map) currentRef.get()); } public void processUpdate(Map updates) { - if (updates == null || updates.isEmpty()) { - LOGGER.info("LongPollClient received nothing."); - return; - } - LOGGER.info("LongPollClient received updates for keys: {}", updates.keySet()); + + LOGGER.debug("LongPollClient received updates for keys: {}", updates.keySet()); + for (Map.Entry entry : updates.entrySet()) { - KeyType keyType = entry.getKey(); - UpdatedObject updatedObject = entry.getValue(); - - Object objectSnapshot = - DESERIALIZERS - .get(keyType.getLongPollNamespace()) - .apply(updatedObject.getObjectSnapshot().toByteArray()); - - if (LOGGER.isDebugEnabled()) { - LOGGER.debug( - "The updated object for key {} is {}", - keyType, - ReflectionToStringBuilder.toString(objectSnapshot)); - } - - keyListeners.get(entry.getKey()).notifyChanged(objectSnapshot); - objectSnapshots.put(entry.getKey(), objectSnapshot); + objectSnapshots.put(entry.getKey(), entry.getValue().getObjectSnapshot()); snapshotIds.put(entry.getKey(), entry.getValue().getSnapshotId()); + keyListeners.get(entry.getKey()).notifyChanged(entry.getValue().getObjectSnapshot()); } } diff --git a/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java b/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java index 71b3a2e8baa1e..466af829167e8 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java +++ b/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java @@ -4,7 +4,9 @@ public enum LongPollNamespace { REPLICA_HANDLES, + TRAFFIC_POLICIES, + BACKEND_CONFIGS, - ROUTE_TABLE; + ROUTE_TABLE } diff --git a/java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java b/java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java new file mode 100644 index 0000000000000..3f3ddc63c1ae2 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java @@ -0,0 +1,33 @@ +package io.ray.serve.poll; + +import java.io.Serializable; + +/** The updated object that long poll client received. */ +public class UpdatedObject implements Serializable { + + private static final long serialVersionUID = 6245682414826079438L; + + private Object objectSnapshot; + + /** + * The identifier for the object's version. There is not sequential relation among different + * object's snapshot_ids. + */ + private int snapshotId; + + public Object getObjectSnapshot() { + return objectSnapshot; + } + + public void setObjectSnapshot(Object objectSnapshot) { + this.objectSnapshot = objectSnapshot; + } + + public int getSnapshotId() { + return snapshotId; + } + + public void setSnapshotId(int snapshotId) { + this.snapshotId = snapshotId; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java b/java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java deleted file mode 100644 index cd66932f48276..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java +++ /dev/null @@ -1,10 +0,0 @@ -package io.ray.serve.util; - -import java.util.Collection; - -public class CollectionUtil { - - public static boolean isEmpty(Collection collection) { - return collection == null || collection.isEmpty(); - } -} diff --git a/java/serve/src/main/java/io/ray/serve/util/CommonUtil.java b/java/serve/src/main/java/io/ray/serve/util/CommonUtil.java deleted file mode 100644 index a32ee212196d8..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/util/CommonUtil.java +++ /dev/null @@ -1,13 +0,0 @@ -package io.ray.serve.util; - -import org.apache.commons.lang3.StringUtils; - -public class CommonUtil { - - public static String formatActorName(String controllerName, String actorName) { - if (StringUtils.isBlank(controllerName)) { - return actorName; - } - return controllerName + ":" + actorName; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java b/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java index ae449dd714733..5de1142433008 100644 --- a/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java +++ b/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java @@ -2,7 +2,6 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Executable; -import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; @@ -179,17 +178,4 @@ public static List getMethodStrings(Class targetClass) { } return methodStrings; } - - @SuppressWarnings("unchecked") - public static List getInstancesByClassNames(String classNames, Class cls) - throws ClassNotFoundException, InstantiationException, IllegalAccessException, - IllegalArgumentException, InvocationTargetException, NoSuchMethodException, - SecurityException { - String[] classNameArray = StringUtils.split(classNames, ";"); - List isntances = new ArrayList<>(); - for (String className : classNameArray) { - isntances.add((T) Class.forName(className).getConstructor().newInstance()); - } - return isntances; - } } diff --git a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java index 1a1c0c082d3f8..b1d02a046063e 100644 --- a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java +++ b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java @@ -2,42 +2,26 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Lists; -import com.google.gson.Gson; import com.google.protobuf.InvalidProtocolBufferException; import io.ray.runtime.serializer.MessagePackSerializer; -import io.ray.serve.Constants; import io.ray.serve.RayServeException; import io.ray.serve.generated.BackendConfig; import io.ray.serve.generated.BackendLanguage; -import io.ray.serve.generated.BackendVersion; -import io.ray.serve.generated.EndpointInfo; -import io.ray.serve.generated.EndpointSet; -import io.ray.serve.generated.LongPollResult; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.generated.RequestWrapper; -import io.ray.serve.generated.UpdatedObject; -import io.ray.serve.poll.KeyType; -import java.util.HashMap; -import java.util.Map; import org.apache.commons.lang3.StringUtils; public class ServeProtoUtil { - private static final Gson GSON = new Gson(); - - public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) { + public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) + throws InvalidProtocolBufferException { // Get a builder from BackendConfig(bytes) or create a new one. BackendConfig.Builder builder = null; if (backendConfigBytes == null) { builder = BackendConfig.newBuilder(); } else { - BackendConfig backendConfig = null; - try { - backendConfig = BackendConfig.parseFrom(backendConfigBytes); - } catch (InvalidProtocolBufferException e) { - throw new RayServeException("Failed to parse BackendConfig from protobuf bytes.", e); - } + BackendConfig backendConfig = BackendConfig.parseFrom(backendConfigBytes); if (backendConfig == null) { builder = BackendConfig.newBuilder(); } else { @@ -56,12 +40,12 @@ public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) { builder.setMaxConcurrentQueries(100); } - if (builder.getGracefulShutdownWaitLoopS() == 0) { - builder.setGracefulShutdownWaitLoopS(2); + if (builder.getExperimentalGracefulShutdownWaitLoopS() == 0) { + builder.setExperimentalGracefulShutdownWaitLoopS(2); } - if (builder.getGracefulShutdownTimeoutS() == 0) { - builder.setGracefulShutdownTimeoutS(20); + if (builder.getExperimentalGracefulShutdownTimeoutS() == 0) { + builder.setExperimentalGracefulShutdownTimeoutS(20); } if (builder.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) { @@ -100,7 +84,7 @@ public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes) // Set default values. if (StringUtils.isBlank(builder.getCallMethod())) { - builder.setCallMethod(Constants.DEFAULT_CALL_METHOD); + builder.setCallMethod("call"); } return builder.build(); @@ -124,47 +108,4 @@ public static RequestWrapper parseRequestWrapper(byte[] httpRequestWrapperBytes) return builder.build(); } - - public static Map parseUpdatedObjects(byte[] longPollResultBytes) - throws InvalidProtocolBufferException { - if (longPollResultBytes == null) { - return null; - } - LongPollResult longPollResult = LongPollResult.parseFrom(longPollResultBytes); - Map updatedObjects = longPollResult.getUpdatedObjectsMap(); - if (updatedObjects == null || updatedObjects.isEmpty()) { - return null; - } - Map udpates = new HashMap<>(updatedObjects.size()); - updatedObjects.forEach( - (key, value) -> udpates.put(ServeProtoUtil.GSON.fromJson(key, KeyType.class), value)); - return udpates; - } - - public static Map parseEndpointSet(byte[] endpointSetBytes) { - if (endpointSetBytes == null) { - return null; - } - EndpointSet endpointSet = null; - try { - endpointSet = EndpointSet.parseFrom(endpointSetBytes); - } catch (InvalidProtocolBufferException e) { - throw new RayServeException("Failed to parse EndpointSet from protobuf bytes.", e); - } - if (endpointSet == null) { - return null; - } - return endpointSet.getEndpointsMap(); - } - - public static BackendVersion parseBackendVersion(byte[] backendVersionBytes) { - if (backendVersionBytes == null) { - return null; - } - try { - return BackendVersion.parseFrom(backendVersionBytes); - } catch (InvalidProtocolBufferException e) { - throw new RayServeException("Failed to parse BackendVersion from protobuf bytes.", e); - } - } } diff --git a/java/serve/src/main/java/io/ray/serve/util/SocketUtil.java b/java/serve/src/main/java/io/ray/serve/util/SocketUtil.java deleted file mode 100644 index ab93a6e152210..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/util/SocketUtil.java +++ /dev/null @@ -1,49 +0,0 @@ -package io.ray.serve.util; - -import java.io.IOException; -import java.net.InetSocketAddress; -import java.net.ServerSocket; - -public class SocketUtil { - - public static final int PORT_RANGE_MAX = 65535; - - public static int findAvailableTcpPort(int minPort) { - int portRange = PORT_RANGE_MAX - minPort; - int candidatePort = minPort; - int searchCounter = 0; - while (!isPortAvailable(candidatePort)) { - candidatePort++; - if (++searchCounter > portRange) { - throw new IllegalStateException( - String.format( - "Could not find an available tcp port in the range [%d, %d] after %d attempts.", - minPort, PORT_RANGE_MAX, searchCounter)); - } - } - return candidatePort; - } - - public static boolean isPortAvailable(int port) { - ServerSocket socket; - try { - socket = new ServerSocket(); - } catch (IOException e) { - throw new IllegalStateException("Unable to create ServerSocket.", e); - } - - try { - InetSocketAddress sa = new InetSocketAddress(port); - socket.bind(sa); - return true; - } catch (IOException ex) { - return false; - } finally { - try { - socket.close(); - } catch (IOException ex) { - // ignore this exception for now - } - } - } -} diff --git a/java/serve/src/test/java/io/ray/serve/DummyServeController.java b/java/serve/src/test/java/io/ray/serve/DummyServeController.java deleted file mode 100644 index 6ee319a477898..0000000000000 --- a/java/serve/src/test/java/io/ray/serve/DummyServeController.java +++ /dev/null @@ -1,21 +0,0 @@ -package io.ray.serve; - -import io.ray.serve.generated.EndpointInfo; -import io.ray.serve.generated.EndpointSet; -import java.util.Map; - -public class DummyServeController implements ServeController { - - private Map endpoints; - - @Override - public byte[] getAllEndpoints() { - EndpointSet.Builder builder = EndpointSet.newBuilder(); - builder.putAllEndpoints(endpoints); - return builder.build().toByteArray(); - } - - public void setEndpoints(Map endpoints) { - this.endpoints = endpoints; - } -} diff --git a/java/serve/src/test/java/io/ray/serve/HttpProxyTest.java b/java/serve/src/test/java/io/ray/serve/HttpProxyTest.java deleted file mode 100644 index 5166603662c82..0000000000000 --- a/java/serve/src/test/java/io/ray/serve/HttpProxyTest.java +++ /dev/null @@ -1,74 +0,0 @@ -package io.ray.serve; - -import io.ray.api.ActorHandle; -import io.ray.api.Ray; -import io.ray.serve.api.Serve; -import io.ray.serve.generated.EndpointInfo; -import io.ray.serve.util.CommonUtil; -import java.io.IOException; -import java.net.HttpURLConnection; -import java.util.HashMap; -import java.util.Map; -import org.apache.commons.lang3.RandomStringUtils; -import org.apache.hc.client5.http.classic.HttpClient; -import org.apache.hc.client5.http.classic.methods.HttpPost; -import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; -import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class HttpProxyTest { - - @Test - public void test() throws IOException { - - boolean inited = Ray.isInitialized(); - Ray.init(); - - try { - String controllerName = - CommonUtil.formatActorName( - Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); - String endpointName = "HTTPProxyTest"; - String route = "/route"; - - // Controller - ActorHandle controllerHandle = - Ray.actor(DummyServeController::new).setName(controllerName).remote(); - - Map endpointInfos = new HashMap<>(); - endpointInfos.put( - endpointName, - EndpointInfo.newBuilder().setEndpointName(endpointName).setRoute(route).build()); - controllerHandle.task(DummyServeController::setEndpoints, endpointInfos).remote(); - - Serve.setInternalReplicaContext(null, null, controllerName, null); - - // ProxyRouter updates routes. - ProxyRouter proxyRouter = new ProxyRouter(); - proxyRouter.updateRoutes(endpointInfos); - - // HTTP proxy. - HttpProxy httpProxy = new HttpProxy(); - httpProxy.init(null, proxyRouter); - - // Send request. - HttpClient httpClient = HttpClientBuilder.create().build(); - HttpPost httpPost = new HttpPost("http://localhost:" + httpProxy.getPort() + route); - try (CloseableHttpResponse httpResponse = - (CloseableHttpResponse) httpClient.execute(httpPost)) { - - // No Backend replica, so error is expected. - int status = httpResponse.getCode(); - Assert.assertEquals(status, HttpURLConnection.HTTP_INTERNAL_ERROR); - } - - } finally { - if (!inited) { - Ray.shutdown(); - } - Serve.setInternalReplicaContext(null); - Serve.setGlobalClient(null); - } - } -} diff --git a/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java b/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java deleted file mode 100644 index 6b1daa11b1141..0000000000000 --- a/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java +++ /dev/null @@ -1,110 +0,0 @@ -package io.ray.serve; - -import io.ray.api.ActorHandle; -import io.ray.api.Ray; -import io.ray.runtime.serializer.MessagePackSerializer; -import io.ray.serve.api.Serve; -import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendVersion; -import io.ray.serve.generated.EndpointInfo; -import io.ray.serve.util.CommonUtil; -import java.io.IOException; -import java.net.HttpURLConnection; -import java.util.HashMap; -import java.util.Map; -import org.apache.commons.lang3.RandomStringUtils; -import org.apache.hc.client5.http.classic.HttpClient; -import org.apache.hc.client5.http.classic.methods.HttpPost; -import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; -import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class ProxyActorTest { - - @Test - public void test() throws IOException { - boolean inited = Ray.isInitialized(); - Ray.init(); - - try { - String prefix = "ProxyActorTest"; - String controllerName = - CommonUtil.formatActorName( - Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); - String backendTag = prefix; - String replicaTag = prefix; - String endpointName = prefix; - String route = "/route"; - String version = "v1"; - - // Controller - ActorHandle controller = - Ray.actor(DummyServeController::new).setName(controllerName).remote(); - Map endpointInfos = new HashMap<>(); - endpointInfos.put( - endpointName, - EndpointInfo.newBuilder().setEndpointName(endpointName).setRoute(route).build()); - controller.task(DummyServeController::setEndpoints, endpointInfos).remote(); - - // Replica - DeploymentInfo deploymentInfo = new DeploymentInfo(); - deploymentInfo.setBackendConfig(BackendConfig.newBuilder().build().toByteArray()); - deploymentInfo.setBackendVersion( - BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); - deploymentInfo.setReplicaConfig( - new ReplicaConfig(DummyBackendReplica.class.getName(), null, new HashMap<>())); - - ActorHandle replica = - Ray.actor( - RayServeWrappedReplica::new, - backendTag, - replicaTag, - deploymentInfo, - controllerName) - .setName(replicaTag) - .remote(); - replica.task(RayServeWrappedReplica::ready).remote(); - - // ProxyActor - ProxyActor proxyActor = new ProxyActor(controllerName, null); - proxyActor.getProxyRouter().updateRoutes(endpointInfos); - proxyActor - .getProxyRouter() - .getHandles() - .get(endpointName) - .getRouter() - .getReplicaSet() - .updateWorkerReplicas(ActorSet.newBuilder().addNames(replicaTag).build()); - - // Send request. - HttpClient httpClient = HttpClientBuilder.create().build(); - HttpPost httpPost = - new HttpPost( - "http://localhost:" - + ((HttpProxy) proxyActor.getProxies().get(HttpProxy.PROXY_NAME)).getPort() - + route); - try (CloseableHttpResponse httpResponse = - (CloseableHttpResponse) httpClient.execute(httpPost)) { - - int status = httpResponse.getCode(); - Assert.assertEquals(status, HttpURLConnection.HTTP_OK); - Object result = - MessagePackSerializer.decode( - EntityUtils.toByteArray(httpResponse.getEntity()), Object.class); - - Assert.assertNotNull(result); - Assert.assertEquals("1", result.toString()); - } - - } finally { - if (!inited) { - Ray.shutdown(); - } - Serve.setInternalReplicaContext(null); - Serve.setGlobalClient(null); - } - } -} diff --git a/java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java b/java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java deleted file mode 100644 index 03535a0575a79..0000000000000 --- a/java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java +++ /dev/null @@ -1,68 +0,0 @@ -package io.ray.serve; - -import io.ray.api.ActorHandle; -import io.ray.api.Ray; -import io.ray.serve.api.Serve; -import io.ray.serve.generated.EndpointInfo; -import io.ray.serve.util.CommonUtil; -import java.util.HashMap; -import java.util.Map; -import org.apache.commons.lang3.RandomStringUtils; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class ProxyRouterTest { - - @Test - public void test() { - - boolean inited = Ray.isInitialized(); - Ray.init(); - - try { - String prefix = "ProxyRouterTest"; - String controllerName = - CommonUtil.formatActorName( - Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); - String endpointName1 = prefix + "_1"; - String endpointName2 = prefix + "_2"; - String route1 = "/route1"; - - // Controller - ActorHandle controllerHandle = - Ray.actor(DummyServeController::new).setName(controllerName).remote(); - Map endpointInfos = new HashMap<>(); - endpointInfos.put( - endpointName1, - EndpointInfo.newBuilder().setEndpointName(endpointName1).setRoute(route1).build()); - endpointInfos.put( - endpointName2, EndpointInfo.newBuilder().setEndpointName(endpointName2).build()); - controllerHandle.task(DummyServeController::setEndpoints, endpointInfos).remote(); - - Serve.setInternalReplicaContext(null, null, controllerName, null); - - // ProxyRouter updates routes. - ProxyRouter proxyRouter = new ProxyRouter(); - proxyRouter.updateRoutes(endpointInfos); - - // Check result. - Map routeInfo = proxyRouter.getRouteInfo(); - Assert.assertNotNull(routeInfo); - Assert.assertNotNull(routeInfo.get(route1)); - Assert.assertEquals(routeInfo.get(route1).getRoute(), route1); - Assert.assertEquals(routeInfo.get(route1).getEndpointName(), endpointName1); - Assert.assertNotNull(routeInfo.get(endpointName2)); - Assert.assertEquals(routeInfo.get(endpointName2).getEndpointName(), endpointName2); - Map handles = proxyRouter.getHandles(); - Assert.assertNotNull(handles); - Assert.assertNotNull(handles.get(endpointName1)); - Assert.assertNotNull(handles.get(endpointName2)); - } finally { - if (!inited) { - Ray.shutdown(); - } - Serve.setInternalReplicaContext(null); - Serve.setGlobalClient(null); - } - } -} diff --git a/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java b/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java deleted file mode 100644 index 9e4ac68b612fd..0000000000000 --- a/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java +++ /dev/null @@ -1,76 +0,0 @@ -package io.ray.serve; - -import io.ray.api.ActorHandle; -import io.ray.api.ObjectRef; -import io.ray.api.Ray; -import io.ray.runtime.serializer.MessagePackSerializer; -import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendLanguage; -import io.ray.serve.generated.BackendVersion; -import java.util.HashMap; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class RayServeHandleTest { - - @Test - public void test() { - boolean inited = Ray.isInitialized(); - Ray.init(); - - try { - String backendTag = "RayServeHandleTest"; - String controllerName = backendTag + "_controller"; - String replicaTag = backendTag + "_replica"; - String actorName = replicaTag; - String version = "v1"; - - // Controller - ActorHandle controllerHandle = - Ray.actor(DummyServeController::new).setName(controllerName).remote(); - - // Replica - BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); - backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); - byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); - - Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; - byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); - - DeploymentInfo deploymentInfo = new DeploymentInfo(); - deploymentInfo.setBackendConfig(backendConfigBytes); - deploymentInfo.setBackendVersion( - BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); - deploymentInfo.setReplicaConfig( - new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); - - ActorHandle replicaHandle = - Ray.actor( - RayServeWrappedReplica::new, - backendTag, - replicaTag, - deploymentInfo, - controllerName) - .setName(actorName) - .remote(); - replicaHandle.task(RayServeWrappedReplica::ready).remote(); - - // RayServeHandle - RayServeHandle rayServeHandle = - new RayServeHandle(controllerHandle, backendTag, null, null) - .setMethodName("getBackendTag"); - ActorSet.Builder builder = ActorSet.newBuilder(); - builder.addNames(actorName); - rayServeHandle.getRouter().getReplicaSet().updateWorkerReplicas(builder.build()); - - // remote - ObjectRef resultRef = rayServeHandle.remote(null); - Assert.assertEquals((String) resultRef.get(), backendTag); - } finally { - if (!inited) { - Ray.shutdown(); - } - } - } -} diff --git a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java index 065b74ac1fc0e..7cc7746ff165c 100644 --- a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java +++ b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java @@ -6,12 +6,9 @@ import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.generated.BackendConfig; import io.ray.serve.generated.BackendLanguage; -import io.ray.serve.generated.BackendVersion; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.generated.RequestWrapper; import java.io.IOException; -import java.util.HashMap; -import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -20,6 +17,7 @@ public class RayServeReplicaTest { @SuppressWarnings("unused") @Test public void test() throws IOException { + boolean inited = Ray.isInitialized(); Ray.init(); @@ -27,40 +25,38 @@ public void test() throws IOException { String controllerName = "RayServeReplicaTest"; String backendTag = "b_tag"; String replicaTag = "r_tag"; - String version = "v1"; - ActorHandle controllerHandle = - Ray.actor(DummyServeController::new).setName(controllerName).remote(); + ActorHandle controllerHandle = + Ray.actor(ReplicaContext::new, backendTag, replicaTag, controllerName, new Object()) + .setName(controllerName) + .remote(); BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); + byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; - byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); - DeploymentInfo deploymentInfo = new DeploymentInfo(); - deploymentInfo.setBackendConfig(backendConfigBytes); - deploymentInfo.setBackendVersion( - BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); - deploymentInfo.setReplicaConfig( - new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); + byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); ActorHandle backendHandle = Ray.actor( RayServeWrappedReplica::new, backendTag, replicaTag, - deploymentInfo, + "io.ray.serve.ReplicaContext", + initArgsBytes, + backendConfigBytes, controllerName) .remote(); - // ready backendHandle.task(RayServeWrappedReplica::ready).remote(); - // handle request RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); - requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); + requestMetadata.setRequestId("RayServeReplicaTest"); requestMetadata.setCallMethod("getBackendTag"); + RequestWrapper.Builder requestWrapper = RequestWrapper.newBuilder(); ObjectRef resultRef = @@ -70,22 +66,8 @@ public void test() throws IOException { requestMetadata.build().toByteArray(), requestWrapper.build().toByteArray()) .remote(); - Assert.assertEquals((String) resultRef.get(), backendTag); - - // reconfigure - ObjectRef versionRef = - backendHandle.task(RayServeWrappedReplica::reconfigure, (Object) null).remote(); - Assert.assertEquals(BackendVersion.parseFrom(versionRef.get()).getCodeVersion(), version); - - // get version - versionRef = backendHandle.task(RayServeWrappedReplica::getVersion).remote(); - Assert.assertEquals(BackendVersion.parseFrom(versionRef.get()).getCodeVersion(), version); - - // prepare for shutdown - ObjectRef shutdownRef = - backendHandle.task(RayServeWrappedReplica::prepareForShutdown).remote(); - Assert.assertTrue(shutdownRef.get()); + Assert.assertEquals((String) resultRef.get(), backendTag); } finally { if (!inited) { Ray.shutdown(); diff --git a/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java deleted file mode 100644 index 513d27e4bb6b1..0000000000000 --- a/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java +++ /dev/null @@ -1,108 +0,0 @@ -package io.ray.serve; - -import io.ray.api.ActorHandle; -import io.ray.api.ObjectRef; -import io.ray.api.Ray; -import io.ray.runtime.serializer.MessagePackSerializer; -import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendLanguage; -import io.ray.serve.generated.BackendVersion; -import io.ray.serve.generated.RequestMetadata; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; -import org.apache.commons.lang3.RandomStringUtils; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class ReplicaSetTest { - - private String backendTag = "ReplicaSetTest"; - - @Test - public void setMaxConcurrentQueriesTest() { - ReplicaSet replicaSet = new ReplicaSet(backendTag); - BackendConfig.Builder builder = BackendConfig.newBuilder(); - builder.setMaxConcurrentQueries(200); - - replicaSet.setMaxConcurrentQueries(builder.build()); - Assert.assertEquals(replicaSet.getMaxConcurrentQueries(), 200); - } - - @Test - public void updateWorkerReplicasTest() { - ReplicaSet replicaSet = new ReplicaSet(backendTag); - ActorSet.Builder builder = ActorSet.newBuilder(); - - replicaSet.updateWorkerReplicas(builder.build()); - Map, Set>> inFlightQueries = - replicaSet.getInFlightQueries(); - Assert.assertTrue(inFlightQueries.isEmpty()); - } - - @SuppressWarnings("unused") - @Test - public void assignReplicaTest() { - boolean inited = Ray.isInitialized(); - Ray.init(); - - try { - String controllerName = backendTag + "_controller"; - String replicaTag = backendTag + "_replica"; - String actorName = replicaTag; - String version = "v1"; - - // Controller - ActorHandle controllerHandle = - Ray.actor(DummyServeController::new).setName(controllerName).remote(); - - // Replica - BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); - backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); - byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); - - Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; - byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); - - DeploymentInfo deploymentInfo = new DeploymentInfo(); - deploymentInfo.setBackendConfig(backendConfigBytes); - deploymentInfo.setBackendVersion( - BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); - deploymentInfo.setReplicaConfig( - new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); - - ActorHandle replicaHandle = - Ray.actor( - RayServeWrappedReplica::new, - backendTag, - replicaTag, - deploymentInfo, - controllerName) - .setName(actorName) - .remote(); - replicaHandle.task(RayServeWrappedReplica::ready).remote(); - - // ReplicaSet - ReplicaSet replicaSet = new ReplicaSet(backendTag); - ActorSet.Builder builder = ActorSet.newBuilder(); - builder.addNames(actorName); - replicaSet.updateWorkerReplicas(builder.build()); - - // assign - - RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); - requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); - requestMetadata.setCallMethod("getBackendTag"); - - Query query = new Query(requestMetadata.build(), null); - ObjectRef resultRef = replicaSet.assignReplica(query); - - Assert.assertEquals((String) resultRef.get(), backendTag); - } finally { - if (!inited) { - Ray.shutdown(); - } - } - } -} diff --git a/java/serve/src/test/java/io/ray/serve/RouterTest.java b/java/serve/src/test/java/io/ray/serve/RouterTest.java deleted file mode 100644 index 3312179912e38..0000000000000 --- a/java/serve/src/test/java/io/ray/serve/RouterTest.java +++ /dev/null @@ -1,80 +0,0 @@ -package io.ray.serve; - -import io.ray.api.ActorHandle; -import io.ray.api.ObjectRef; -import io.ray.api.Ray; -import io.ray.runtime.serializer.MessagePackSerializer; -import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendLanguage; -import io.ray.serve.generated.BackendVersion; -import io.ray.serve.generated.RequestMetadata; -import java.util.HashMap; -import org.apache.commons.lang3.RandomStringUtils; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class RouterTest { - - @Test - public void test() { - boolean inited = Ray.isInitialized(); - Ray.init(); - - try { - String backendTag = "RouterTest"; - String controllerName = backendTag + "_controller"; - String replicaTag = backendTag + "_replica"; - String actorName = replicaTag; - String version = "v1"; - - // Controller - ActorHandle controllerHandle = - Ray.actor(DummyServeController::new).setName(controllerName).remote(); - - // Replica - BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); - backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); - byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); - - Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; - byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); - - DeploymentInfo deploymentInfo = new DeploymentInfo(); - deploymentInfo.setBackendConfig(backendConfigBytes); - deploymentInfo.setBackendVersion( - BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); - deploymentInfo.setReplicaConfig( - new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); - - ActorHandle replicaHandle = - Ray.actor( - RayServeWrappedReplica::new, - backendTag, - replicaTag, - deploymentInfo, - controllerName) - .setName(actorName) - .remote(); - replicaHandle.task(RayServeWrappedReplica::ready).remote(); - - // Router - Router router = new Router(controllerHandle, backendTag); - ActorSet.Builder builder = ActorSet.newBuilder(); - builder.addNames(actorName); - router.getReplicaSet().updateWorkerReplicas(builder.build()); - - // assign - RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); - requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); - requestMetadata.setCallMethod("getBackendTag"); - - ObjectRef resultRef = router.assignRequest(requestMetadata.build(), null); - Assert.assertEquals((String) resultRef.get(), backendTag); - } finally { - if (!inited) { - Ray.shutdown(); - } - } - } -} diff --git a/java/serve/src/test/java/io/ray/serve/api/ClientTest.java b/java/serve/src/test/java/io/ray/serve/api/ClientTest.java deleted file mode 100644 index c3489bc1a1a19..0000000000000 --- a/java/serve/src/test/java/io/ray/serve/api/ClientTest.java +++ /dev/null @@ -1,47 +0,0 @@ -package io.ray.serve.api; - -import io.ray.api.ActorHandle; -import io.ray.api.Ray; -import io.ray.serve.DummyServeController; -import io.ray.serve.RayServeHandle; -import io.ray.serve.generated.EndpointInfo; -import java.util.HashMap; -import java.util.Map; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class ClientTest { - - @Test - public void getHandleTest() { - - boolean inited = Ray.isInitialized(); - Ray.init(); - - try { - String prefix = "ClientTest"; - String controllerName = prefix + "_controller"; - String endpointName = prefix + "_endpoint"; - - // Controller. - ActorHandle controllerHandle = - Ray.actor(DummyServeController::new).setName(controllerName).remote(); - - // Mock endpoints. - Map endpoints = new HashMap<>(); - endpoints.put(endpointName, EndpointInfo.newBuilder().setEndpointName(endpointName).build()); - controllerHandle.task(DummyServeController::setEndpoints, endpoints).remote(); - - // Client. - Client client = new Client(controllerHandle, controllerName, true); - - // Get handle. - RayServeHandle rayServeHandle = client.getHandle(endpointName, false); - Assert.assertNotNull(rayServeHandle); - } finally { - if (!inited) { - Ray.shutdown(); - } - } - } -} diff --git a/java/serve/src/test/java/io/ray/serve/api/ServeTest.java b/java/serve/src/test/java/io/ray/serve/api/ServeTest.java index cf470e8ce2248..b63a709a167de 100644 --- a/java/serve/src/test/java/io/ray/serve/api/ServeTest.java +++ b/java/serve/src/test/java/io/ray/serve/api/ServeTest.java @@ -1,12 +1,7 @@ package io.ray.serve.api; -import io.ray.api.ActorHandle; -import io.ray.api.Ray; -import io.ray.serve.Constants; -import io.ray.serve.DummyServeController; +import io.ray.serve.RayServeException; import io.ray.serve.ReplicaContext; -import io.ray.serve.util.CommonUtil; -import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -15,53 +10,31 @@ public class ServeTest { @Test public void replicaContextTest() { + ReplicaContext preContext = Serve.INTERNAL_REPLICA_CONTEXT; + ReplicaContext replicaContext; + + // Test null replica context. + Serve.INTERNAL_REPLICA_CONTEXT = null; try { - // Test context setting and getting. - String backendTag = "backendTag"; - String replicaTag = "replicaTag"; - String controllerName = "controllerName"; - Object servableObject = new Object(); - Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, servableObject); + replicaContext = Serve.getReplicaContext(); + Assert.assertTrue(false, "expect RayServeException"); + } catch (RayServeException e) { - ReplicaContext replicaContext = Serve.getReplicaContext(); - Assert.assertNotNull(replicaContext, "no replica context"); - Assert.assertEquals(replicaContext.getBackendTag(), backendTag); - Assert.assertEquals(replicaContext.getReplicaTag(), replicaTag); - Assert.assertEquals(replicaContext.getInternalControllerName(), controllerName); - } finally { - // Recover context. - Serve.setInternalReplicaContext(null); } - } - @SuppressWarnings("unused") - @Test - public void getGlobalClientTest() { - boolean inited = Ray.isInitialized(); - Ray.init(); - try { - Client client = null; - try { - client = Serve.getGlobalClient(); - Assert.assertTrue(false, "Expect IllegalStateException here!"); - } catch (IllegalStateException e) { - } - Assert.assertNull(client); + // Test context setting and getting. + String backendTag = "backendTag"; + String replicaTag = "replicaTag"; + String controllerName = "controllerName"; + Object servableObject = new Object(); + Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, servableObject); - String controllerName = - CommonUtil.formatActorName( - Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); - ActorHandle actorHandle = - Ray.actor(DummyServeController::new).setName(controllerName).remote(); - Serve.setInternalReplicaContext(null, null, controllerName, null); - client = Serve.getGlobalClient(); - Assert.assertNotNull(client); - } finally { - if (!inited) { - Ray.shutdown(); - } - Serve.setInternalReplicaContext(null); - Serve.setGlobalClient(null); - } + replicaContext = Serve.getReplicaContext(); + Assert.assertNotNull(replicaContext, "no replica context"); + Assert.assertEquals(replicaContext.getBackendTag(), backendTag); + Assert.assertEquals(replicaContext.getReplicaTag(), replicaTag); + Assert.assertEquals(replicaContext.getInternalControllerName(), controllerName); + + Serve.INTERNAL_REPLICA_CONTEXT = preContext; } } diff --git a/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java b/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java index 710ad97128ede..628f5ff4a89c4 100644 --- a/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java +++ b/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java @@ -1,15 +1,12 @@ package io.ray.serve.poll; -import com.google.gson.Gson; import org.testng.Assert; import org.testng.annotations.Test; public class KeyTypeTest { - private static final Gson GSON = new Gson(); - @Test - public void hashTest() { + public void test() { KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1"); KeyType k2 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1"); KeyType k3 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, null); @@ -31,14 +28,4 @@ public void hashTest() { Assert.assertNotEquals(k1.hashCode(), k4.hashCode()); Assert.assertFalse(k1.equals(k4)); } - - @Test - public void jsonTest() { - KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1"); - String json = GSON.toJson(k1); - - KeyType k2 = GSON.fromJson(json, KeyType.class); - Assert.assertEquals(k1, k2); - Assert.assertEquals(k1.hashCode(), k2.hashCode()); - } } diff --git a/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java b/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java index 7ee254806fad3..3d172d87bedc7 100644 --- a/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java +++ b/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java @@ -1,8 +1,5 @@ package io.ray.serve.poll; -import com.google.protobuf.ByteString; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.UpdatedObject; import java.util.HashMap; import java.util.Map; import org.testng.Assert; @@ -13,35 +10,25 @@ public class LongPollClientTest { @Test public void test() throws Throwable { - String[] a = new String[] {"test"}; - - // Construct a listener map. KeyType keyType = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "backendTag"); + int[] a = new int[] {0}; Map keyListeners = new HashMap<>(); - keyListeners.put( - keyType, (object) -> a[0] = String.valueOf(((BackendConfig) object).getNumReplicas())); - - // Initialize LongPollClient. + keyListeners.put(keyType, (object) -> a[0] = (Integer) object); LongPollClient longPollClient = new LongPollClient(null, keyListeners); - // Construct updated object. - BackendConfig.Builder backendConfig = BackendConfig.newBuilder(); - backendConfig.setNumReplicas(20); int snapshotId = 10; - UpdatedObject.Builder updatedObject = UpdatedObject.newBuilder(); + int objectSnapshot = 20; + UpdatedObject updatedObject = new UpdatedObject(); updatedObject.setSnapshotId(snapshotId); - updatedObject.setObjectSnapshot(ByteString.copyFrom(backendConfig.build().toByteArray())); + updatedObject.setObjectSnapshot(objectSnapshot); - // Process update. Map updates = new HashMap<>(); - updates.put(keyType, updatedObject.build()); + updates.put(keyType, updatedObject); longPollClient.processUpdate(updates); - // Validation. Assert.assertEquals(longPollClient.getSnapshotIds().get(keyType).intValue(), snapshotId); Assert.assertEquals( - ((BackendConfig) longPollClient.getObjectSnapshots().get(keyType)).getNumReplicas(), - backendConfig.getNumReplicas()); - Assert.assertEquals(a[0], String.valueOf(backendConfig.getNumReplicas())); + ((Integer) longPollClient.getObjectSnapshots().get(keyType)).intValue(), objectSnapshot); + Assert.assertEquals(a[0], objectSnapshot); } } diff --git a/python/build-wheel-windows.sh b/python/build-wheel-windows.sh index c7c282acaa421..cb36f901bd61c 100755 --- a/python/build-wheel-windows.sh +++ b/python/build-wheel-windows.sh @@ -81,13 +81,6 @@ build_wheel_windows() { unset PYTHON2_BIN_PATH PYTHON3_BIN_PATH # make sure these aren't set by some chance install_ray cd "${WORKSPACE_DIR}"/python - # Set the commit SHA in __init__.py. - if [ -n "$TRAVIS_COMMIT" ]; then - sed -i.bak "s/{{RAY_COMMIT_SHA}}/$TRAVIS_COMMIT/g" ray/__init__.py && rm ray/__init__.py.bak - else - echo "TRAVIS_COMMIT variable not set - required to populated ray.__commit__." - exit 1 - fi # build ray wheel python setup.py --quiet bdist_wheel # build ray-cpp wheel diff --git a/python/ray/_private/client_mode_hook.py b/python/ray/_private/client_mode_hook.py index ef3df68206303..e6abf5f5a98f0 100644 --- a/python/ray/_private/client_mode_hook.py +++ b/python/ray/_private/client_mode_hook.py @@ -1,6 +1,6 @@ import os from contextlib import contextmanager -from functools import partial, wraps +from functools import wraps import threading # Attr set on func defs to mark they have been converted to client mode. @@ -15,8 +15,6 @@ is_client_mode_enabled_by_default = is_client_mode_enabled os.environ.update({"RAY_CLIENT_MODE": "0"}) -is_init_called = False - # Local setting of whether to ignore client hook conversion. This defaults # to TRUE and is disabled when the underlying 'real' Ray function is needed. _client_hook_status_on_thread = threading.local() @@ -77,27 +75,13 @@ def enable_client_mode(): _explicitly_disable_client_mode() -def client_mode_hook(func=None, *, auto_init: bool): - """Decorator for whether to use the 'regular' ray version of a function, - or the Ray Client version of that function. - - Args: - func (callable): This function. This is set when this function is used - as a decorator. - auto_init (bool): Whether `ray.init()` should be transparently called when - the wrapped function is called. This should be `True` for functions - that are *NOT* part of the initialization path (e.g. `init` or - `is_initialized`) or for functions that do not require Ray to be - initialized (e.g., KV operations, `shutdown`). - """ - if func is None: - return partial(client_mode_hook, auto_init=auto_init) - +def client_mode_hook(func): + """Decorator for ray module methods to delegate to ray client""" from ray.util.client import ray @wraps(func) def wrapper(*args, **kwargs): - if client_mode_should_convert(auto_init=auto_init): + if client_mode_should_convert(): # Legacy code # we only convert init function if RAY_CLIENT_MODE=1 if func.__name__ != "init" or is_client_mode_enabled_by_default: @@ -107,23 +91,13 @@ def wrapper(*args, **kwargs): return wrapper -def client_mode_should_convert(*, auto_init: bool): - """Determines if functions should be converted to client mode & if - Ray should be auto-initialized. - - NOTE: `auto_init` must happen before we branch into regular ray or client - code because the initialization may result in either mode. - """ - if auto_init: - import ray - if os.environ.get("RAY_ENABLE_AUTO_CONNECT", - "") != "0" and not ray.is_initialized(): - ray.init() - - # `is_client_mode_enabled_by_default` is used for testing with - # `RAY_CLIENT_MODE=1`. This flag means all tests run with client mode. +def client_mode_should_convert(): + # This is for testing with RAY_CLIENT_MODE. + # When RAY_CLIENT_MODE=1, it means that for all the tests + # will run with client mode. + # is_client_mode_enabled will be set to be off when client is off return (is_client_mode_enabled or is_client_mode_enabled_by_default) and \ - _get_client_hook_status_on_thread() + _get_client_hook_status_on_thread() def client_mode_wrap(func): @@ -141,9 +115,7 @@ def client_mode_wrap(func): @wraps(func) def wrapper(*args, **kwargs): - # Directly pass this through since `client_mode_wrap` is for - # Placement Group APIs - if client_mode_should_convert(auto_init=True): + if client_mode_should_convert(): f = ray.remote(num_cpus=0)(func) ref = f.remote(*args, **kwargs) return ray.get(ref) diff --git a/python/ray/_private/parameter.py b/python/ray/_private/parameter.py index ac727fc2dec01..4303808609a48 100644 --- a/python/ray/_private/parameter.py +++ b/python/ray/_private/parameter.py @@ -72,8 +72,8 @@ class RayParams: be created. worker_path (str): The path of the source code that will be run by the worker. - setup_worker_path (str): The path of the Python file that will set up - the environment for the worker process. + setup_worker_path (str): The path of the Python file that will run + worker_setup_hook to set up the environment for the worker process. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. include_dashboard: Boolean flag indicating whether to start the web @@ -116,7 +116,6 @@ class RayParams: ray_debugger_external (bool): If true, make the Ray debugger for a worker available externally to the node it is running on. This will bind on 0.0.0.0 instead of localhost. - env_vars (dict): Override environment variables for the raylet. """ def __init__(self, @@ -169,7 +168,7 @@ def __init__(self, metrics_export_port=None, tracing_startup_hook=None, no_monitor=False, - env_vars=None): + lru_evict=False): self.object_ref_seed = object_ref_seed self.external_addresses = external_addresses self.redis_address = redis_address @@ -216,11 +215,18 @@ def __init__(self, self.start_initial_python_workers_for_first_job = ( start_initial_python_workers_for_first_job) self.ray_debugger_external = ray_debugger_external - self.env_vars = env_vars self._system_config = _system_config or {} self._enable_object_reconstruction = enable_object_reconstruction self._check_usage() + # Set the internal config options for LRU eviction. + if lru_evict: + raise DeprecationWarning( + "The lru_evict flag is deprecated as Ray natively " + "supports object spilling. Please read " + "https://docs.ray.io/en/master/memory-management.html#object-spilling " # noqa + "for more details.") + # Set the internal config options for object reconstruction. if enable_object_reconstruction: # Turn off object pinning. diff --git a/python/ray/_private/runtime_env/__init__.py b/python/ray/_private/runtime_env/__init__.py index e69de29bb2d1d..20401cb96f021 100644 --- a/python/ray/_private/runtime_env/__init__.py +++ b/python/ray/_private/runtime_env/__init__.py @@ -0,0 +1,3 @@ +from ray._private.runtime_env.context import RuntimeEnvContext # noqa: F401 +from ray._private.runtime_env.validation import ( # noqa: F401 + override_task_or_actor_runtime_env, RuntimeEnvDict) # noqa: F401 diff --git a/python/ray/_private/runtime_env/conda.py b/python/ray/_private/runtime_env/conda.py index 92bc3d8cb1139..d9c810b89b75b 100644 --- a/python/ray/_private/runtime_env/conda.py +++ b/python/ray/_private/runtime_env/conda.py @@ -12,9 +12,9 @@ from pathlib import Path import ray +from ray._private.runtime_env import RuntimeEnvContext from ray._private.runtime_env.conda_utils import (get_conda_activate_commands, get_or_create_conda_env) -from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.utils import (get_wheel_filename, get_master_wheel_url, get_release_wheel_url, try_to_create_directory) @@ -81,7 +81,7 @@ def get_conda_dict(runtime_env, runtime_env_dir) -> Optional[Dict[Any, Any]]: else: return None if runtime_env.get("pip"): - requirements_txt = "\n".join(runtime_env["pip"]) + "\n" + requirements_txt = runtime_env["pip"] pip_hash = hashlib.sha1(requirements_txt.encode("utf-8")).hexdigest() pip_hash_str = f"pip-generated-{pip_hash}" diff --git a/python/ray/_private/runtime_env/conda_utils.py b/python/ray/_private/runtime_env/conda_utils.py index 2339da036b60c..5d61c9e8c5f45 100644 --- a/python/ray/_private/runtime_env/conda_utils.py +++ b/python/ray/_private/runtime_env/conda_utils.py @@ -126,21 +126,6 @@ def get_or_create_conda_env(conda_env_path: str, return env_name -def get_conda_env_list() -> list: - """ - Get conda env list. - """ - conda_path = get_conda_bin_executable("conda") - try: - exec_cmd([conda_path, "--help"], throw_on_error=False) - except EnvironmentError: - raise ValueError(f"Could not find Conda executable at {conda_path}.") - _, stdout, _ = exec_cmd([conda_path, "env", "list", "--json"]) - envs = json.loads(stdout)["envs"] - print(f"Conda env len {len(envs)}") - return envs - - class ShellCommandException(Exception): pass diff --git a/python/ray/_private/runtime_env/context.py b/python/ray/_private/runtime_env/context.py index c5db64437ce2d..af3409f310ca5 100644 --- a/python/ray/_private/runtime_env/context.py +++ b/python/ray/_private/runtime_env/context.py @@ -4,13 +4,9 @@ import sys from typing import Dict, List, Optional -from ray.util.annotations import DeveloperAPI -from ray.core.generated.common_pb2 import Language - logger = logging.getLogger(__name__) -@DeveloperAPI class RuntimeEnvContext: """A context used to describe the created runtime env.""" @@ -35,13 +31,10 @@ def serialize(self) -> str: def deserialize(json_string): return RuntimeEnvContext(**json.loads(json_string)) - def exec_worker(self, passthrough_args: List[str], language: Language): + def exec_worker(self, passthrough_args: List[str]): os.environ.update(self.env_vars) - if language == Language.PYTHON: - executable = f"exec {self.py_executable}" - else: - executable = "exec" - exec_command = " ".join([executable] + passthrough_args) + exec_command = " ".join([f"exec {self.py_executable}"] + + passthrough_args) command_str = " && ".join(self.command_prefix + [exec_command]) logger.info(f"Exec'ing worker with command: {command_str}") os.execvp("bash", ["bash", "-c", command_str]) diff --git a/python/ray/_private/runtime_env/plugin.py b/python/ray/_private/runtime_env/plugin.py deleted file mode 100644 index 5e411c141fc08..0000000000000 --- a/python/ray/_private/runtime_env/plugin.py +++ /dev/null @@ -1,70 +0,0 @@ -from abc import ABC, abstractstaticmethod - -from ray.util.annotations import DeveloperAPI -from ray._private.runtime_env.context import RuntimeEnvContext - - -@DeveloperAPI -class RuntimeEnvPlugin(ABC): - @abstractstaticmethod - def validate(runtime_env_dict: dict) -> str: - """Validate user entry and returns a URI uniquely describing resource. - - This method will be called at ``f.options(runtime_env=...)`` or - ``ray.init(runtime_env=...)`` time and it should check the runtime env - dictionary for any errors. For example, it can raise "TypeError: - expected string for "conda" field". - - Args: - runtime_env_dict(dict): the entire dictionary passed in by user. - - Returns: - uri(str): a URI uniquely describing this resource (e.g., a hash of - the conda spec). - """ - raise NotImplementedError() - - def create(uri: str, runtime_env_dict: dict, - ctx: RuntimeEnvContext) -> float: - """Create and install the runtime environment. - - Gets called in the runtime env agent at install time. The URI can be - used as a caching mechanism. - - Args: - uri(str): a URI uniquely describing this resource. - runtime_env_dict(dict): the entire dictionary passed in by user. - ctx(RuntimeEnvContext): auxiliary information supplied by Ray. - - Returns: - the disk space taken up by this plugin installation for this - environment. e.g. for working_dir, this downloads the files to the - local node. - """ - return 0 - - def modify_context(uri: str, runtime_env_dict: dict, - ctx: RuntimeEnvContext) -> None: - """Modify context to change worker startup behavior. - - For example, you can use this to preprend "cd " command to worker - startup, or add new environment variables. - - Args: - uri(str): a URI uniquely describing this resource. - runtime_env_dict(dict): the entire dictionary passed in by user. - ctx(RuntimeEnvContext): auxiliary information supplied by Ray. - """ - return - - def delete(uri: str, ctx: RuntimeEnvContext) -> float: - """Delete the the runtime environment given uri. - - Args: - uri(str): a URI uniquely describing this resource. - ctx(RuntimeEnvContext): auxiliary information supplied by Ray. - - Returns: - the amount of space reclaimed by the deletion. - """ - return 0 diff --git a/python/ray/_private/runtime_env/validation.py b/python/ray/_private/runtime_env/validation.py index 0e41bb6b30bd0..e113e4151424d 100644 --- a/python/ray/_private/runtime_env/validation.py +++ b/python/ray/_private/runtime_env/validation.py @@ -1,15 +1,12 @@ -import copy import json import logging import os from pathlib import Path import sys -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, Optional import yaml import ray -from ray._private.runtime_env.plugin import RuntimeEnvPlugin -from ray._private.utils import import_attr # We need to setup this variable before # using this module @@ -23,198 +20,19 @@ GCS_STORAGE_MAX_SIZE = 100 * 1024 * 1024 # 100MiB -def parse_and_validate_working_dir(working_dir: str, - is_task_or_actor: bool = False) -> str: - """Parses and validates a user-provided 'working_dir' option. +class RuntimeEnvDict: + """Parses and validates the runtime env dictionary from the user. - The working_dir may not be specified per-task or per-actor. - - Otherwise, it should be a valid path to a local directory. - """ - assert working_dir is not None - - if is_task_or_actor: - raise NotImplementedError( - "Overriding working_dir for tasks and actors isn't supported. " - "Please use ray.init(runtime_env={'working_dir': ...}) " - "to configure the environment per-job instead.") - elif not isinstance(working_dir, str): - raise TypeError("`working_dir` must be a string, got " - f"{type(working_dir)}.") - elif not Path(working_dir).is_dir(): - raise ValueError( - f"working_dir {working_dir} is not a valid directory.") - - return working_dir - - -def parse_and_validate_conda(conda: Union[str, dict], - is_task_or_actor: bool = False - ) -> Union[str, dict]: - """Parses and validates a user-provided 'conda' option. - - Conda can be one of three cases: - 1) A dictionary describing the env. This is passed through directly. - 2) A string referring to a preinstalled conda env. - 3) A string pointing to a local conda YAML file. This is detected - by looking for a '.yaml' or '.yml' suffix. In this case, the file - will be read as YAML and passed through as a dictionary. - """ - assert conda is not None - - result = None - if sys.platform == "win32": - raise NotImplementedError("The 'conda' field in runtime_env " - "is not currently supported on " - "Windows.") - elif isinstance(conda, str): - yaml_file = Path(conda) - if yaml_file.suffix in (".yaml", ".yml"): - if not yaml_file.is_file(): - raise ValueError(f"Can't find conda YAML file {yaml_file}.") - try: - result = yaml.safe_load(yaml_file.read_text()) - except Exception as e: - raise ValueError( - f"Failed to read conda file {yaml_file}: {e}.") - else: - # Assume it's a pre-existing conda environment name. - result = conda - elif isinstance(conda, dict): - result = conda - else: - raise TypeError("runtime_env['conda'] must be of type str or " - f"dict, got {type(conda)}.") - - return result - - -def parse_and_validate_pip(pip: Union[str, List[str]], - is_task_or_actor: bool = False - ) -> Optional[List[str]]: - """Parses and validates a user-provided 'pip' option. - - Conda can be one of two cases: - 1) A List[str] describing the requirements. This is passed through. - 2) A string pointing to a local requirements file. In this case, the - file contents will be read split into a list. - """ - assert pip is not None - - result = None - if sys.platform == "win32": - raise NotImplementedError("The 'pip' field in runtime_env " - "is not currently supported on " - "Windows.") - elif isinstance(pip, str): - # We have been given a path to a requirements.txt file. - pip_file = Path(pip) - if not pip_file.is_file(): - raise ValueError(f"{pip_file} is not a valid file") - result = pip_file.read_text().strip().split("\n") - elif isinstance(pip, list) and all(isinstance(dep, str) for dep in pip): - if len(pip) == 0: - result = None - else: - result = pip - else: - raise TypeError("runtime_env['pip'] must be of type str or " - f"List[str], got {type(pip)}") - - return result - - -def parse_and_validate_uris(uris: List[str], - is_task_or_actor: bool = False) -> List[str]: - """Parses and validates a user-provided 'uris' option. - - These are passed through without validation (for now). - """ - assert uris is not None - return uris - - -def parse_and_validate_container(container: List[str], - is_task_or_actor: bool = False) -> List[str]: - """Parses and validates a user-provided 'container' option. - - This is passed through without validation (for now). - """ - assert container is not None - return container - - -def parse_and_validate_excludes(excludes: List[str], - is_task_or_actor: bool = False) -> List[str]: - """Parses and validates a user-provided 'excludes' option. - - This is validated to verify that it is of type List[str]. - - If an empty list is passed, we return `None` for consistency. - """ - assert excludes is not None - - if isinstance(excludes, list) and len(excludes) == 0: - return None - - if (isinstance(excludes, list) - and all(isinstance(path, str) for path in excludes)): - return excludes - else: - raise TypeError("runtime_env['excludes'] must be of type " - f"List[str], got {type(excludes)}") - - -def parse_and_validate_env_vars(env_vars: Dict[str, str], - is_task_or_actor: bool = False - ) -> Optional[Dict[str, str]]: - """Parses and validates a user-provided 'env_vars' option. - - This is validated to verify that all keys and vals are strings. - - If an empty dictionary is passed, we return `None` for consistency. - """ - assert env_vars is not None - if len(env_vars) == 0: - return None - - if not (isinstance(env_vars, dict) and all( - isinstance(k, str) and isinstance(v, str) - for (k, v) in env_vars.items())): - raise TypeError("runtime_env['env_vars'] must be of type " - "Dict[str, str]") - - return env_vars - - -# Dictionary mapping runtime_env options with the function to parse and -# validate them. -OPTION_TO_VALIDATION_FN = { - "working_dir": parse_and_validate_working_dir, - "excludes": parse_and_validate_excludes, - "conda": parse_and_validate_conda, - "pip": parse_and_validate_pip, - "uris": parse_and_validate_uris, - "env_vars": parse_and_validate_env_vars, - "container": parse_and_validate_container, -} - - -class ParsedRuntimeEnv(dict): - """An internal wrapper for runtime_env that is parsed and validated. - - This should be constructed from user-provided input (the API runtime_env) - and used everywhere that the runtime_env is passed around internally. - - All options in the resulting dictionary will have non-None values. - - Currently supported options: + Attributes: working_dir (Path): Specifies the working directory of the worker. This can either be a local directory or zip file. Examples: "." # cwd "local_project.zip" # archive is unpacked into directory - uris (List[str]): A list of URIs that define the working_dir. + py_modules (List[Path]): Similar to working_dir, but specifies python + modules to add to the `sys.path`. + Examples: + ["/path/to/other_module", "/other_path/local_project.zip"] pip (List[str] | str): Either a list of pip packages, or a string containing the path to a pip requirements.txt file. conda (dict | str): Either the conda YAML config, the name of a @@ -246,136 +64,170 @@ class ParsedRuntimeEnv(dict): {"OMP_NUM_THREADS": "32", "TF_WARNINGS": "none"} """ - known_fields: Set[str] = { - "working_dir", "conda", "pip", "uris", "containers", "excludes", - "env_vars", "_ray_release", "_ray_commit", "_inject_current_ray", - "plugins" - } - def __init__(self, - runtime_env: Dict[str, Any], - is_task_or_actor: bool = False, - _validate: bool = True): - super().__init__() - - # Blindly trust that the runtime_env has already been validated. - # This is dangerous and should only be used internally (e.g., on the - # deserialization codepath. - if not _validate: - self.update(runtime_env) - return - - if runtime_env.get("conda") and runtime_env.get("pip"): - raise ValueError( - "The 'pip' field and 'conda' field of " - "runtime_env cannot both be specified.\n" - f"specified pip field: {runtime_env['pip']}\n" - f"specified conda field: {runtime_env['conda']}\n" - "To use pip with conda, please only set the 'conda' " - "field, and specify your pip dependencies " - "within the conda YAML config dict: see " - "https://conda.io/projects/conda/en/latest/" - "user-guide/tasks/manage-environments.html" - "#create-env-file-manually") - - for option, validate_fn in OPTION_TO_VALIDATION_FN.items(): - option_val = runtime_env.get(option) - if option_val is not None: - validated_option_val = validate_fn( - option_val, is_task_or_actor=is_task_or_actor) - if validated_option_val is not None: - self[option] = validated_option_val - - if "_ray_release" in runtime_env: - self["_ray_release"] = runtime_env["_ray_release"] - - if "_ray_commit" in runtime_env: - self["_ray_commit"] = runtime_env["_ray_commit"] + runtime_env_json: dict, + working_dir: Optional[str] = None): + # Simple dictionary with all options validated. This will always + # contain all supported keys; values will be set to None if + # unspecified. However, if all values are None this is set to {}. + self._dict = dict() + + if "working_dir" in runtime_env_json: + self._dict["working_dir"] = runtime_env_json["working_dir"] + if not isinstance(self._dict["working_dir"], str): + raise TypeError("`working_dir` must be a string. Type " + f"{type(self._dict['working_dir'])} received.") + working_dir = Path(self._dict["working_dir"]).absolute() + else: + self._dict["working_dir"] = None + working_dir = Path(working_dir).absolute() if working_dir else None + + self._dict["conda"] = None + if "conda" in runtime_env_json: + if sys.platform == "win32": + raise NotImplementedError("The 'conda' field in runtime_env " + "is not currently supported on " + "Windows.") + conda = runtime_env_json["conda"] + if isinstance(conda, str): + yaml_file = Path(conda) + if yaml_file.suffix in (".yaml", ".yml"): + if working_dir and not yaml_file.is_absolute(): + yaml_file = working_dir / yaml_file + if not yaml_file.is_file(): + raise ValueError( + f"Can't find conda YAML file {yaml_file}") + try: + self._dict["conda"] = yaml.safe_load( + yaml_file.read_text()) + except Exception as e: + raise ValueError( + f"Invalid conda file {yaml_file} with error {e}") + else: + logger.info( + f"Using preinstalled conda environment: {conda}") + self._dict["conda"] = conda + elif isinstance(conda, dict): + self._dict["conda"] = conda + elif conda is not None: + raise TypeError("runtime_env['conda'] must be of type str or " + "dict") + + self._dict["pip"] = None + if "pip" in runtime_env_json: + if sys.platform == "win32": + raise NotImplementedError("The 'pip' field in runtime_env " + "is not currently supported on " + "Windows.") + if ("conda" in runtime_env_json + and runtime_env_json["conda"] is not None): + raise ValueError( + "The 'pip' field and 'conda' field of " + "runtime_env cannot both be specified.\n" + f"specified pip field: {runtime_env_json['pip']}\n" + f"specified conda field: {runtime_env_json['conda']}\n" + "To use pip with conda, please only set the 'conda' " + "field, and specify your pip dependencies " + "within the conda YAML config dict: see " + "https://conda.io/projects/conda/en/latest/" + "user-guide/tasks/manage-environments.html" + "#create-env-file-manually") + pip = runtime_env_json["pip"] + if isinstance(pip, str): + # We have been given a path to a requirements.txt file. + pip_file = Path(pip) + if working_dir and not pip_file.is_absolute(): + pip_file = working_dir / pip_file + if not pip_file.is_file(): + raise ValueError(f"{pip_file} is not a valid file") + self._dict["pip"] = pip_file.read_text() + elif isinstance(pip, list) and all( + isinstance(dep, str) for dep in pip): + # Construct valid pip requirements.txt from list of packages. + self._dict["pip"] = "\n".join(pip) + "\n" + else: + raise TypeError("runtime_env['pip'] must be of type str or " + "List[str]") + + if "uris" in runtime_env_json: + self._dict["uris"] = runtime_env_json["uris"] + + if "container" in runtime_env_json: + self._dict["container"] = runtime_env_json["container"] + + self._dict["env_vars"] = None + if "env_vars" in runtime_env_json: + env_vars = runtime_env_json["env_vars"] + self._dict["env_vars"] = env_vars + if not (isinstance(env_vars, dict) and all( + isinstance(k, str) and isinstance(v, str) + for (k, v) in env_vars.items())): + raise TypeError("runtime_env['env_vars'] must be of type" + "Dict[str, str]") + + if "_ray_release" in runtime_env_json: + self._dict["_ray_release"] = runtime_env_json["_ray_release"] + + if "_ray_commit" in runtime_env_json: + self._dict["_ray_commit"] = runtime_env_json["_ray_commit"] else: - if self.get("pip") or self.get("conda"): - self["_ray_commit"] = ray.__commit__ + if self._dict.get("pip") or self._dict.get("conda"): + self._dict["_ray_commit"] = ray.__commit__ # Used for testing wheels that have not yet been merged into master. # If this is set to True, then we do not inject Ray into the conda # or pip dependencies. - if "_inject_current_ray" in runtime_env: - self["_inject_current_ray"] = runtime_env["_inject_current_ray"] - elif "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ: - self["_inject_current_ray"] = True - - if "plugins" in runtime_env: - self["plugins"] = dict() - for class_path, plugin_field in runtime_env["plugins"].items(): - plugin_class: RuntimeEnvPlugin = import_attr(class_path) - if not issubclass(plugin_class, RuntimeEnvPlugin): - # TODO(simon): move the inferface to public once ready. - raise TypeError( - f"{class_path} must be inherit from " - "ray._private.runtime_env.plugin.RuntimeEnvPlugin.") - # TODO(simon): implement uri support. - _ = plugin_class.validate(runtime_env) - # Validation passed, add the entry to parsed runtime env. - self["plugins"][class_path] = plugin_field + if os.environ.get("RAY_RUNTIME_ENV_LOCAL_DEV_MODE"): + runtime_env_json["_inject_current_ray"] = True + if "_inject_current_ray" in runtime_env_json: + self._dict["_inject_current_ray"] = runtime_env_json[ + "_inject_current_ray"] - unknown_fields = ( - set(runtime_env.keys()) - ParsedRuntimeEnv.known_fields) - if len(unknown_fields): - logger.warning( - "The following unknown entries in the runtime_env dictionary " - f"will be ignored: {unknown_fields}. If you intended to use " - "them as plugins, they must be nested in the `plugins` field.") + # TODO(ekl) we should have better schema validation here. + # TODO(ekl) support py_modules + # TODO(architkulkarni) support docker # TODO(architkulkarni) This is to make it easy for the worker caching # code in C++ to check if the env is empty without deserializing and # parsing it. We should use a less confusing approach here. - if all(val is None for val in self.values()): + if all(val is None for val in self._dict.values()): self._dict = {} - @classmethod - def deserialize(cls, serialized: str) -> "ParsedRuntimeEnv": - return cls(json.loads(serialized), _validate=False) + def get_parsed_dict(self) -> dict: + return self._dict def serialize(self) -> str: - # Sort the keys we can compare the serialized string for equality. - return json.dumps(self, sort_keys=True) + # Use sort_keys=True because we will use the output as a key to cache + # workers by, so we need the serialization to be independent of the + # dict order. + return json.dumps(self._dict, sort_keys=True) + def set_uris(self, uris): + self._dict["uris"] = uris -def override_task_or_actor_runtime_env( - child_runtime_env: ParsedRuntimeEnv, - parent_runtime_env: ParsedRuntimeEnv) -> ParsedRuntimeEnv: - """Merge the given child runtime env with the parent runtime env. - - If running in a driver, the current runtime env comes from the - JobConfig. Otherwise, we are running in a worker for an actor or - task, and the current runtime env comes from the current TaskSpec. - - By default, the child runtime env inherits non-specified options from the - parent. There are two exceptions to this: - - working_dir is not inherited (only URIs). - - The env_vars dictionaries are merged, so environment variables - not specified by the child are still inherited from the parent. - - Returns: - The resulting merged ParsedRuntimeEnv. - """ - assert child_runtime_env is not None - assert parent_runtime_env is not None - - # Override environment variables. - result_env_vars = copy.deepcopy(parent_runtime_env.get("env_vars") or {}) - child_env_vars = child_runtime_env.get("env_vars") or {} - result_env_vars.update(child_env_vars) - # Inherit all other non-specified options from the parent. - result = copy.deepcopy(parent_runtime_env) - result.update(child_runtime_env) - if len(result_env_vars) > 0: - result["env_vars"] = result_env_vars - if "working_dir" in result: - del result["working_dir"] # working_dir should not be in child env. +def override_task_or_actor_runtime_env( + runtime_env: Optional[Dict[str, Any]], + parent_runtime_env: Dict[str, Any]) -> Dict[str, Any]: + if runtime_env: + if runtime_env.get("working_dir"): + raise NotImplementedError( + "Overriding working_dir for actors is not supported. " + "Please use ray.init(runtime_env={'working_dir': ...}) " + "to configure per-job environment instead.") + # NOTE(edoakes): this is sort of hacky, but we pass in the parent + # working_dir here so the relative path to a requirements.txt file + # works. The right solution would be to merge the runtime_env with the + # parent runtime env before validation. + runtime_env_dict = RuntimeEnvDict( + runtime_env, working_dir=parent_runtime_env.get( + "working_dir")).get_parsed_dict() + else: + runtime_env_dict = {} - # NOTE(architkulkarni): This allows worker caching code in C++ to - # check if a runtime env is empty without deserializing it. - assert all(val is not None for val in result.values()) + # If per-actor URIs aren't specified, override them with those in the + # job config. + if "uris" not in runtime_env_dict and "uris" in parent_runtime_env: + runtime_env_dict["uris"] = parent_runtime_env.get("uris") - return result + return runtime_env_dict diff --git a/python/ray/_private/runtime_env/working_dir.py b/python/ray/_private/runtime_env/working_dir.py index e5034caf74a27..964cf4aafcf5d 100644 --- a/python/ray/_private/runtime_env/working_dir.py +++ b/python/ray/_private/runtime_env/working_dir.py @@ -15,7 +15,7 @@ _internal_kv_initialized) from ray.job_config import JobConfig from ray._private.thirdparty.pathspec import PathSpec -from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env import RuntimeEnvContext default_logger = logging.getLogger(__name__) diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 4bf6c9bc9e055..8d135129b78fa 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -21,7 +21,6 @@ import ray import ray.ray_constants as ray_constants import redis -from ray.core.generated.common_pb2 import Language # Import psutil and colorama after ray so the packaged version is used. import colorama @@ -399,11 +398,6 @@ def node_ip_address_from_perspective(address): def get_node_ip_address(address="8.8.8.8:53"): if ray.worker._global_node is not None: return ray.worker._global_node.node_ip_address - if sys.platform == "darwin": - # Due to the mac osx firewall, - # we use loopback ip as the ip address - # to prevent security popups. - return "127.0.0.1" return node_ip_address_from_perspective(address) @@ -872,8 +866,7 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, fate_share=fate_share, - port_denylist=port_denylist, - listen_to_localhost_only=(node_ip_address == "127.0.0.1")) + port_denylist=port_denylist) processes.append(p) redis_address = address(node_ip_address, port) primary_redis_client = redis.StrictRedis( @@ -929,8 +922,7 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, fate_share=fate_share, - port_denylist=port_denylist, - listen_to_localhost_only=(node_ip_address == "127.0.0.1")) + port_denylist=port_denylist) processes.append(p) shard_address = address(node_ip_address, redis_shard_port) @@ -952,8 +944,7 @@ def _start_redis_instance(executable, password=None, redis_max_memory=None, fate_share=None, - port_denylist=None, - listen_to_localhost_only=False): + port_denylist=None): """Start a single Redis server. Notes: @@ -979,9 +970,6 @@ def _start_redis_instance(executable, will start LRU eviction of entries. port_denylist (set): A set of denylist ports that shouldn't be used when allocating a new port. - listen_to_localhost_only (bool): Redis server only listens to - localhost (127.0.0.1) if it's true, - otherwise it listens to all network interfaces. Returns: A tuple of the port used by Redis and ProcessInfo for the process that @@ -1002,8 +990,6 @@ def _start_redis_instance(executable, raise ValueError("Spaces not permitted in redis password.") command += ["--requirepass", password] command += (["--port", str(port), "--loglevel", "warning"]) - if listen_to_localhost_only: - command += ["--bind", "127.0.0.1"] process_info = start_ray_process( command, ray_constants.PROCESS_TYPE_REDIS_SERVER, @@ -1361,8 +1347,7 @@ def start_raylet(redis_address, start_initial_python_workers_for_first_job=False, max_bytes=0, backup_count=0, - ray_debugger_external=False, - env_updates=None): + ray_debugger_external=False): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -1375,8 +1360,8 @@ def start_raylet(redis_address, to. worker_path (str): The path of the Python file that new worker processes will execute. - setup_worker_path (str): The path of the Python file that will set up - the environment for the worker process. + setup_worker_path (str): The path of the Python file that will run + worker_setup_hook to set up the environment for the worker process. temp_dir (str): The path of the temporary directory Ray will use. session_dir (str): The path of this session. resource_dir(str): The path of resource of this session . @@ -1408,8 +1393,6 @@ def start_raylet(redis_address, RotatingFileHandler's backupCount. ray_debugger_external (bool): True if the Ray debugger should be made available externally to this node. - env_updates (dict): Environment variable overrides. - Returns: ProcessInfo for the process that was started. """ @@ -1454,7 +1437,6 @@ def start_raylet(redis_address, redis_password, session_dir, node_ip_address, - setup_worker_path, ) else: java_worker_command = [] @@ -1585,8 +1567,7 @@ def check_should_start_agent(): use_perftools_profiler=("RAYLET_PERFTOOLS_PATH" in os.environ), stdout_file=stdout_file, stderr_file=stderr_file, - fate_share=fate_share, - env_updates=env_updates) + fate_share=fate_share) return process_info @@ -1610,7 +1591,6 @@ def build_java_worker_command( redis_password, session_dir, node_ip_address, - setup_worker_path, ): """This method assembles the command used to start a Java worker. @@ -1622,8 +1602,6 @@ def build_java_worker_command( redis_password (str): The password of connect to redis. session_dir (str): The path of this session. node_ip_address (str): The ip address for this node. - setup_worker_path (str): The path of the Python file that will set up - the environment for the worker process. Returns: The command string for starting Java worker. """ @@ -1648,9 +1626,7 @@ def build_java_worker_command( pairs.append(("ray.home", RAY_HOME)) pairs.append(("ray.logging.dir", os.path.join(session_dir, "logs"))) pairs.append(("ray.session-dir", session_dir)) - command = [sys.executable] + [setup_worker_path] + ["java"] + [ - "-D{}={}".format(*pair) for pair in pairs - ] + command = ["java"] + ["-D{}={}".format(*pair) for pair in pairs] # Add ray jars path to java classpath ray_jars = os.path.join(get_ray_jars_dir(), "*") @@ -1932,14 +1908,9 @@ def start_ray_client_server( ray_constants.SETUP_WORKER_FILENAME) command = [ - sys.executable, - setup_worker_path, - "-m", - "ray.util.client.server", - f"--redis-address={redis_address}", - f"--port={ray_client_server_port}", - f"--mode={server_type}", - f"--language={Language.Name(Language.PYTHON)}", + sys.executable, setup_worker_path, "-m", "ray.util.client.server", + f"--redis-address={redis_address}", f"--port={ray_client_server_port}", + f"--mode={server_type}" ] if redis_password: command.append(f"--redis-password={redis_password}") diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 8da4ac9f03e69..50bb3d13c008b 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -7,25 +7,24 @@ import pathlib import subprocess import sys +import tempfile import time import timeit import math import traceback +import datetime from typing import Optional, Any, List, Dict from contextlib import redirect_stdout, redirect_stderr import yaml import socket import pytest -import tempfile import ray import ray._private.services import ray._private.utils import ray._private.gcs_utils as gcs_utils -from ray._private.tls_utils import generate_self_signed_tls_certs from ray.util.queue import Queue, _QueueActor, Empty from ray.scripts.scripts import main as ray_main - try: from prometheus_client.parser import text_string_to_metric_families except (ImportError, ModuleNotFoundError): @@ -691,11 +690,57 @@ async def get_batch(self, return batch -def is_placement_group_removed(pg): - table = ray.util.placement_group_table(pg) - if "state" not in table: - return False - return table["state"] == "REMOVED" +def generate_self_signed_tls_certs(): + """Create self-signed key/cert pair for testing. + + This method requires the library ``cryptography`` be installed. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + except ImportError: + raise ImportError( + "Using `Security.temporary` requires `cryptography`, please " + "install it using either pip or conda") + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend()) + key_contents = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + ray_interal = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) + # This is the same logic used by the GCS server to acquire a + # private/interal IP address to listen on. If we just use localhost + + # 127.0.0.1 then we won't be able to connect to the GCS and will get + # an error like "No match found for server name: 192.168.X.Y" + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + private_ip_address = s.getsockname()[0] + s.close() + altnames = x509.SubjectAlternativeName([ + x509.DNSName(socket.gethostbyname( + socket.gethostname())), # Probably 127.0.0.1 + x509.DNSName("127.0.0.1"), + x509.DNSName(private_ip_address), # 192.168.*.* + x509.DNSName("localhost"), + ]) + now = datetime.datetime.utcnow() + cert = (x509.CertificateBuilder() + .subject_name(ray_interal).issuer_name(ray_interal).add_extension( + altnames, critical=False).public_key(key.public_key()) + .serial_number(x509.random_serial_number()).not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=365)).sign( + key, hashes.SHA256(), default_backend())) + + cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() + + return cert_contents, key_contents def setup_tls(): @@ -727,3 +772,10 @@ def teardown_tls(key_filepath, cert_filepath, temp_dir): del os.environ["RAY_TLS_SERVER_CERT"] del os.environ["RAY_TLS_SERVER_KEY"] del os.environ["RAY_TLS_CA_CERT"] + + +def is_placement_group_removed(pg): + table = ray.util.placement_group_table(pg) + if "state" not in table: + return False + return table["state"] == "REMOVED" diff --git a/python/ray/_private/tls_utils.py b/python/ray/_private/tls_utils.py deleted file mode 100644 index 8344d86c30c4b..0000000000000 --- a/python/ray/_private/tls_utils.py +++ /dev/null @@ -1,85 +0,0 @@ -import datetime -import os -import socket - -import grpc - - -def generate_self_signed_tls_certs(): - """Create self-signed key/cert pair for testing. - - This method requires the library ``cryptography`` be installed. - """ - try: - from cryptography import x509 - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import hashes, serialization - from cryptography.hazmat.primitives.asymmetric import rsa - from cryptography.x509.oid import NameOID - except ImportError: - raise ImportError( - "Using `Security.temporary` requires `cryptography`, please " - "install it using either pip or conda") - key = rsa.generate_private_key( - public_exponent=65537, key_size=2048, backend=default_backend()) - key_contents = key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ).decode() - - ray_interal = x509.Name( - [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) - # This is the same logic used by the GCS server to acquire a - # private/interal IP address to listen on. If we just use localhost + - # 127.0.0.1 then we won't be able to connect to the GCS and will get - # an error like "No match found for server name: 192.168.X.Y" - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.connect(("8.8.8.8", 80)) - private_ip_address = s.getsockname()[0] - s.close() - altnames = x509.SubjectAlternativeName([ - x509.DNSName(socket.gethostbyname( - socket.gethostname())), # Probably 127.0.0.1 - x509.DNSName("127.0.0.1"), - x509.DNSName(private_ip_address), # 192.168.*.* - x509.DNSName("localhost"), - ]) - now = datetime.datetime.utcnow() - cert = (x509.CertificateBuilder().subject_name(ray_interal).issuer_name( - ray_interal).add_extension(altnames, critical=False).public_key( - key.public_key()).serial_number( - x509.random_serial_number()).not_valid_before(now) - .not_valid_after(now + datetime.timedelta(days=365)).sign( - key, hashes.SHA256(), default_backend())) - - cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() - - return cert_contents, key_contents - - -def add_port_to_grpc_server(server, address): - if os.environ.get("RAY_USE_TLS", "0") == "1": - server_cert_chain, private_key, ca_cert = load_certs_from_env() - credentials = grpc.ssl_server_credentials( - [(private_key, server_cert_chain)], - root_certificates=ca_cert, - require_client_auth=ca_cert is not None) - return server.add_secure_port(address, credentials) - else: - return server.add_insecure_port(address) - - -def load_certs_from_env(): - if os.environ.get("RAY_USE_TLS", "0") == "1": - with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: - server_cert_chain = f.read() - with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: - private_key = f.read() - if "RAY_TLS_CA_CERT" in os.environ: - with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: - ca_cert = f.read() - else: - ca_cert = None - - return server_cert_chain, private_key, ca_cert diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 50fe38ed65f74..37430d928dd92 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -27,7 +27,6 @@ import ray import ray._private.gcs_utils as gcs_utils import ray.ray_constants as ray_constants -from ray._private.tls_utils import load_certs_from_env # Import psutil after ray so the packaged version is used. import psutil @@ -1112,6 +1111,21 @@ def validate_namespace(namespace: str): "Pass None to not specify a namespace.") +def load_certs_from_env(): + if os.environ.get("RAY_USE_TLS", "0") == "1": + with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: + server_cert_chain = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: + private_key = f.read() + if "RAY_TLS_CA_CERT" in os.environ: + with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: + ca_cert = f.read() + else: + ca_cert = None + + return server_cert_chain, private_key, ca_cert + + def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] = None, asynchronous: bool = False): @@ -1128,3 +1142,15 @@ def init_grpc_channel(address: str, channel = grpc_module.insecure_channel(address, options=options) return channel + + +def add_port_to_grpc_server(server, address): + if os.environ.get("RAY_USE_TLS", "0") == "1": + server_cert_chain, private_key, ca_cert = load_certs_from_env() + credentials = grpc.ssl_server_credentials( + [(private_key, server_cert_chain)], + root_certificates=ca_cert, + require_client_auth=ca_cert is not None) + return server.add_secure_port(address, credentials) + else: + return server.add_insecure_port(address) diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 4326d6cf943a9..5c79c3b796459 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -114,7 +114,7 @@ cdef class CoreWorker: object async_event_loop object plasma_event_handler object job_config - object current_runtime_env + object current_runtime_env_dict c_bool is_local_mode cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata, diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index bbc064ec92938..fdb9a7f51fef0 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -7,18 +7,18 @@ from cpython.exc cimport PyErr_CheckSignals import asyncio +import copy import gc import inspect +import threading +import traceback +import time import logging -import msgpack import os import pickle -import setproctitle import sys -import threading -import time -import traceback import _thread +import setproctitle from libc.stdint cimport ( int32_t, @@ -100,6 +100,13 @@ from ray.includes.ray_config cimport RayConfig from ray.includes.global_state_accessor cimport CGlobalStateAccessor import ray +import ray._private.gcs_utils as gcs_utils +from ray import external_storage +from ray._private.async_compat import ( + sync_to_async, get_new_event_loop) +import ray._private.memory_monitor as memory_monitor +import ray.ray_constants as ray_constants +import ray._private.profiling as profiling from ray.exceptions import ( RayActorError, RayError, @@ -110,15 +117,11 @@ from ray.exceptions import ( TaskCancelledError, AsyncioActorExit, ) -from ray import external_storage -import ray.ray_constants as ray_constants -from ray._private.async_compat import sync_to_async, get_new_event_loop -from ray._private.client_mode_hook import disable_client_hook -import ray._private.gcs_utils as gcs_utils -from ray._private.runtime_env.validation import ParsedRuntimeEnv -import ray._private.memory_monitor as memory_monitor -import ray._private.profiling as profiling from ray._private.utils import decode +from ray._private.client_mode_hook import ( + disable_client_hook, +) +import msgpack cimport cpython @@ -1350,8 +1353,8 @@ cdef class CoreWorker: int64_t placement_group_bundle_index, c_bool placement_group_capture_child_tasks, c_string debugger_breakpoint, - c_string serialized_runtime_env, - runtime_env_uris, + runtime_env_dict, + override_environment_variables ): cdef: unordered_map[c_string, double] c_resources @@ -1359,10 +1362,15 @@ cdef class CoreWorker: c_vector[unique_ptr[CTaskArg]] args_vector CPlacementGroupID c_placement_group_id = \ placement_group_id.native() - c_vector[c_string] c_runtime_env_uris = runtime_env_uris + c_string c_serialized_runtime_env + unordered_map[c_string, c_string] \ + c_override_environment_variables = \ + override_environment_variables c_vector[CObjectReference] return_refs with self.profile_event(b"submit_task"): + c_serialized_runtime_env = \ + self.prepare_runtime_env(runtime_env_dict) prepare_resources(resources, &c_resources) ray_function = CRayFunction( language.lang, function_descriptor.descriptor) @@ -1375,8 +1383,8 @@ cdef class CoreWorker: ray_function, args_vector, CTaskOptions( name, num_returns, c_resources, b"", - serialized_runtime_env, - c_runtime_env_uris), + c_serialized_runtime_env, + c_override_environment_variables), max_retries, retry_exceptions, c_pair[CPlacementGroupID, int64_t]( c_placement_group_id, placement_group_bundle_index), @@ -1402,8 +1410,8 @@ cdef class CoreWorker: int64_t placement_group_bundle_index, c_bool placement_group_capture_child_tasks, c_string extension_data, - c_string serialized_runtime_env, - runtime_env_uris, + runtime_env_dict, + override_environment_variables ): cdef: CRayFunction ray_function @@ -1414,9 +1422,14 @@ cdef class CoreWorker: CActorID c_actor_id CPlacementGroupID c_placement_group_id = \ placement_group_id.native() - c_vector[c_string] c_runtime_env_uris = runtime_env_uris + c_string c_serialized_runtime_env + unordered_map[c_string, c_string] \ + c_override_environment_variables = \ + override_environment_variables with self.profile_event(b"submit_task"): + c_serialized_runtime_env = \ + self.prepare_runtime_env(runtime_env_dict) prepare_resources(resources, &c_resources) prepare_resources(placement_resources, &c_placement_resources) ray_function = CRayFunction( @@ -1436,8 +1449,8 @@ cdef class CoreWorker: c_placement_group_id, placement_group_bundle_index), placement_group_capture_child_tasks, - serialized_runtime_env, - c_runtime_env_uris), + c_serialized_runtime_env, + c_override_environment_variables), extension_data, &c_actor_id)) @@ -1712,11 +1725,12 @@ cdef class CoreWorker: return CCoreWorkerProcess.GetCoreWorker().GetOwnerAddress( c_object_id).SerializeAsString() - def serialize_object_ref(self, ObjectRef object_ref): + def serialize_and_promote_object_ref(self, ObjectRef object_ref): cdef: CObjectID c_object_id = object_ref.native() CAddress c_owner_address = CAddress() c_string serialized_object_status + CCoreWorkerProcess.GetCoreWorker().PromoteObjectToPlasma(c_object_id) CCoreWorkerProcess.GetCoreWorker().GetOwnershipInfo( c_object_id, &c_owner_address, &serialized_object_status) return (object_ref, @@ -1847,20 +1861,19 @@ cdef class CoreWorker: return (CCoreWorkerProcess.GetCoreWorker().GetWorkerContext() .CurrentActorIsAsync()) - def get_current_runtime_env(self) -> ParsedRuntimeEnv: + def get_current_runtime_env_dict(self): # This should never change, so we can safely cache it to avoid ser/de - if self.current_runtime_env is None: + if self.current_runtime_env_dict is None: if self.is_driver: - job_config = self.get_job_config() - serialized_env = job_config.runtime_env.serialized_runtime_env + self.current_runtime_env_dict = \ + json.loads(self.get_job_config().serialized_runtime_env) else: - serialized_env = CCoreWorkerProcess.GetCoreWorker() \ - .GetWorkerContext().GetCurrentSerializedRuntimeEnv() - - self.current_runtime_env = ParsedRuntimeEnv.deserialize( - serialized_env) - - return self.current_runtime_env + self.current_runtime_env_dict = json.loads( + CCoreWorkerProcess.GetCoreWorker() + .GetWorkerContext() + .GetCurrentSerializedRuntimeEnv() + ) + return self.current_runtime_env_dict def is_exiting(self): return CCoreWorkerProcess.GetCoreWorker().IsExiting() @@ -1888,26 +1901,6 @@ cdef class CoreWorker: return ref_counts - def get_actor_call_stats(self): - cdef: - unordered_map[c_string, c_vector[uint64_t]] c_tasks_count - - c_tasks_count = ( - CCoreWorkerProcess.GetCoreWorker().GetActorCallStats()) - it = c_tasks_count.begin() - - tasks_count = dict() - while it != c_tasks_count.end(): - func_name = dereference(it).first - counters = dereference(it).second - tasks_count[func_name] = { - "pending": counters[0], - "running": counters[1], - "finished": counters[2], - } - postincrement(it) - return tasks_count - def set_get_async_callback(self, ObjectRef object_ref, callback): cpython.Py_INCREF(callback) CCoreWorkerProcess.GetCoreWorker().GetAsync( @@ -1932,6 +1925,45 @@ cdef class CoreWorker: self.job_config.ParseFromString(c_job_config.SerializeAsString()) return self.job_config + def prepare_runtime_env(self, runtime_env_dict: dict) -> str: + """Merge the given new runtime env with the current runtime env. + + If running in a driver, the current runtime env comes from the + JobConfig. Otherwise, we are running in a worker for an actor or + task, and the current runtime env comes from the current TaskSpec. + + The child's runtime env dict is merged with the parents via a simple + dict update, except for runtime_env["env_vars"], which is merged + with runtime_env["env_vars"] of the parent rather than overwriting it. + This is so that env vars set in the parent propagate to child actors + and tasks even if a new env var is set in the child. + + Args: + runtime_env_dict (dict): A runtime env for a child actor or task. + Returns: + The resulting merged JSON-serialized runtime env. + """ + + result_dict = copy.deepcopy(self.get_current_runtime_env_dict()) + + result_env_vars = copy.deepcopy(result_dict.get("env_vars") or {}) + child_env_vars = runtime_env_dict.get("env_vars") or {} + result_env_vars.update(child_env_vars) + + result_dict.update(runtime_env_dict) + result_dict["env_vars"] = result_env_vars + + # NOTE(architkulkarni): This allows worker caching code in C++ to + # check if a runtime env is empty without deserializing it. + if result_dict["env_vars"] == {}: + result_dict["env_vars"] = None + if all(val is None for val in result_dict.values()): + result_dict = {} + + # TODO(architkulkarni): We should just use RuntimeEnvDict here + # so all the serialization and validation is done in one place + return json.dumps(result_dict, sort_keys=True) + def get_task_submission_stats(self): cdef: int64_t num_tasks_submitted diff --git a/python/ray/actor.py b/python/ray/actor.py index f228389da72e0..faec5fccc7dd7 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -5,8 +5,7 @@ import ray.ray_constants as ray_constants import ray._raylet import ray._private.signature as signature -from ray._private.runtime_env.validation import ( - override_task_or_actor_runtime_env, ParsedRuntimeEnv) +import ray._private.runtime_env as runtime_support import ray.worker from ray.util.annotations import PublicAPI from ray.util.placement_group import ( @@ -32,7 +31,7 @@ @PublicAPI -@client_mode_hook(auto_init=False) +@client_mode_hook def method(*args, **kwargs): """Annotate an actor method. @@ -389,17 +388,11 @@ class DerivedActorClass(cls, modified_class): PythonFunctionDescriptor.from_class( modified_class.__ray_actor_class__) - # Parse local pip/conda config files here. If we instead did it in - # .remote(), it would get run in the Ray Client server, which runs on - # a remote node where the files aren't available. - new_runtime_env = ParsedRuntimeEnv( - runtime_env or {}, is_task_or_actor=True) - self.__ray_metadata__ = ActorClassMetadata( Language.PYTHON, modified_class, actor_creation_function_descriptor, class_id, max_restarts, max_task_retries, num_cpus, num_gpus, memory, object_store_memory, - resources, accelerator_type, new_runtime_env) + resources, accelerator_type, runtime_env) return self @@ -410,15 +403,10 @@ def _ray_from_function_descriptor( resources, accelerator_type, runtime_env): self = ActorClass.__new__(ActorClass) - # Parse local pip/conda config files here. If we instead did it in - # .remote(), it would get run in the Ray Client server, which runs on - # a remote node where the files aren't available. - new_runtime_env = ParsedRuntimeEnv( - runtime_env or {}, is_task_or_actor=True) self.__ray_metadata__ = ActorClassMetadata( language, None, actor_creation_function_descriptor, None, max_restarts, max_task_retries, num_cpus, num_gpus, memory, - object_store_memory, resources, accelerator_type, new_runtime_env) + object_store_memory, resources, accelerator_type, runtime_env) return self @@ -454,7 +442,8 @@ def options(self, placement_group="default", placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, - runtime_env=None): + runtime_env=None, + override_environment_variables=None): """Configures and overrides the actor instantiation parameters. The arguments are the same as those that can be passed @@ -475,12 +464,6 @@ def method(self): actor_cls = self - # Parse local pip/conda config files here. If we instead did it in - # .remote(), it would get run in the Ray Client server, which runs on - # a remote node where the files aren't available. - new_runtime_env = ParsedRuntimeEnv( - runtime_env or {}, is_task_or_actor=True) - class ActorOptionWrapper: def remote(self, *args, **kwargs): return actor_cls._remote( @@ -502,7 +485,9 @@ def remote(self, *args, **kwargs): placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), - runtime_env=new_runtime_env) + runtime_env=runtime_env, + override_environment_variables=( + override_environment_variables)) return ActorOptionWrapper() @@ -525,7 +510,8 @@ def _remote(self, placement_group="default", placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, - runtime_env=None): + runtime_env=None, + override_environment_variables=None): """Create an actor. This method allows more flexibility than the remote method because @@ -571,6 +557,9 @@ def _remote(self, this actor or task and its children (see :ref:`runtime-environments` for details). This API is in beta and may change before becoming stable. + override_environment_variables: Environment variables to override + and/or introduce for this actor. This is a dictionary mapping + variable names to their values. Returns: A handle to the newly created actor. @@ -595,7 +584,7 @@ def _remote(self, if max_concurrency < 1: raise ValueError("max_concurrency must be >= 1") - if client_mode_should_convert(auto_init=True): + if client_mode_should_convert(): return client_mode_convert_actor( self, args, @@ -616,7 +605,9 @@ def _remote(self, placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), - runtime_env=runtime_env) + runtime_env=runtime_env, + override_environment_variables=( + override_environment_variables)) worker = ray.worker.global_worker worker.check_connected() @@ -732,16 +723,18 @@ def _remote(self, creation_args = signature.flatten_args(function_signature, args, kwargs) - if runtime_env and not isinstance(runtime_env, ParsedRuntimeEnv): - runtime_env = ParsedRuntimeEnv(runtime_env) - elif isinstance(runtime_env, ParsedRuntimeEnv): - pass - else: + if runtime_env is None: runtime_env = meta.runtime_env - parent_runtime_env = worker.core_worker.get_current_runtime_env() - parsed_runtime_env = override_task_or_actor_runtime_env( - runtime_env, parent_runtime_env) + job_runtime_env = worker.core_worker.get_current_runtime_env_dict() + runtime_env_dict = runtime_support.override_task_or_actor_runtime_env( + runtime_env, job_runtime_env) + + if override_environment_variables: + logger.warning("override_environment_variables is deprecated and " + "will be removed in Ray 1.6. Please use " + ".options(runtime_env={'env_vars': {...}}).remote()" + "instead.") actor_id = worker.core_worker.create_actor( meta.language, @@ -761,8 +754,9 @@ def _remote(self, placement_group_capture_child_tasks, # Store actor_method_cpu in actor handle's extension data. extension_data=str(actor_method_cpu), - serialized_runtime_env=parsed_runtime_env.serialize(), - runtime_env_uris=parsed_runtime_env.get("uris") or []) + runtime_env_dict=runtime_env_dict, + override_environment_variables=override_environment_variables + or dict()) actor_handle = ActorHandle( meta.language, diff --git a/python/ray/autoscaler/_private/autoscaler.py b/python/ray/autoscaler/_private/autoscaler.py index b153cff37d259..3b26b845e8070 100644 --- a/python/ray/autoscaler/_private/autoscaler.py +++ b/python/ray/autoscaler/_private/autoscaler.py @@ -165,12 +165,6 @@ def read_fn(): self.disable_node_updaters = self.config["provider"].get( "disable_node_updaters", False) - # Disable launch config checking if true. - # This is set in the fake_multinode situations where there isn't any - # meaningful node "type" to enforce. - self.disable_launch_config_check = self.config["provider"].get( - "disable_launch_config_check", False) - # Node launchers self.launch_queue = queue.Queue() self.pending_launches = ConcurrentCounter() @@ -491,8 +485,7 @@ def _report_pending_infeasible(self, unfulfilled: List[ResourceDict]): pending = [] infeasible = [] for bundle in unfulfilled: - placement_group = any( - "_group_" in k or k == "bundle" for k in bundle) + placement_group = any("_group_" in k for k in bundle) if placement_group: continue if self.resource_demand_scheduler.is_feasible(bundle): @@ -634,6 +627,7 @@ def _keep_worker_of_node_type(self, node_id: NodeID, Return KeepOrTerminate.decide_later otherwise. + Args: node_type_counts(Dict[NodeType, int]): The non_terminated node types counted so far. @@ -763,8 +757,6 @@ def reset(self, errors_fatal=False): "Error parsing config.") def launch_config_ok(self, node_id): - if self.disable_launch_config_check: - return True node_tags = self.provider.node_tags(node_id) tag_launch_conf = node_tags.get(TAG_RAY_LAUNCH_CONFIG) node_type = node_tags.get(TAG_RAY_USER_NODE_TYPE) diff --git a/python/ray/autoscaler/_private/docker.py b/python/ray/autoscaler/_private/docker.py index 92dd16ad5001f..8d94759549217 100644 --- a/python/ray/autoscaler/_private/docker.py +++ b/python/ray/autoscaler/_private/docker.py @@ -18,7 +18,7 @@ def _check_docker_file_mounts(file_mounts: Dict[str, str]) -> None: if Path(local).is_file(): cli_logger.warning( f"File Mount: ({remote}:{local}) refers to a file.\n To ensure" - " this mount updates properly, please use a directory.") + "this mount updates properly, please use a directory.") def validate_docker_config(config: Dict[str, Any]) -> None: diff --git a/python/ray/autoscaler/_private/fake_multi_node/__init__.py b/python/ray/autoscaler/_private/fake_multi_node/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/python/ray/autoscaler/_private/fake_multi_node/example.yaml b/python/ray/autoscaler/_private/fake_multi_node/example.yaml deleted file mode 100644 index 497f89647c25f..0000000000000 --- a/python/ray/autoscaler/_private/fake_multi_node/example.yaml +++ /dev/null @@ -1,55 +0,0 @@ -# Example command to start a cluster with this config: -# -# RAY_FAKE_CLUSTER=1 ray start --autoscaling-config=example.yaml --head --block -# -# Alternatively, you can programmatically create a fake autoscaling cluster -# using ray.cluster_utils.AutoscalingCluster. -cluster_name: fake_multinode -max_workers: 8 -provider: - type: fake_multinode - use_node_id_as_ip: True - disable_node_updaters: True - disable_launch_config_check: True -available_node_types: - ray.head.default: - # You must set this manually to your "head" node resources!! The head - # node is launched via `ray start` and hence the autoscaler cannot - # configure its resources. The resources specified for its node type - # must line up with what Ray detects/is configured with on start. - resources: - CPU: 8 # <-- set this to num CPUs used/detected in `ray start` - GPU: 0 # <-- set this to num GPUs used/detected in `ray start` - node_config: {} - max_workers: 0 - ray.worker.cpu: - resources: - CPU: 1 - object_store_memory: 1000000000 - node_config: {} - min_workers: 0 - max_workers: 4 - ray.worker.gpu: - resources: - CPU: 4 - GPU: 1 - object_store_memory: 1000000000 - node_config: {} - min_workers: 0 - max_workers: 2 -head_node_type: ray.head.default -auth: {} -upscaling_speed: 1.0 -idle_timeout_minutes: 0.1 -docker: {} -initialization_commands: [] -setup_commands: [] -head_setup_commands: [] -worker_setup_commands: [] -head_start_ray_commands: [] -worker_start_ray_commands: [] -file_mounts: {} -cluster_synced_files: [] -file_mounts_sync_continuously: false -rsync_exclude: [] -rsync_filter: [] diff --git a/python/ray/autoscaler/_private/fake_multi_node/node_provider.py b/python/ray/autoscaler/_private/fake_multi_node/node_provider.py deleted file mode 100644 index 71650d845b1e5..0000000000000 --- a/python/ray/autoscaler/_private/fake_multi_node/node_provider.py +++ /dev/null @@ -1,114 +0,0 @@ -import logging -import os -import json - -import ray -from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import (TAG_RAY_NODE_KIND, NODE_KIND_HEAD, - NODE_KIND_WORKER, TAG_RAY_USER_NODE_TYPE, - TAG_RAY_NODE_NAME, TAG_RAY_NODE_STATUS, - STATUS_UP_TO_DATE) - -logger = logging.getLogger(__name__) - -# We generate the node ids deterministically in the fake node provider, so that -# we can associate launched nodes with their resource reports. IDs increment -# starting with fffff*00000 for the head node, fffff*00001, etc. for workers. -FAKE_HEAD_NODE_ID = "fffffffffffffffffffffffffffffffffffffffffffffffffff00000" -FAKE_HEAD_NODE_TYPE = "ray.head.default" - - -class FakeMultiNodeProvider(NodeProvider): - """A node provider that implements multi-node on a single machine. - - This is used for laptop mode testing of autoscaling functionality.""" - - def __init__(self, provider_config, cluster_name): - NodeProvider.__init__(self, provider_config, cluster_name) - if "RAY_FAKE_CLUSTER" not in os.environ: - raise RuntimeError( - "FakeMultiNodeProvider requires ray to be started with " - "RAY_FAKE_CLUSTER=1 ray start ...") - self._nodes = { - FAKE_HEAD_NODE_ID: { - "tags": { - TAG_RAY_NODE_KIND: NODE_KIND_HEAD, - TAG_RAY_USER_NODE_TYPE: FAKE_HEAD_NODE_TYPE, - TAG_RAY_NODE_NAME: FAKE_HEAD_NODE_ID, - TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, - } - }, - } - self._next_node_id = 0 - - def _next_hex_node_id(self): - self._next_node_id += 1 - base = "fffffffffffffffffffffffffffffffffffffffffffffffffff" - return base + str(self._next_node_id).zfill(5) - - def non_terminated_nodes(self, tag_filters): - nodes = [] - for node_id in self._nodes: - tags = self.node_tags(node_id) - ok = True - for k, v in tag_filters.items(): - if tags.get(k) != v: - ok = False - if ok: - nodes.append(node_id) - return nodes - - def is_running(self, node_id): - return node_id in self._nodes - - def is_terminated(self, node_id): - return node_id not in self._nodes - - def node_tags(self, node_id): - return self._nodes[node_id]["tags"] - - def external_ip(self, node_id): - return node_id - - def internal_ip(self, node_id): - return node_id - - def set_node_tags(self, node_id, tags): - raise AssertionError("Readonly node provider cannot be updated") - - def create_node_with_resources(self, node_config, tags, count, resources): - node_type = tags[TAG_RAY_USER_NODE_TYPE] - next_id = self._next_hex_node_id() - ray_params = ray._private.parameter.RayParams( - min_worker_port=0, - max_worker_port=0, - dashboard_port=None, - num_cpus=resources.pop("CPU", 0), - num_gpus=resources.pop("GPU", 0), - object_store_memory=resources.pop("object_store_memory", None), - resources=resources, - redis_address="{}:6379".format( - ray._private.services.get_node_ip_address()), - env_vars={ - "RAY_OVERRIDE_NODE_ID_FOR_TESTING": next_id, - "RAY_OVERRIDE_RESOURCES": json.dumps(resources), - }) - node = ray.node.Node( - ray_params, head=False, shutdown_at_exit=False, spawn_reaper=False) - self._nodes[next_id] = { - "tags": { - TAG_RAY_NODE_KIND: NODE_KIND_WORKER, - TAG_RAY_USER_NODE_TYPE: node_type, - TAG_RAY_NODE_NAME: next_id, - TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, - }, - "node": node - } - - def terminate_node(self, node_id): - node = self._nodes.pop(node_id)["node"] - node.kill_all_processes(check_alive=False, allow_graceful=True) - - @staticmethod - def bootstrap_config(cluster_config): - return cluster_config diff --git a/python/ray/autoscaler/_private/gcp/node.py b/python/ray/autoscaler/_private/gcp/node.py index 69a456ac56c0e..93a9933ddc186 100644 --- a/python/ray/autoscaler/_private/gcp/node.py +++ b/python/ray/autoscaler/_private/gcp/node.py @@ -437,26 +437,8 @@ def create_instance(self, "name": name }) - # Allow Google Compute Engine instance templates. - # - # Config example: - # - # ... - # node_config: - # sourceInstanceTemplate: global/instanceTemplates/worker-16 - # machineType: e2-standard-16 - # ... - # - # node_config parameters override matching template parameters, if any. - # - # https://cloud.google.com/compute/docs/instance-templates - # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert - source_instance_template = config.pop("sourceInstanceTemplate", None) - operation = self.resource.instances().insert( - project=self.project_id, - zone=self.availability_zone, - sourceInstanceTemplate=source_instance_template, + project=self.project_id, zone=self.availability_zone, body=config).execute() if wait_for_operation: diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index b19a8d04c7032..172bd5b74b57d 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -27,8 +27,6 @@ from ray.autoscaler._private.load_metrics import LoadMetrics from ray.autoscaler._private.constants import \ AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE -from ray.autoscaler._private.fake_multi_node.node_provider import \ - FAKE_HEAD_NODE_ID from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS, \ DEBUG_AUTOSCALING_ERROR, format_readonly_node_type @@ -164,10 +162,7 @@ def __init__(self, head_node_ip = redis_address.split(":")[0] self.redis_address = redis_address self.redis_password = redis_password - if os.environ.get("RAY_FAKE_CLUSTER"): - self.load_metrics = LoadMetrics(local_ip=FAKE_HEAD_NODE_ID) - else: - self.load_metrics = LoadMetrics(local_ip=head_node_ip) + self.load_metrics = LoadMetrics(local_ip=head_node_ip) self.last_avail_resources = None self.event_summarizer = EventSummarizer() self.prefix_cluster_info = prefix_cluster_info @@ -228,7 +223,7 @@ def update_load_metrics(self): request = gcs_service_pb2.GetAllResourceUsageRequest() response = self.gcs_node_resources_stub.GetAllResourceUsage( - request, timeout=60) + request, timeout=4) resources_batch_data = response.resource_usage_data # Tell the readonly node provider what nodes to report. @@ -249,7 +244,8 @@ def update_load_metrics(self): resource_message.node_id.hex()) resources = {} for k, v in resource_message.resources_total.items(): - resources[k] = v + if not k.startswith("node:"): + resources[k] = v mirror_node_types[node_type] = { "resources": resources, "node_config": {}, diff --git a/python/ray/autoscaler/_private/node_launcher.py b/python/ray/autoscaler/_private/node_launcher.py index 803b8df37d881..aad9cee326973 100644 --- a/python/ray/autoscaler/_private/node_launcher.py +++ b/python/ray/autoscaler/_private/node_launcher.py @@ -47,8 +47,6 @@ def _launch_node(self, config: Dict[str, Any], count: int, if node_type: launch_config.update( config["available_node_types"][node_type]["node_config"]) - resources = copy.deepcopy( - config["available_node_types"][node_type]["resources"]) launch_hash = hash_launch_conf(launch_config, config["auth"]) self.log("Launching {} nodes, type {}.".format(count, node_type)) node_config = copy.deepcopy(config.get("worker_nodes", {})) @@ -66,8 +64,7 @@ def _launch_node(self, config: Dict[str, Any], count: int, node_tags[TAG_RAY_USER_NODE_TYPE] = node_type node_config.update(launch_config) launch_start_time = time.time() - self.provider.create_node_with_resources(node_config, node_tags, count, - resources) + self.provider.create_node(node_config, node_tags, count) launch_time = time.time() - launch_start_time for _ in range(count): # Note: when launching multiple nodes we observe the time it diff --git a/python/ray/autoscaler/_private/providers.py b/python/ray/autoscaler/_private/providers.py index 343350817f512..e60eb441e1414 100644 --- a/python/ray/autoscaler/_private/providers.py +++ b/python/ray/autoscaler/_private/providers.py @@ -56,12 +56,6 @@ def _import_readonly(provider_config): return ReadOnlyNodeProvider -def _import_fake_multinode(provider_config): - from ray.autoscaler._private.fake_multi_node.node_provider import \ - FakeMultiNodeProvider - return FakeMultiNodeProvider - - def _import_kubernetes(provider_config): from ray.autoscaler._private._kubernetes.node_provider import \ KubernetesNodeProvider @@ -123,7 +117,6 @@ def _import_external(provider_config): _NODE_PROVIDERS = { "local": _import_local, - "fake_multinode": _import_fake_multinode, "readonly": _import_readonly, "aws": _import_aws, "gcp": _import_gcp, @@ -136,7 +129,6 @@ def _import_external(provider_config): _PROVIDER_PRETTY_NAMES = { "readonly": "Readonly (Manual Cluster Setup)", - "fake_multinode": "Fake Multinode", "local": "Local", "aws": "AWS", "gcp": "GCP", diff --git a/python/ray/autoscaler/_private/resource_demand_scheduler.py b/python/ray/autoscaler/_private/resource_demand_scheduler.py index f055a01769714..517f49f63281c 100644 --- a/python/ray/autoscaler/_private/resource_demand_scheduler.py +++ b/python/ray/autoscaler/_private/resource_demand_scheduler.py @@ -116,8 +116,7 @@ def is_feasible(self, bundle: ResourceDict) -> bool: for node_type, config in self.node_types.items(): max_of_type = config.get("max_workers", 0) node_resources = config["resources"] - if (node_type == self.head_node_type or max_of_type > 0) and _fits( - node_resources, bundle): + if max_of_type > 0 and _fits(node_resources, bundle): return True return False @@ -765,11 +764,7 @@ def _utilization_score(node_resources: ResourceDict, return None fittable = [] - resource_types = set() for r in resources: - for k, v in r.items(): - if v > 0: - resource_types.add(k) if _fits(remaining, r): fittable.append(r) _inplace_subtract(remaining, r) @@ -777,15 +772,12 @@ def _utilization_score(node_resources: ResourceDict, return None util_by_resources = [] - num_matching_resource_types = 0 for k, v in node_resources.items(): # Don't divide by zero. if v < 1: # Could test v == 0 on the nose, but v < 1 feels safer. # (Note that node resources are integers.) continue - if k in resource_types: - num_matching_resource_types += 1 util = (v - remaining[k]) / v util_by_resources.append(v * (util**3)) @@ -793,11 +785,9 @@ def _utilization_score(node_resources: ResourceDict, if not util_by_resources: return None - # Prioritize matching multiple resource types first, then prioritize - # using all resources, then prioritize overall balance + # Prioritize using all resources first, then prioritize overall balance # of multiple resources. - return (num_matching_resource_types, min(util_by_resources), - np.mean(util_by_resources)) + return (min(util_by_resources), np.mean(util_by_resources)) def get_bin_pack_residual(node_resources: List[ResourceDict], @@ -828,16 +818,7 @@ def get_bin_pack_residual(node_resources: List[ResourceDict], nodes = copy.deepcopy(node_resources) # List of nodes that cannot be used again due to strict spread. used = [] - # We order the resource demands in the following way: - # More complex demands first. - # Break ties: heavier demands first. - # Break ties: lexicographically (to ensure stable ordering). - for demand in sorted( - resource_demands, - key=lambda demand: (len(demand.values()), - sum(demand.values()), - sorted(demand.items())), - reverse=True): + for demand in resource_demands: found = False node = None for i in range(len(nodes)): diff --git a/python/ray/autoscaler/gcp/tpu.yaml b/python/ray/autoscaler/gcp/tpu.yaml index a963e62c1898d..34726cb2205b4 100644 --- a/python/ray/autoscaler/gcp/tpu.yaml +++ b/python/ray/autoscaler/gcp/tpu.yaml @@ -32,9 +32,9 @@ available_node_types: # Support for TPU pods will be added in the future. acceleratorType: v2-8 runtimeVersion: v2-alpha - schedulingConfig: - # Set to false to use non-preemptible TPUs - preemptible: true + # Uncomment to use preemptible TPUs + # schedulingConfig: + # preemptible: true provider: type: gcp @@ -51,21 +51,15 @@ head_node_type: ray_head_default # Compute instances have python 3.7, but TPUs have 3.8 - need to update # Install Jax and other dependencies on the Compute head node head_setup_commands: - # Two first lines are a workaround for ssh timing out - - sleep 2 - - sleep 2 - - sudo chown -R $(whoami) /opt/conda/* - - conda create -y -n "ray" python=3.8.5 - - conda activate ray && echo 'conda activate ray' >> ~/.bashrc - - python -m pip install --upgrade pip - - python -m pip install --upgrade "jax[cpu]==0.2.14" + - conda create -y -n "ray" python=3.8.5 && sudo update-alternatives --install /opt/conda/bin/python python /opt/conda/envs/ray/bin/python 10 && sudo update-alternatives --install /opt/conda/bin/pip pip /opt/conda/envs/ray/bin/pip 10 + - export PATH="$PATH:/opt/conda/envs/ray/bin" && echo 'export PATH="$PATH:/opt/conda/envs/ray/bin"' >> ~/.bashrc + - python -m pip install --upgrade "jax[cpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - python -m pip install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku google-api-python-client cryptography tensorboardX ray[default] - python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl - git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && python -m pip install . # Install Jax and other dependencies on TPU worker_setup_commands: - - pip3 install --upgrade pip - pip3 install --upgrade "jax[tpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip3 install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku tensorboardX ray[default] - python3 -c "import jax; jax.device_count(); jax.numpy.add(1, 1)" # test if Jax has been installed correctly diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index 3340910592ba3..c912cd772456d 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -124,18 +124,6 @@ def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str], """ raise NotImplementedError - def create_node_with_resources( - self, node_config: Dict[str, Any], tags: Dict[str, str], - count: int, - resources: Dict[str, float]) -> Optional[Dict[str, Any]]: - """Create nodes with a given resource config. - - This is the method actually called by the autoscaler. Prefer to - implement this when possible directly, otherwise it delegates to the - create_node() implementation. - """ - return self.create_node(node_config, tags, count) - def set_node_tags(self, node_id: str, tags: Dict[str, str]) -> None: """Sets the tag values (string dict) for the specified node.""" raise NotImplementedError diff --git a/python/ray/autoscaler/ray-schema.json b/python/ray/autoscaler/ray-schema.json index 3e9c791720f04..64e67c7c8f42b 100644 --- a/python/ray/autoscaler/ray-schema.json +++ b/python/ray/autoscaler/ray-schema.json @@ -56,7 +56,7 @@ }, "idle_timeout_minutes": { "description": "If a node is idle for this many minutes, it will be removed.", - "type": "number", + "type": "integer", "minimum": 0 }, "provider": { diff --git a/python/ray/cluster_utils.py b/python/ray/cluster_utils.py index 80ef1f0ecd40a..965492b6eafcf 100644 --- a/python/ray/cluster_utils.py +++ b/python/ray/cluster_utils.py @@ -1,9 +1,4 @@ import logging -import json -import yaml -import os -import subprocess -import tempfile import time import ray @@ -13,70 +8,6 @@ logger = logging.getLogger(__name__) -class AutoscalingCluster: - """Create a local autoscaling cluster for testing. - - See test_autoscaler_fake_multinode.py for an end-to-end example. - """ - - def __init__(self, head_resources: dict, worker_node_types: dict): - """Create the cluster. - - Args: - head_resources: resources of the head node, including CPU. - worker_node_types: autoscaler node types config for worker nodes. - """ - base_config = yaml.safe_load( - open( - os.path.join( - os.path.dirname(ray.__file__), - "autoscaler/_private/fake_multi_node/example.yaml"))) - base_config["available_node_types"] = worker_node_types - base_config["available_node_types"]["ray.head.default"] = { - "resources": head_resources, - "node_config": {}, - "max_workers": 0, - } - self._head_resources = head_resources - self._config = base_config - self._process = None - - def start(self): - """Start the cluster. - - After this call returns, you can connect to the cluster with - ray.init("auto"). - """ - subprocess.check_call(["ray", "stop", "--force"]) - fake_config = tempfile.mktemp() - with open(fake_config, "w") as f: - f.write(json.dumps(self._config)) - cmd = [ - "ray", "start", "--autoscaling-config={}".format(fake_config), - "--head", "--block" - ] - if "CPU" in self._head_resources: - cmd.append("--num-cpus={}".format(self._head_resources.pop("CPU"))) - if "GPU" in self._head_resources: - cmd.append("--num-gpus={}".format(self._head_resources.pop("GPU"))) - if self._head_resources: - cmd.append("--resources='{}'".format( - json.dumps(self._head_resources))) - env = os.environ.copy() - env.update({ - "AUTOSCALER_UPDATE_INTERVAL_S": "1", - "RAY_FAKE_CLUSTER": "1" - }) - self._process = subprocess.Popen(cmd, env=env) - time.sleep(5) # TODO(ekl) wait for it properly - - def shutdown(self): - """Terminate the cluster.""" - if self._process: - self._process.kill() - subprocess.check_call(["ray", "stop", "--force"]) - - class Cluster: def __init__(self, initialize_head=False, diff --git a/python/ray/cross_language.py b/python/ray/cross_language.py index f3ad1de0a3903..eae3939ffa11e 100644 --- a/python/ray/cross_language.py +++ b/python/ray/cross_language.py @@ -79,8 +79,7 @@ def java_function(class_name, function_name): None, # max_calls, None, # max_retries, None, # retry_exceptions, - None, # runtime_env - None) # placement_group + None) # runtime_env @PublicAPI(stability="beta") diff --git a/python/ray/data/__init__.py b/python/ray/data/__init__.py index c6e411fadf86b..521add717220c 100644 --- a/python/ray/data/__init__.py +++ b/python/ray/data/__init__.py @@ -1,8 +1,7 @@ from ray.data.read_api import from_items, range, range_arrow, \ range_tensor, read_parquet, read_json, read_csv, read_binary_files, \ - from_dask, from_modin, from_mars, from_pandas, from_pandas_refs, \ - from_numpy, from_arrow, from_arrow_refs, from_spark, read_datasource, \ - read_numpy, read_text + from_dask, from_modin, from_mars, from_pandas, from_numpy, from_arrow, \ + from_spark, read_datasource, read_numpy, read_text from ray.data.datasource import Datasource, ReadTask from ray.data.dataset import Dataset from ray.data.impl.progress_bar import set_progress_bars @@ -19,12 +18,10 @@ "from_dask", "from_items", "from_arrow", - "from_arrow_refs", "from_mars", "from_modin", "from_numpy", "from_pandas", - "from_pandas_refs", "from_spark", "range", "range_arrow", diff --git a/python/ray/data/block.py b/python/ray/data/block.py index e7edab74863ad..35b99780c5e0d 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -16,8 +16,8 @@ # Represents a batch of records to be stored in the Ray object store. # # Block data can be accessed in a uniform way via ``BlockAccessors`` such as -# ``SimpleBlockAccessor`` and ``ArrowBlockAccessor``. -Block = Union[List[T], "pyarrow.Table", bytes] +# ``SimpleBlockAccessor``, ``ArrowBlockAccessor``, and ``TensorBlockAccessor``. +Block = Union[List[T], np.ndarray, "pyarrow.Table", bytes] @DeveloperAPI @@ -52,8 +52,8 @@ class BlockAccessor(Generic[T]): as a top-level Ray object, without a wrapping class (issue #17186). There are three types of block accessors: ``SimpleBlockAccessor``, which - operates over a plain Python list, and ``ArrowBlockAccessor`` for - ``pyarrow.Table`` type blocks. + operates over a plain Python list, ``ArrowBlockAccessor``, for + ``pyarrow.Table`` type blocks, and ``TensorBlockAccessor``, for tensors. """ def num_rows(self) -> int: @@ -85,16 +85,12 @@ def to_pandas(self) -> "pandas.DataFrame": """Convert this block into a Pandas dataframe.""" raise NotImplementedError - def to_numpy(self, column: str = None) -> np.ndarray: - """Convert this block (or column of block) into a NumPy ndarray. - - Args: - column: Name of column to convert, or None. - """ + def to_numpy(self) -> np.ndarray: + """Convert this block into a NumPy ndarray.""" raise NotImplementedError - def to_arrow(self) -> "pyarrow.Table": - """Convert this block into an Arrow table.""" + def to_arrow(self) -> Union["pyarrow.Table", "pyarrow.Tensor"]: + """Convert this block into an Arrow table or tensor.""" raise NotImplementedError def size_bytes(self) -> int: @@ -140,6 +136,10 @@ def for_block(block: Block) -> "BlockAccessor[T]": from ray.data.impl.simple_block import \ SimpleBlockAccessor return SimpleBlockAccessor(block) + elif isinstance(block, np.ndarray): + from ray.data.impl.tensor_block import \ + TensorBlockAccessor + return TensorBlockAccessor(block) else: raise TypeError("Not a block type: {}".format(block)) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 0b8a7fad6ca50..11d0a13c9cbae 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -51,7 +51,8 @@ class Dataset(Generic[T]): Datasets are implemented as a list of ``ObjectRef[Block]``. The block also determines the unit of parallelism. The default block type is the - ``pyarrow.Table``. Arrow-incompatible objects are held in ``list`` blocks. + ``pyarrow.Table``. Tensor objects are held in ``np.ndarray`` blocks, + and other Arrow-incompatible objects are held in ``list`` blocks. Since Datasets are just lists of Ray object refs, they can be passed between Ray tasks and actors just like any other object. Datasets support @@ -168,7 +169,7 @@ def map_batches(self, tasks, or "actors" to use an autoscaling Ray actor pool. batch_format: Specify "native" to use the native block format, "pandas" to select ``pandas.DataFrame`` as the batch format, - or "pyarrow" to select ``pyarrow.Table``. + or "pyarrow" to select ``pyarrow.Table/Tensor``. ray_remote_args: Additional resource requirements to request from ray (e.g., num_gpus=1 to request GPUs for the map tasks). """ @@ -204,15 +205,19 @@ def transform(block: Block) -> Block: "or 'pyarrow', got: {}".format(batch_format)) applied = fn(view) - if isinstance(applied, list) or isinstance(applied, pa.Table): + if (isinstance(applied, list) or isinstance(applied, pa.Table) + or isinstance(applied, np.ndarray)): applied = applied elif isinstance(applied, pd.core.frame.DataFrame): applied = pa.Table.from_pandas(applied) + elif isinstance(applied, pa.Tensor): + applied = applied.to_numpy() else: raise ValueError("The map batches UDF returned a type " f"{type(applied)}, which is not allowed. " "The return type must be either list, " - "pandas.DataFrame, or pyarrow.Table") + "pandas.DataFrame, np.ndarray, " + "pyarrow.Tensor, or pyarrow.Table") builder.add_block(applied) return builder.build() @@ -347,13 +352,8 @@ def random_shuffle( Returns: The shuffled dataset. """ - curr_num_blocks = self.num_blocks() - # Handle empty dataset. - if curr_num_blocks == 0: - return self - if num_blocks is None: - num_blocks = curr_num_blocks + num_blocks = self.num_blocks() new_blocks = simple_shuffle( self._move_blocks() if _move else self._blocks, num_blocks, @@ -402,150 +402,24 @@ def split(self, if n <= 0: raise ValueError(f"The number of splits {n} is not positive.") + if n > self.num_blocks() and equal: + raise NotImplementedError( + f"The number of splits {n} > the number of dataset blocks " + f"{self.num_blocks()}, yet an equal split was requested.") + if locality_hints and len(locality_hints) != n: raise ValueError( f"The length of locality_hints {len(locality_hints)} " "doesn't equal the number of splits {n}.") - def _partition_splits(splits: List[Dataset[T]], part_size: int, - counts_cache: Dict[str, int]): - """Partition splits into two sets: splits that are smaller than the - target size and splits that are larger than the target size. - """ - splits = sorted(splits, key=lambda s: counts_cache[s._get_uuid()]) - idx = next(i for i, split in enumerate(splits) - if counts_cache[split._get_uuid()] >= part_size) - return splits[:idx], splits[idx:] - - def _equalize_larger_splits(splits: List[Dataset[T]], target_size: int, - counts_cache: Dict[str, int], - num_splits_required: int): - """Split each split into one or more subsplits that are each the - target size, with at most one leftover split that's smaller - than the target size. - - This assume that the given splits are sorted in ascending order. - """ - new_splits = [] - leftovers = [] - for split in splits: - size = counts_cache[split._get_uuid()] - if size == target_size: - new_splits.append(split) - continue - split_indices = list(range(target_size, size, target_size)) - split_splits = split.split_at_indices(split_indices) - last_split_size = split_splits[-1].count() - if last_split_size < target_size: - # Last split is smaller than the target size, save it for - # our unioning of small splits. - leftover = split_splits.pop() - leftovers.append(leftover) - counts_cache[leftover._get_uuid()] = leftover.count() - if len(new_splits) + len(split_splits) >= num_splits_required: - # Short-circuit if the new splits will make us reach the - # desired number of splits. - new_splits.extend( - split_splits[:num_splits_required - len(new_splits)]) - break - new_splits.extend(split_splits) - return new_splits, leftovers - - def _equalize_smaller_splits( - splits: List[Dataset[T]], target_size: int, - counts_cache: Dict[str, int], num_splits_required: int): - """Union small splits up to the target split size. - - This assume that the given splits are sorted in ascending order. - """ - new_splits = [] - union_buffer = [] - union_buffer_size = 0 - low = 0 - high = len(splits) - 1 - while low <= high: - # Union small splits up to the target split size. - low_split = splits[low] - low_count = counts_cache[low_split._get_uuid()] - high_split = splits[high] - high_count = counts_cache[high_split._get_uuid()] - if union_buffer_size + high_count <= target_size: - # Try to add the larger split to the union buffer first. - union_buffer.append(high_split) - union_buffer_size += high_count - high -= 1 - elif union_buffer_size + low_count <= target_size: - union_buffer.append(low_split) - union_buffer_size += low_count - low += 1 - else: - # Neither the larger nor smaller split fit in the union - # buffer, so we split the smaller split into a subsplit - # that will fit into the union buffer and a leftover - # subsplit that we add back into the candidate split list. - diff = target_size - union_buffer_size - diff_split, new_low_split = low_split.split_at_indices( - [diff]) - union_buffer.append(diff_split) - union_buffer_size += diff - # We overwrite the old low split and don't advance the low - # pointer since (1) the old low split can be discarded, - # (2) the leftover subsplit is guaranteed to be smaller - # than the old low split, and (3) the low split should be - # the smallest split in the candidate split list, which is - # this subsplit. - splits[low] = new_low_split - counts_cache[new_low_split._get_uuid()] = low_count - diff - if union_buffer_size == target_size: - # Once the union buffer is full, we union together the - # splits. - assert len(union_buffer) > 1, union_buffer - first_ds = union_buffer[0] - new_split = first_ds.union(*union_buffer[1:]) - new_splits.append(new_split) - # Clear the union buffer. - union_buffer = [] - union_buffer_size = 0 - if len(new_splits) == num_splits_required: - # Short-circuit if we've reached the desired number of - # splits. - break - return new_splits - - def equalize(splits: List[Dataset[T]], - num_splits: int) -> List[Dataset[T]]: + # TODO(ekl) we could do better than truncation here. This could be a + # problem if block sizes are very skewed. + def equalize(splits: List[Dataset[T]]) -> List[Dataset[T]]: if not equal: return splits - counts = {s._get_uuid(): s.count() for s in splits} - total_rows = sum(counts.values()) - # Number of rows for each split. - target_size = total_rows // num_splits - - # Partition splits. - smaller_splits, larger_splits = _partition_splits( - splits, target_size, counts) - if len(smaller_splits) == 0 and num_splits < len(splits): - # All splits are already equal. - return splits - - # Split larger splits. - new_splits, leftovers = _equalize_larger_splits( - larger_splits, target_size, counts, num_splits) - # Short-circuit if we've already reached the desired number of - # splits. - if len(new_splits) == num_splits: - return new_splits - # Add leftovers to small splits and re-sort. - smaller_splits += leftovers - smaller_splits = sorted( - smaller_splits, key=lambda s: counts[s._get_uuid()]) - - # Union smaller splits. - new_splits_small = _equalize_smaller_splits( - smaller_splits, target_size, counts, - num_splits - len(new_splits)) - new_splits.extend(new_splits_small) - return new_splits + lower_bound = min([s.count() for s in splits]) + assert lower_bound > 0, splits + return [s.limit(lower_bound) for s in splits] block_refs = list(self._blocks) metadata_mapping = { @@ -559,8 +433,7 @@ def equalize(splits: List[Dataset[T]], BlockList( list(blocks), [metadata_mapping[b] for b in blocks])) for blocks in np.array_split(block_refs, n) - if not equal or len(blocks) > 0 - ], n) + ]) # If the locality_hints is set, we use a two-round greedy algorithm # to co-locate the blocks with the actors based on block @@ -659,7 +532,7 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]: [metadata_mapping[b] for b in allocation_per_actor[actor]])) for actor in locality_hints - ], n) + ]) def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]: """Split the dataset at the given indices (like np.split). @@ -707,9 +580,6 @@ def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]: def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]": """Combine this dataset with others of the same type. - The order of the blocks in the datasets is preserved, as is the - relative ordering between the datasets passed in the argument list. - Args: other: List of datasets to combine with this one. The datasets must have the same schema as this dataset, otherwise the @@ -719,21 +589,35 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]": A new dataset holding the union of their data. """ - calls: List[Callable[[], ObjectRef[Block]]] = [] - metadata: List[BlockMetadata] = [] blocks: List[ObjectRef[Block]] = [] + metadata: List[BlockMetadata] = [] + pending_blocks: List[Callable[[], ObjectRef[Block]]] = [] + pending_metadata: List[BlockMetadata] = [] datasets = [self] + list(other) for ds in datasets: bl = ds._blocks if isinstance(bl, LazyBlockList): - calls.extend(bl._calls) + for block, meta in zip(bl._blocks, bl._metadata): + blocks.append(block) + metadata.append(meta) + lim = len(bl._blocks) + for call, meta in zip(bl._calls[lim:], bl._metadata[lim:]): + pending_blocks.append(call) + pending_metadata.append(meta) else: - calls.extend([None] * len(bl)) - metadata.extend(bl._metadata) - blocks.extend(bl._blocks) + assert isinstance(bl, BlockList), bl + blocks.extend(list(bl._blocks)) + metadata.extend(bl.get_metadata()) - return Dataset(LazyBlockList(calls, metadata, blocks)) + result = LazyBlockList([], []) + result._calls = ([None] * len(blocks)) + pending_blocks + result._blocks = blocks + result._metadata = metadata + pending_metadata + + assert len(result._calls) == len(result._metadata), result + assert len(result._blocks) <= len(result._calls), result + return Dataset(result) def sort(self, key: Union[None, str, List[str], Callable[[T], Any]] = None, @@ -769,9 +653,6 @@ def sort(self, Returns: A new, sorted dataset. """ - # Handle empty dataset. - if self.num_blocks() == 0: - return self return Dataset(sort_impl(self._blocks, key, descending)) def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]": @@ -797,8 +678,8 @@ def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]": comes from the first dataset and v comes from the second. """ - blocks1 = self.get_internal_block_refs() - blocks2 = other.get_internal_block_refs() + blocks1 = self.get_blocks() + blocks2 = other.get_blocks() if len(blocks1) != len(blocks2): # TODO(ekl) consider supporting if num_rows are equal. @@ -880,9 +761,6 @@ def count(self) -> int: Returns: The number of records in the dataset. """ - # Handle empty dataset. - if self.num_blocks() == 0: - return 0 # For parquet, we can return the count directly from metadata. meta_count = self._meta_count() @@ -971,8 +849,6 @@ def write_parquet(self, path: str, *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, - try_create_dir: bool = True, - arrow_open_stream_args: Optional[Dict[str, Any]] = None, **arrow_parquet_args) -> None: """Write the dataset to parquet. @@ -991,10 +867,6 @@ def write_parquet(self, path: The path to the destination root directory, where Parquet files will be written to. filesystem: The filesystem implementation to write to. - try_create_dir: Try to create all directories in destination path - if True. Does nothing if all directories already exist. - arrow_open_stream_args: kwargs passed to - pyarrow.fs.FileSystem.open_output_stream arrow_parquet_args: Options to pass to pyarrow.parquet.write_table(), which is used to write out each block to a file. @@ -1004,16 +876,12 @@ def write_parquet(self, path=path, dataset_uuid=self._uuid, filesystem=filesystem, - try_create_dir=try_create_dir, - open_stream_args=arrow_open_stream_args, **arrow_parquet_args) def write_json(self, path: str, *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, - try_create_dir: bool = True, - arrow_open_stream_args: Optional[Dict[str, Any]] = None, **pandas_json_args) -> None: """Write the dataset to json. @@ -1032,10 +900,6 @@ def write_json(self, path: The path to the destination root directory, where json files will be written to. filesystem: The filesystem implementation to write to. - try_create_dir: Try to create all directories in destination path - if True. Does nothing if all directories already exist. - arrow_open_stream_args: kwargs passed to - pyarrow.fs.FileSystem.open_output_stream pandas_json_args: These args will be passed to pandas.DataFrame.to_json(), which we use under the hood to write out each Datasets block. These @@ -1046,16 +910,12 @@ def write_json(self, path=path, dataset_uuid=self._uuid, filesystem=filesystem, - try_create_dir=try_create_dir, - open_stream_args=arrow_open_stream_args, **pandas_json_args) def write_csv(self, path: str, *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, - try_create_dir: bool = True, - arrow_open_stream_args: Optional[Dict[str, Any]] = None, **arrow_csv_args) -> None: """Write the dataset to csv. @@ -1074,10 +934,6 @@ def write_csv(self, path: The path to the destination root directory, where csv files will be written to. filesystem: The filesystem implementation to write to. - try_create_dir: Try to create all directories in destination path - if True. Does nothing if all directories already exist. - arrow_open_stream_args: kwargs passed to - pyarrow.fs.FileSystem.open_output_stream arrow_csv_args: Other CSV write options to pass to pyarrow. """ self.write_datasource( @@ -1085,23 +941,17 @@ def write_csv(self, path=path, dataset_uuid=self._uuid, filesystem=filesystem, - try_create_dir=try_create_dir, - open_stream_args=arrow_open_stream_args, **arrow_csv_args) def write_numpy( self, path: str, *, - column: str = "value", - filesystem: Optional["pyarrow.fs.FileSystem"] = None, - try_create_dir: bool = True, - arrow_open_stream_args: Optional[Dict[str, Any]] = None) -> None: - """Write a tensor column of the dataset to npy files. + filesystem: Optional["pyarrow.fs.FileSystem"] = None) -> None: + """Write the dataset to npy files. - This is only supported for datasets convertible to Arrow records that - contain a TensorArray column. To control the number of files, use - ``.repartition()``. + This is only supported for datasets of Tensor records. + To control the number of files, use ``.repartition()``. The format of the output files will be {self._uuid}_{block_idx}.npy, where ``uuid`` is an unique id for the dataset. @@ -1114,22 +964,13 @@ def write_numpy( Args: path: The path to the destination root directory, where npy files will be written to. - column: The name of the table column that contains the tensor to - be written. This defaults to "value". filesystem: The filesystem implementation to write to. - try_create_dir: Try to create all directories in destination path - if True. Does nothing if all directories already exist. - arrow_open_stream_args: kwargs passed to - pyarrow.fs.FileSystem.open_output_stream """ self.write_datasource( NumpyDatasource(), path=path, dataset_uuid=self._uuid, - column=column, - filesystem=filesystem, - try_create_dir=try_create_dir, - open_stream_args=arrow_open_stream_args) + filesystem=filesystem) def write_datasource(self, datasource: Datasource[T], **write_args) -> None: @@ -1201,7 +1042,7 @@ def iter_batches(self, batch_format: The format in which to return each batch. Specify "native" to use the current block format, "pandas" to select ``pandas.DataFrame`` or "pyarrow" to select - ``pyarrow.Table``. Default is "native". + ``pyarrow.Table/Tensor``. Default is "native". drop_last: Whether to drop the last batch if it's incomplete. Returns: @@ -1469,15 +1310,14 @@ def to_modin(self) -> "modin.DataFrame": """Convert this dataset into a Modin dataframe. This works by first converting this dataset into a distributed set of - Pandas dataframes (using ``.to_pandas_refs()``). Please see caveats - there. Then the individual dataframes are used to create the modin - DataFrame using + Pandas dataframes (using ``.to_pandas()``). Please see caveats there. + Then the individual dataframes are used to create the modin DataFrame + using ``modin.distributed.dataframe.pandas.partitions.from_partitions()``. This is only supported for datasets convertible to Arrow records. This function induces a copy of the data. For zero-copy access to the - underlying data, consider using ``.to_arrow()`` or - ``.get_internal_block_refs()``. + underlying data, consider using ``.to_arrow()`` or ``.get_blocks()``. Time complexity: O(dataset size / parallelism) @@ -1487,7 +1327,7 @@ def to_modin(self) -> "modin.DataFrame": from modin.distributed.dataframe.pandas.partitions import ( from_partitions) - pd_objs = self.to_pandas_refs() + pd_objs = self.to_pandas() return from_partitions(pd_objs, axis=0) def to_spark(self, @@ -1503,45 +1343,17 @@ def to_spark(self, core_worker = ray.worker.global_worker.core_worker locations = [ core_worker.get_owner_address(block) - for block in self.get_internal_block_refs() + for block in self.get_blocks() ] return raydp.spark.ray_dataset_to_spark_dataframe( - spark, self.schema(), self.get_internal_block_refs(), locations) - - def to_pandas(self, limit: int = 1000) -> "pandas.DataFrame": - """Convert this dataset into a single Pandas DataFrame. - - This is only supported for datasets convertible to Arrow records. This - limits the number of records returned to the provided limit. - - Time complexity: O(limit) - - Args: - limit: The maximum number of records to return. + spark, self.schema(), self.get_blocks(), locations) - Returns: - A Pandas DataFrame created from this dataset, containing a limited - number of records. - """ - - if self.count() > limit: - logger.warning(f"Only returning the first {limit} records from " - "to_pandas()") - limited_ds = self.limit(limit) - blocks = limited_ds.get_internal_block_refs() - output = DelegatingArrowBlockBuilder() - for block in ray.get(blocks): - output.add_block(block) - return output.build().to_pandas() - - @DeveloperAPI - def to_pandas_refs(self) -> List[ObjectRef["pandas.DataFrame"]]: + def to_pandas(self) -> List[ObjectRef["pandas.DataFrame"]]: """Convert this dataset into a distributed set of Pandas dataframes. This is only supported for datasets convertible to Arrow records. This function induces a copy of the data. For zero-copy access to the - underlying data, consider using ``.to_arrow()`` or - ``.get_internal_block_refs()``. + underlying data, consider using ``.to_arrow()`` or ``.get_blocks()``. Time complexity: O(dataset size / parallelism) @@ -1552,48 +1364,23 @@ def to_pandas_refs(self) -> List[ObjectRef["pandas.DataFrame"]]: block_to_df = cached_remote_fn(_block_to_df) return [block_to_df.remote(block) for block in self._blocks] - def to_numpy(self, *, - column: Optional[str] = None) -> List[ObjectRef[np.ndarray]]: + def to_numpy(self) -> List[ObjectRef[np.ndarray]]: """Convert this dataset into a distributed set of NumPy ndarrays. This is only supported for datasets convertible to NumPy ndarrays. This function induces a copy of the data. For zero-copy access to the - underlying data, consider using ``.to_arrow()`` or - ``.get_internal_block_refs()``. + underlying data, consider using ``.to_arrow()`` or ``.get_blocks()``. Time complexity: O(dataset size / parallelism) - Args: - column: The name of the column to convert to numpy, or None to - specify the entire row. Required for Arrow tables. - Returns: A list of remote NumPy ndarrays created from this dataset. """ block_to_ndarray = cached_remote_fn(_block_to_ndarray) - return [ - block_to_ndarray.remote(block, column=column) - for block in self._blocks - ] - - def to_arrow(self) -> List["pyarrow.Table"]: - """Convert this dataset into a list of Arrow tables. - - This is only supported for datasets convertible to Arrow records. - This function is zero-copy if the existing data is already in Arrow - format. Otherwise, the data will be converted to Arrow format. - - Time complexity: O(1) unless conversion is required. + return [block_to_ndarray.remote(block) for block in self._blocks] - Returns: - A list of Arrow tables created from this dataset. - """ - - return ray.get(self.to_arrow_refs()) - - @DeveloperAPI - def to_arrow_refs(self) -> List[ObjectRef["pyarrow.Table"]]: + def to_arrow(self) -> List[ObjectRef["pyarrow.Table"]]: """Convert this dataset into a distributed set of Arrow tables. This is only supported for datasets convertible to Arrow records. @@ -1663,32 +1450,28 @@ def __init__(self, ds: "Dataset[T]"): def __iter__(self): return Iterator(self._ds) - return DatasetPipeline(Iterable(self), length=times or float("inf")) + return DatasetPipeline(Iterable(self), length=times) def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]": - raise DeprecationWarning("Use .window(blocks_per_window=n) instead of " - ".pipeline(parallelism=n)") + """Pipeline the dataset execution by splitting its blocks into groups. - def window(self, *, blocks_per_window: int = 10) -> "DatasetPipeline[T]": - """Convert this into a DatasetPipeline by windowing over data blocks. - - Transformations prior to the call to ``window()`` are evaluated in + Transformations prior to the call to ``pipeline()`` are evaluated in bulk on the entire dataset. Transformations done on the returned - pipeline are evaluated incrementally per window of blocks as data is + pipeline are evaluated incrementally per group of blocks as data is read from the output of the pipeline. - Windowing execution allows for output to be read sooner without + Pipelining execution allows for output to be read sooner without waiting for all transformations to fully execute, and can also improve efficiency if transforms use different resources (e.g., GPUs). - Without windowing:: + Without pipelining:: [preprocessing......] [inference.......] [write........] Time -----------------------------------------------------------> - With windowing:: + With pipelining:: [prep1] [prep2] [prep3] [infer1] [infer2] [infer3] @@ -1698,20 +1481,20 @@ def window(self, *, blocks_per_window: int = 10) -> "DatasetPipeline[T]": Examples: >>> # Create an inference pipeline. >>> ds = ray.data.read_binary_files(dir) - >>> pipe = ds.window(blocks_per_window=10).map(infer) - DatasetPipeline(num_windows=40, num_stages=2) + >>> pipe = ds.pipeline(parallelism=10).map(infer) + DatasetPipeline(num_stages=2, length=40) >>> # The higher the stage parallelism, the shorter the pipeline. - >>> pipe = ds.window(blocks_per_window=20).map(infer) - DatasetPipeline(num_windows=20, num_stages=2) + >>> pipe = ds.pipeline(parallelism=20).map(infer) + DatasetPipeline(num_stages=2, length=20) >>> # Outputs can be incrementally read from the pipeline. >>> for item in pipe.iter_rows(): ... print(item) Args: - blocks_per_window: The window size (parallelism) in blocks. - Increasing window size increases pipeline throughput, but also + parallelism: The parallelism (number of blocks) per stage. + Increasing parallelism increases pipeline throughput, but also increases the latency to initial output, since it decreases the length of the pipeline. Setting this to infinity effectively disables pipelining. @@ -1735,7 +1518,7 @@ def gen(): class Iterable: def __init__(self, blocks): - self._splits = blocks.split(split_size=blocks_per_window) + self._splits = blocks.split(split_size=parallelism) def __iter__(self): return Iterator(self._splits) @@ -1744,7 +1527,7 @@ def __iter__(self): return DatasetPipeline(it, length=len(it._splits)) @DeveloperAPI - def get_internal_block_refs(self) -> List[ObjectRef[Block]]: + def get_blocks(self) -> List[ObjectRef[Block]]: """Get a list of references to the underlying blocks of this dataset. This function can be used for zero-copy access to the data. @@ -1798,14 +1581,13 @@ def _split(self, index: int, right = None return left, right - def _divide(self, block_idx: int) -> ("Dataset[T]", "Dataset[T]"): - left, right = self._blocks.divide(block_idx) - return Dataset(left), Dataset(right) - def __repr__(self) -> str: schema = self.schema() if schema is None: schema_str = "Unknown schema" + elif isinstance(schema, dict): + schema_str = "".format( + schema["shape"], schema["dtype"]) elif isinstance(schema, type): schema_str = str(schema) else: @@ -1817,6 +1599,8 @@ def __repr__(self) -> str: schema_str = ", ".join(schema_str) schema_str = "{" + schema_str + "}" count = self._meta_count() + if count is None: + count = "?" return "Dataset(num_blocks={}, num_rows={}, schema={})".format( len(self._blocks), count, schema_str) @@ -1856,9 +1640,9 @@ def _block_to_df(block: Block): return block.to_pandas() -def _block_to_ndarray(block: Block, column: Optional[str]): +def _block_to_ndarray(block: Block): block = BlockAccessor.for_block(block) - return block.to_numpy(column) + return block.to_numpy() def _block_to_arrow(block: Block): diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 962961105f895..158905e70e9f9 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1,7 +1,7 @@ import functools import time from typing import Any, Callable, List, Iterator, Iterable, Generic, Union, \ - Optional, TYPE_CHECKING + TYPE_CHECKING import ray from ray.data.dataset import Dataset, T, U, BatchType @@ -13,15 +13,13 @@ if TYPE_CHECKING: import pyarrow -# Operations that can be naively applied per dataset row in the pipeline. +# Operations that can be naively applied per dataset in the pipeline. PER_DATASET_OPS = [ - "map", "map_batches", "flat_map", "filter", "write_json", "write_csv", - "write_parquet", "write_datasource" + "map", "map_batches", "flat_map", "filter", "repartition", + "random_shuffle", "sort", "write_json", "write_csv", "write_parquet", + "write_datasource" ] -# Operations that apply to each dataset holistically in the pipeline. -HOLISTIC_PER_DATASET_OPS = ["repartition", "random_shuffle", "sort"] - # Similar to above but we should force evaluation immediately. PER_DATASET_OUTPUT_OPS = [ "write_json", "write_csv", "write_parquet", "write_datasource" @@ -42,7 +40,7 @@ class DatasetPipeline(Generic[T]): A DatasetPipeline can be created by either repeating a Dataset (``ds.repeat(times=None)``), by turning a single Dataset into a pipeline - (``ds.window(blocks_per_window=10)``), or defined explicitly using + (``ds.pipeline(parallelism=10)``), or defined explicitly using ``DatasetPipeline.from_iterable()``. DatasetPipeline supports the all the per-record transforms of Datasets @@ -59,7 +57,7 @@ def __init__(self, """Construct a DatasetPipeline (internal API). The constructor is not part of the DatasetPipeline API. Use the - ``Dataset.repeat()``, ``Dataset.window()``, or + ``Dataset.repeat()``, ``Dataset.pipeline()``, or ``DatasetPipeline.from_iterable()`` methods to construct a pipeline. """ self._base_iterable = base_iterable @@ -242,124 +240,6 @@ def __next__(self): for idx in range(n) ] - def rewindow(self, *, blocks_per_window: int) -> "DatasetPipeline[T]": - """Change the windowing (blocks per dataset) of this pipeline. - - Changes the windowing of this pipeline to the specified size. For - example, if the current pipeline has two blocks per dataset, and - `.rewindow(blocks_per_window=4)` is requested, adjacent datasets will - be merged until each dataset is 4 blocks. If - `.rewindow(blocks_per_window)` was requested the datasets will be - split into smaller windows. - - Args: - blocks_per_window: The new target blocks per window. - """ - - class WindowIterator: - def __init__(self, original_iter): - self._original_iter = original_iter - self._buffer: Optional[Dataset[T]] = None - - def __next__(self) -> Dataset[T]: - try: - # Merge windows until we meet the requested window size. - if self._buffer is None: - self._buffer = next(self._original_iter) - while self._buffer.num_blocks() < blocks_per_window: - self._buffer = self._buffer.union( - next(self._original_iter)) - # Slice off the left-most chunk and return it. - res, self._buffer = self._buffer._divide(blocks_per_window) - assert res.num_blocks() <= blocks_per_window, res - return lambda: res - except StopIteration: - # Return the left-over data as a single window. - if self._buffer and self._buffer.num_blocks() > 0: - res = self._buffer - assert res.num_blocks() <= blocks_per_window, res - self._buffer = None - return lambda: res - else: - raise - - class WindowIterable: - def __init__(self, original_iter): - self._original_iter = original_iter - - def __iter__(self): - return WindowIterator(self._original_iter) - - return DatasetPipeline( - WindowIterable(self.iter_datasets()), length=None) - - def repeat(self, times: int = None) -> "DatasetPipeline[T]": - """Repeat this pipeline a given number or times, or indefinitely. - - This operation is only allowed for pipelines of a finite length. An - error will be raised for pipelines of infinite length. - - Transformations prior to the call to ``repeat()`` are evaluated once. - Transformations done on the repeated pipeline are evaluated on each - loop of the pipeline over the base pipeline. - - Args: - times: The number of times to loop over this pipeline, or None - to repeat indefinitely. - """ - - if self._length == float("inf"): - raise ValueError("Cannot repeat a pipeline of infinite length.") - - class RepeatIterator: - def __init__(self, original_iter): - self._original_iter = original_iter - # Holds results to repeat. - self._results = [] - # Incrementing cursor over results. - self._i = 0 - # This is calculated later. - self._max_i = None - - def __next__(self) -> Dataset[T]: - # Still going through the original pipeline. - if self._original_iter: - try: - res = next(self._original_iter) - self._results.append(res) - return lambda: res - except StopIteration: - self._original_iter = None - # Calculate the cursor limit. - if times: - self._max_i = len(self._results) * (times - 1) - else: - self._max_i = float("inf") - # Going through a repeat of the pipeline. - if self._i < self._max_i: - res = self._results[self._i % len(self._results)] - self._i += 1 - return lambda: res - else: - raise StopIteration - - class RepeatIterable: - def __init__(self, original_iter): - self._original_iter = original_iter - - def __iter__(self): - return RepeatIterator(self._original_iter) - - if not times: - length = float("inf") - elif times and self._length: - length = times * self._length - else: - length = None - - return DatasetPipeline( - RepeatIterable(self.iter_datasets()), length=length) - def schema(self) -> Union[type, "pyarrow.lib.Schema"]: """Return the schema of the dataset pipeline. @@ -407,19 +287,6 @@ def sum(self) -> int: total += elem return total - def show_windows(self, limit_per_dataset: int = 10) -> None: - """Print up to the given number of records from each window/dataset. - - This is helpful as a debugging tool for understanding the structure of - dataset pipelines. - - Args: - limit_per_dataset: Rows to print per window/dataset. - """ - for i, ds in enumerate(self.iter_datasets()): - print("=== Window {} ===".format(i)) - ds.show(limit_per_dataset) - @DeveloperAPI def iter_datasets(self) -> Iterator[Dataset[T]]: """Iterate over the output datasets of this pipeline. @@ -433,9 +300,9 @@ def iter_datasets(self) -> Iterator[Dataset[T]]: return PipelineExecutor(self) @DeveloperAPI - def foreach_window(self, fn: Callable[[Dataset[T]], Dataset[U]] - ) -> "DatasetPipeline[U]": - """Apply a transform to each dataset/window in this pipeline. + def foreach_dataset(self, fn: Callable[[Dataset[T]], Dataset[U]] + ) -> "DatasetPipeline[U]": + """Apply a transform to each dataset in this pipeline. Args: fn: The function to transform each dataset with. @@ -452,10 +319,6 @@ def foreach_window(self, fn: Callable[[Dataset[T]], Dataset[U]] self._progress_bars, _executed=self._executed) - def foreach_dataset(self, *a, **kw) -> None: - raise DeprecationWarning( - "`foreach_dataset` has been renamed to `foreach_window`.") - @staticmethod def from_iterable(iterable: Iterable[Callable[[], Dataset[T]]], ) -> "DatasetPipeline[T]": @@ -472,7 +335,7 @@ def from_iterable(iterable: Iterable[Callable[[], Dataset[T]]], return DatasetPipeline(iterable, length=length) def __repr__(self) -> str: - return "DatasetPipeline(num_windows={}, num_stages={})".format( + return "DatasetPipeline(length={}, num_stages={})".format( self._length, 1 + len(self._stages)) def __str__(self) -> str: @@ -492,7 +355,7 @@ def make_impl(method): @functools.wraps(delegate) def impl(self, *args, **kwargs): - return self.foreach_window( + return self.foreach_dataset( lambda ds: getattr(ds, method)(*args, **kwargs)) if impl.__annotations__.get("return"): @@ -503,33 +366,6 @@ def impl(self, *args, **kwargs): setattr(DatasetPipeline, method, make_impl(method)) -for method in HOLISTIC_PER_DATASET_OPS: - - def make_impl(method): - delegate = getattr(Dataset, method) - - @functools.wraps(delegate) - def impl(self, *args, **kwargs): - return self.foreach_window( - lambda ds: getattr(ds, method)(*args, **kwargs)) - - if impl.__annotations__.get("return"): - impl.__annotations__["return"] = impl.__annotations__[ - "return"].replace("Dataset", "DatasetPipeline") - - return impl - - def deprecation_warning(method: str): - def impl(*a, **kw): - raise DeprecationWarning( - "`{}` has been renamed to `{}_each_window`.".format( - method, method)) - - return impl - - setattr(DatasetPipeline, method, deprecation_warning(method)) - setattr(DatasetPipeline, method + "_each_window", make_impl(method)) - for method in PER_DATASET_OUTPUT_OPS: def make_impl(method): diff --git a/python/ray/data/datasource/__init__.py b/python/ray/data/datasource/__init__.py index fd3b5e21c6eb0..b09aa9acb0c0d 100644 --- a/python/ray/data/datasource/__init__.py +++ b/python/ray/data/datasource/__init__.py @@ -1,6 +1,5 @@ from ray.data.datasource.datasource import (Datasource, RangeDatasource, - DummyOutputDatasource, ReadTask, - RandomIntRowDatasource) + DummyOutputDatasource, ReadTask) from ray.data.datasource.json_datasource import JSONDatasource from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.datasource.numpy_datasource import NumpyDatasource @@ -19,7 +18,6 @@ "_S3FileSystemWrapper", "Datasource", "RangeDatasource", - "RandomIntRowDatasource", "DummyOutputDatasource", "ReadTask", ] diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 0945120306ebf..46b313ab3bfd0 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -130,11 +130,10 @@ def make_block(start: int, count: int) -> Block: return pyarrow.Table.from_arrays( [np.arange(start, start + count)], names=["value"]) elif block_format == "tensor": - tensor = TensorArray( - np.ones(tensor_shape, dtype=np.int64) * np.expand_dims( + return np.ones( + tensor_shape, dtype=np.int64) * np.expand_dims( np.arange(start, start + count), - tuple(range(1, 1 + len(tensor_shape))))) - return pyarrow.Table.from_pydict({"value": tensor}) + tuple(range(1, 1 + len(tensor_shape)))) else: return list(builtins.range(start, start + count)) @@ -146,14 +145,7 @@ def make_block(start: int, count: int) -> Block: import pyarrow schema = pyarrow.Table.from_pydict({"value": [0]}).schema elif block_format == "tensor": - _check_pyarrow_version() - from ray.data.extensions import TensorArray - import pyarrow - tensor = TensorArray( - np.ones(tensor_shape, dtype=np.int64) * np.expand_dims( - np.arange(0, 10), tuple( - range(1, 1 + len(tensor_shape))))) - schema = pyarrow.Table.from_pydict({"value": tensor}).schema + schema = {"dtype": "int64", "shape": (None, ) + tensor_shape} elif block_format == "list": schema = int else: @@ -221,50 +213,3 @@ def on_write_complete(self, write_results: List[WriteResult]) -> None: def on_write_failed(self, write_results: List[ObjectRef[WriteResult]], error: Exception) -> None: self.num_failed += 1 - - -class RandomIntRowDatasource(Datasource[ArrowRow]): - """An example datasource that generates rows with random int64 columns. - - Examples: - >>> source = RandomIntRowDatasource() - >>> ray.data.read_datasource(source, n=10, num_columns=2).take() - ... ArrowRow({'c_0': 1717767200176864416, 'c_1': 999657309586757214}) - ... ArrowRow({'c_0': 4983608804013926748, 'c_1': 1160140066899844087}) - """ - - def prepare_read(self, parallelism: int, n: int, - num_columns: int) -> List[ReadTask]: - _check_pyarrow_version() - import pyarrow - - read_tasks: List[ReadTask] = [] - block_size = max(1, n // parallelism) - - def make_block(count: int, num_columns: int) -> Block: - return pyarrow.Table.from_arrays( - np.random.randint( - np.iinfo(np.int64).max, - size=(num_columns, count), - dtype=np.int64), - names=[f"c_{i}" for i in range(num_columns)]) - - schema = pyarrow.Table.from_pydict( - {f"c_{i}": [0] - for i in range(num_columns)}).schema - - i = 0 - while i < n: - count = min(block_size, n - i) - read_tasks.append( - ReadTask( - lambda count=count, num_columns=num_columns: - make_block(count, num_columns), - BlockMetadata( - num_rows=count, - size_bytes=8 * count * num_columns, - schema=schema, - input_files=None))) - i += block_size - - return read_tasks diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 054f18b0436f7..9a326ebdcf62d 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -1,7 +1,6 @@ import logging import os -from typing import Callable, Optional, List, Tuple, Union, Any, Dict, \ - TYPE_CHECKING +from typing import Callable, Optional, List, Tuple, Union, Any, TYPE_CHECKING import urllib.parse if TYPE_CHECKING: @@ -37,7 +36,6 @@ def prepare_read( paths: Union[str, List[str]], filesystem: Optional["pyarrow.fs.FileSystem"] = None, schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None, - open_stream_args: Optional[Dict[str, Any]] = None, _block_udf: Optional[Callable[[Block], Block]] = None, **reader_args) -> List[ReadTask]: """Creates and returns read tasks for a file-based datasource. @@ -54,9 +52,6 @@ def prepare_read( filesystem = _wrap_s3_serialization_workaround(filesystem) - if open_stream_args is None: - open_stream_args = {} - def read_files( read_paths: List[str], fs: Union["pyarrow.fs.FileSystem", _S3FileSystemWrapper]): @@ -65,7 +60,7 @@ def read_files( fs = fs.unwrap() builder = DelegatingArrowBlockBuilder() for read_path in read_paths: - with fs.open_input_stream(read_path, **open_stream_args) as f: + with fs.open_input_stream(read_path) as f: data = read_file(f, read_path, **reader_args) if isinstance(data, pa.Table) or isinstance( data, np.ndarray): @@ -120,22 +115,16 @@ def do_write(self, path: str, dataset_uuid: str, filesystem: Optional["pyarrow.fs.FileSystem"] = None, - try_create_dir: bool = True, - open_stream_args: Optional[Dict[str, Any]] = None, _block_udf: Optional[Callable[[Block], Block]] = None, **write_args) -> List[ObjectRef[WriteResult]]: """Creates and returns write tasks for a file-based datasource.""" path, filesystem = _resolve_paths_and_filesystem(path, filesystem) path = path[0] - if try_create_dir: - filesystem.create_dir(path, recursive=True) + filesystem.create_dir(path, recursive=True) filesystem = _wrap_s3_serialization_workaround(filesystem) _write_block_to_file = self._write_block - if open_stream_args is None: - open_stream_args = {} - def write_block(write_path: str, block: Block): logger.debug(f"Writing {write_path} file.") fs = filesystem @@ -143,10 +132,8 @@ def write_block(write_path: str, block: Block): fs = fs.unwrap() if _block_udf is not None: block = _block_udf(block) - - with fs.open_output_stream(write_path, **open_stream_args) as f: - _write_block_to_file(f, BlockAccessor.for_block(block), - **write_args) + with fs.open_output_stream(write_path) as f: + _write_block_to_file(f, BlockAccessor.for_block(block)) write_block = cached_remote_fn(write_block) @@ -201,8 +188,9 @@ def _resolve_paths_and_filesystem( compatibility. """ import pyarrow as pa - from pyarrow.fs import FileSystem, PyFileSystem, FSSpecHandler, \ - _resolve_filesystem_and_path + from pyarrow.fs import (FileSystem, PyFileSystem, FSSpecHandler, + _resolve_filesystem_and_path) + import fsspec if isinstance(paths, str): paths = [paths] @@ -214,20 +202,11 @@ def _resolve_paths_and_filesystem( raise ValueError("Must provide at least one path.") if filesystem and not isinstance(filesystem, FileSystem): - err_msg = f"The filesystem passed must either conform to " \ - f"pyarrow.fs.FileSystem, or " \ - f"fsspec.spec.AbstractFileSystem. The provided " \ - f"filesystem was: {filesystem}" - try: - import fsspec - except ModuleNotFoundError: - # If filesystem is not a pyarrow filesystem and fsspec isn't - # installed, then filesystem is neither a pyarrow filesystem nor - # an fsspec filesystem, so we raise a TypeError. - raise TypeError(err_msg) if not isinstance(filesystem, fsspec.spec.AbstractFileSystem): - raise TypeError(err_msg) - + raise TypeError(f"The filesystem passed must either conform to " + f"pyarrow.fs.FileSystem, or " + f"fsspec.spec.AbstractFileSystem. The provided " + f"filesystem was: {filesystem}") filesystem = PyFileSystem(FSSpecHandler(filesystem)) resolved_paths = [] @@ -287,10 +266,9 @@ def _expand_paths(paths: Union[str, List[str]], return expanded_paths, file_infos -def _expand_directory( - path: str, - filesystem: "pyarrow.fs.FileSystem", - exclude_prefixes: Optional[List[str]] = None) -> List[str]: +def _expand_directory(path: str, + filesystem: "pyarrow.fs.FileSystem", + exclude_prefixes: List[str] = [".", "_"]) -> List[str]: """ Expand the provided directory path to a list of file paths. @@ -305,9 +283,6 @@ def _expand_directory( Returns: A list of file paths contained in the provided directory. """ - if exclude_prefixes is None: - exclude_prefixes = [".", "_"] - from pyarrow.fs import FileSelector selector = FileSelector(path, recursive=True) files = filesystem.get_file_info(selector) @@ -320,7 +295,7 @@ def _expand_directory( if not file_path.startswith(base_path): continue relative = file_path[len(base_path):] - if any(relative.startswith(prefix) for prefix in exclude_prefixes): + if any(relative.startswith(prefix) for prefix in [".", "_"]): continue filtered_paths.append((file_path, file_)) # We sort the paths to guarantee a stable order. diff --git a/python/ray/data/datasource/numpy_datasource.py b/python/ray/data/datasource/numpy_datasource.py index 8ba02e9d40cc5..08bc7f2c0916e 100644 --- a/python/ray/data/datasource/numpy_datasource.py +++ b/python/ray/data/datasource/numpy_datasource.py @@ -7,7 +7,7 @@ import pyarrow from ray.data.block import BlockAccessor -from ray.data.datasource.file_based_datasource import FileBasedDatasource +from ray.data.datasource.file_based_datasource import (FileBasedDatasource) class NumpyDatasource(FileBasedDatasource): @@ -21,22 +21,17 @@ class NumpyDatasource(FileBasedDatasource): """ def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args): - from ray.data.extensions import TensorArray - import pyarrow as pa # TODO(ekl) Ideally numpy can read directly from the file, but it # seems like it requires the file to be seekable. buf = BytesIO() data = f.readall() buf.write(data) buf.seek(0) - return pa.Table.from_pydict({ - "value": TensorArray(np.load(buf, allow_pickle=True)) - }) + return np.load(buf) def _write_block(self, f: "pyarrow.NativeFile", block: BlockAccessor, - column: str, **writer_args): - value = block.to_numpy(column) - np.save(f, value) + **writer_args): + np.save(f, block.to_arrow()) def _file_format(self): return "npy" diff --git a/python/ray/data/examples/demo_infer.py b/python/ray/data/examples/demo_infer.py index 352d8ddf31ec6..18237f7898541 100644 --- a/python/ray/data/examples/demo_infer.py +++ b/python/ray/data/examples/demo_infer.py @@ -18,7 +18,7 @@ def __call__(self, x): return x -ds = ds.window(blocks_per_window=10) \ +ds = ds.pipeline(parallelism=10) \ .map(preprocess) \ .map(Model, compute="actors", num_gpus=1) diff --git a/python/ray/data/extensions/tensor_extension.py b/python/ray/data/extensions/tensor_extension.py index 9872cf7e225ef..3c80fed64242f 100644 --- a/python/ray/data/extensions/tensor_extension.py +++ b/python/ray/data/extensions/tensor_extension.py @@ -140,7 +140,7 @@ class TensorDtype(pd.api.extensions.ExtensionDtype): one: int64 two: extension> - >>> read_df = ray.get(read_ds.to_pandas_refs())[0] + >>> read_df = ray.get(read_ds.to_pandas())[0] >>> read_df.dtypes one int64 two TensorDtype @@ -422,7 +422,7 @@ class TensorArray(pd.api.extensions.ExtensionArray, TensorOpsMixin): one: int64 two: extension> - >>> read_df = ray.get(read_ds.to_pandas_refs())[0] + >>> read_df = ray.get(read_ds.to_pandas())[0] >>> read_df.dtypes one int64 two TensorDtype @@ -1155,10 +1155,6 @@ def __arrow_ext_class__(self): """ return ArrowTensorArray - def __str__(self): - return "".format( - self.shape, self.storage_type.value_type) - @PublicAPI(stability="beta") class ArrowTensorArray(pa.ExtensionArray): diff --git a/python/ray/data/impl/arrow_block.py b/python/ray/data/impl/arrow_block.py index a9d0634930a49..41c5875bb6c16 100644 --- a/python/ray/data/impl/arrow_block.py +++ b/python/ray/data/impl/arrow_block.py @@ -13,6 +13,7 @@ from ray.data.block import Block, BlockAccessor, BlockMetadata from ray.data.impl.block_builder import BlockBuilder from ray.data.impl.simple_block import SimpleBlockBuilder +from ray.data.impl.tensor_block import TensorBlockBuilder if TYPE_CHECKING: import pandas @@ -77,6 +78,8 @@ def add(self, item: Any) -> None: self._builder = ArrowBlockBuilder() except (TypeError, pyarrow.lib.ArrowInvalid): self._builder = SimpleBlockBuilder() + elif isinstance(item, np.ndarray): + self._builder = TensorBlockBuilder() else: self._builder = SimpleBlockBuilder() self._builder.add(item) @@ -185,21 +188,8 @@ def schema(self) -> "pyarrow.lib.Schema": def to_pandas(self) -> "pandas.DataFrame": return self._table.to_pandas() - def to_numpy(self, column: str = None) -> np.ndarray: - if not column: - raise ValueError( - "`column` must be specified when calling .to_numpy() " - "on Arrow blocks.") - if column not in self._table.column_names: - raise ValueError( - "Cannot find column {}, available columns: {}".format( - column, self._table.column_names)) - array = self._table[column] - if array.num_chunks > 1: - # TODO(ekl) combine fails since we can't concat ArrowTensorType? - array = array.combine_chunks() - assert array.num_chunks == 1, array - return self._table[column].chunk(0).to_numpy() + def to_numpy(self) -> np.ndarray: + return np.array(self._table) def to_arrow(self) -> "pyarrow.Table": return self._table diff --git a/python/ray/data/impl/block_list.py b/python/ray/data/impl/block_list.py index 691a710b5faa6..b6e88c8fe4fc0 100644 --- a/python/ray/data/impl/block_list.py +++ b/python/ray/data/impl/block_list.py @@ -42,13 +42,6 @@ def split(self, split_size: int) -> List["BlockList"]: output.append(BlockList(b.tolist(), m.tolist())) return output - def divide(self, block_idx: int) -> ("BlockList", "BlockList"): - self._check_if_cleared() - return (BlockList(self._blocks[:block_idx], - self._metadata[:block_idx]), - BlockList(self._blocks[block_idx:], - self._metadata[block_idx:])) - def __len__(self): self._check_if_cleared() return len(self._blocks) diff --git a/python/ray/data/impl/compute.py b/python/ray/data/impl/compute.py index 8f0a7fb8e41f0..e52aa3bce13d1 100644 --- a/python/ray/data/impl/compute.py +++ b/python/ray/data/impl/compute.py @@ -35,10 +35,6 @@ def _map_block(block: Block, meta: BlockMetadata, class TaskPool(ComputeStrategy): def apply(self, fn: Any, remote_args: dict, blocks: BlockList[Any]) -> BlockList[Any]: - # Handle empty datasets. - if len(blocks) == 0: - return blocks - map_bar = ProgressBar("Map Progress", total=len(blocks)) kwargs = remote_args.copy() @@ -51,23 +47,8 @@ def apply(self, fn: Any, remote_args: dict, ] new_blocks, new_metadata = zip(*refs) - new_metadata = list(new_metadata) - try: - new_metadata = map_bar.fetch_until_complete(new_metadata) - except (ray.exceptions.RayTaskError, KeyboardInterrupt) as e: - # One or more mapper tasks failed, or we received a SIGINT signal - # while waiting; either way, we cancel all map tasks. - for ref in new_metadata: - ray.cancel(ref) - # Wait until all tasks have failed or been cancelled. - for ref in new_metadata: - try: - ray.get(ref) - except (ray.exceptions.RayTaskError, - ray.exceptions.TaskCancelledError): - pass - # Reraise the original task failure exception. - raise e from None + map_bar.block_until_complete(list(new_blocks)) + new_metadata = ray.get(list(new_metadata)) return BlockList(list(new_blocks), list(new_metadata)) diff --git a/python/ray/data/impl/lazy_block_list.py b/python/ray/data/impl/lazy_block_list.py index 0bfd1e0ac1093..7ccf8e58295ae 100644 --- a/python/ray/data/impl/lazy_block_list.py +++ b/python/ray/data/impl/lazy_block_list.py @@ -9,25 +9,19 @@ class LazyBlockList(BlockList[T]): - def __init__(self, - calls: Callable[[], ObjectRef[Block]], - metadata: List[BlockMetadata], - blocks: List[ObjectRef[Block]] = None): + def __init__(self, calls: Callable[[], ObjectRef[Block]], + metadata: List[BlockMetadata]): + assert len(calls) == len(metadata), (calls, metadata) self._calls = calls + self._blocks = [calls[0]()] if calls else [] self._metadata = metadata - if blocks: - self._blocks = blocks - else: - self._blocks = [None] * len(calls) - # Immediately compute the first block at least. - if calls: - self._blocks[0] = calls[0]() - assert len(calls) == len(metadata), (calls, metadata) - assert len(calls) == len(self._blocks), (calls, self._blocks) def copy(self) -> "LazyBlockList": - return LazyBlockList(self._calls.copy(), self._metadata.copy(), - self._blocks.copy()) + new_list = LazyBlockList.__new__(LazyBlockList) + new_list._calls = self._calls + new_list._blocks = self._blocks + new_list._metadata = self._metadata + return new_list def clear(self): super().clear() @@ -38,22 +32,11 @@ def split(self, split_size: int) -> List["LazyBlockList"]: num_splits = math.ceil(len(self._calls) / split_size) calls = np.array_split(self._calls, num_splits) meta = np.array_split(self._metadata, num_splits) - blocks = np.array_split(self._blocks, num_splits) output = [] - for c, m, b in zip(calls, meta, blocks): - output.append(LazyBlockList(c.tolist(), m.tolist(), b.tolist())) + for c, m in zip(calls, meta): + output.append(LazyBlockList(c.tolist(), m.tolist())) return output - def divide(self, block_idx: int) -> ("BlockList", "BlockList"): - self._check_if_cleared() - left = LazyBlockList(self._calls[:block_idx], - self._metadata[:block_idx], - self._blocks[:block_idx]) - right = LazyBlockList(self._calls[block_idx:], - self._metadata[block_idx:], - self._blocks[block_idx:]) - return left, right - def __len__(self): self._check_if_cleared() return len(self._calls) @@ -81,19 +64,9 @@ def _get_or_compute(self, i: int) -> ObjectRef[Block]: self._check_if_cleared() assert i < len(self._calls), i # Check if we need to compute more blocks. - if not self._blocks[i]: + if i >= len(self._blocks): + start = len(self._blocks) # Exponentially increase the number of blocks computed per batch. - for j in range(max(i + 1, i * 2)): - if j >= len(self._blocks): - break - if not self._blocks[j]: - self._blocks[j] = self._calls[j]() - assert self._blocks[i], self._blocks + for c in self._calls[start:max(i + 1, start * 2)]: + self._blocks.append(c()) return self._blocks[i] - - def _num_computed(self): - i = 0 - for b in self._blocks: - if b is not None: - i += 1 - return i diff --git a/python/ray/data/impl/pipeline_executor.py b/python/ray/data/impl/pipeline_executor.py index 7eeacc0a8cac1..c02b04ffdabb4 100644 --- a/python/ray/data/impl/pipeline_executor.py +++ b/python/ray/data/impl/pipeline_executor.py @@ -10,7 +10,7 @@ from ray.data.dataset_pipeline import DatasetPipeline -@ray.remote(num_cpus=0, placement_group=None) +@ray.remote def pipeline_stage(fn: Callable[[], Dataset[T]]) -> Dataset[T]: try: prev = set_progress_bars(False) @@ -27,15 +27,12 @@ def __init__(self, pipeline: "DatasetPipeline[T]"): self._iter = iter(self._pipeline._base_iterable) self._stages[0] = pipeline_stage.remote(next(self._iter)) - if self._pipeline._length and self._pipeline._length != float("inf"): - length = self._pipeline._length - else: - length = 1 - if self._pipeline._progress_bars: self._bars = [ - ProgressBar("Stage {}".format(i), length, position=i) - for i in range(len(self._stages)) + ProgressBar( + "Stage {}".format(i), + self._pipeline._length or 1, + position=i) for i in range(len(self._stages)) ] else: self._bars = None @@ -87,7 +84,7 @@ def __next__(self): return output -@ray.remote(num_cpus=0, placement_group=None) +@ray.remote class PipelineSplitExecutorCoordinator: def __init__(self, pipeline: "DatasetPipeline[T]", n: int, splitter: Callable[[Dataset], "DatasetPipeline[T]"]): diff --git a/python/ray/data/impl/progress_bar.py b/python/ray/data/impl/progress_bar.py index fc28da681f3ee..c9c1caa43cb5b 100644 --- a/python/ray/data/impl/progress_bar.py +++ b/python/ray/data/impl/progress_bar.py @@ -1,4 +1,4 @@ -from typing import List, Any +from typing import List import ray from ray.types import ObjectRef @@ -50,16 +50,6 @@ def block_until_complete(self, remaining: List[ObjectRef]) -> None: done, remaining = ray.wait(remaining, fetch_local=False) self.update(len(done)) - def fetch_until_complete(self, refs: List[ObjectRef]) -> List[Any]: - ref_to_result = {} - remaining = refs - while remaining: - done, remaining = ray.wait(remaining, fetch_local=True) - for ref, result in zip(done, ray.get(done)): - ref_to_result[ref] = result - self.update(len(done)) - return [ref_to_result[ref] for ref in refs] - def set_description(self, name: str) -> None: if self._bar: self._bar.set_description(name) diff --git a/python/ray/data/impl/remote_fn.py b/python/ray/data/impl/remote_fn.py index a6b4eb06d0f46..968380e187c50 100644 --- a/python/ray/data/impl/remote_fn.py +++ b/python/ray/data/impl/remote_fn.py @@ -13,10 +13,7 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: which means ray.remote cannot be used top-level in ray.data). """ if fn not in CACHED_FUNCTIONS: - default_ray_remote_args = { - "retry_exceptions": True, - "placement_group": None, - } + default_ray_remote_args = {"retry_exceptions": True} CACHED_FUNCTIONS[fn] = ray.remote(**{ **default_ray_remote_args, **ray_remote_args diff --git a/python/ray/data/impl/simple_block.py b/python/ray/data/impl/simple_block.py index f609c65bd28b8..ba20d1334b06b 100644 --- a/python/ray/data/impl/simple_block.py +++ b/python/ray/data/impl/simple_block.py @@ -58,9 +58,7 @@ def to_pandas(self) -> "pandas.DataFrame": import pandas return pandas.DataFrame(self._items) - def to_numpy(self, column: str = None) -> np.ndarray: - if column: - raise ValueError("`column` arg not supported for list block") + def to_numpy(self) -> np.ndarray: return np.array(self._items) def to_arrow(self) -> "pyarrow.Table": diff --git a/python/ray/data/impl/tensor_block.py b/python/ray/data/impl/tensor_block.py new file mode 100644 index 0000000000000..3ad8d8afad71b --- /dev/null +++ b/python/ray/data/impl/tensor_block.py @@ -0,0 +1,80 @@ +from typing import Iterator, List, TypeVar, Dict, TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + import pandas + import pyarrow + +from ray.data.block import Block, BlockAccessor +from ray.data.impl.block_builder import BlockBuilder + +T = TypeVar("T") + + +# TODO(ekl) switch to pyarrow.Tensor as the block type; currently there is a +# serialization issue with pyarrow tensors. +class TensorBlockBuilder(BlockBuilder[T]): + def __init__(self): + self._rows = [] + self._tensors: List[np.ndarray] = [] + self._num_rows = 0 + + def add(self, row: np.ndarray) -> None: + self._rows.append(row) + self._num_rows += 1 + + def add_block(self, block: np.ndarray) -> None: + assert isinstance(block, np.ndarray), block + self._tensors.append(block) + self._num_rows += len(block) + + def build(self) -> Block: + tensors = self._tensors.copy() + if self._rows: + tensors.append(np.stack(self._rows, axis=0)) + return np.concatenate(tensors, axis=0) + + def num_rows(self) -> int: + return self._num_rows + + +class TensorBlockAccessor(BlockAccessor): + def __init__(self, tensor: np.ndarray): + self._tensor = tensor + + def iter_rows(self) -> Iterator[np.ndarray]: + return iter(self._tensor) + + def slice(self, start: int, end: int, + copy: bool) -> "TensorBlockAccessor[T]": + view = self._tensor[start:end] + if copy: + view = view.copy() + return view + + def to_pandas(self) -> "pandas.DataFrame": + import pandas + return pandas.DataFrame(self._tensor) + + def to_numpy(self) -> np.ndarray: + return self._tensor + + def to_arrow(self) -> "pyarrow.Tensor": + import pyarrow + return pyarrow.Tensor.from_numpy(self._tensor) + + def schema(self) -> Dict: + shape = self._tensor.shape + shape = (None, ) + shape[1:] + return {"shape": shape, "dtype": self._tensor.dtype.name} + + def num_rows(self) -> int: + return len(self._tensor) + + def size_bytes(self) -> int: + return self._tensor.nbytes + + @staticmethod + def builder() -> TensorBlockBuilder[T]: + return TensorBlockBuilder() diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index fb98561489987..887e08baa1495 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -14,7 +14,7 @@ import ray from ray.types import ObjectRef -from ray.util.annotations import PublicAPI, DeveloperAPI +from ray.util.annotations import PublicAPI from ray.data.block import Block, BlockAccessor, BlockMetadata from ray.data.dataset import Dataset from ray.data.datasource import Datasource, RangeDatasource, \ @@ -283,7 +283,6 @@ def read_json(paths: Union[str, List[str]], filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, ray_remote_args: Dict[str, Any] = None, - arrow_open_stream_args: Optional[Dict[str, Any]] = None, **arrow_json_args) -> Dataset[ArrowRow]: """Create an Arrow dataset from json files. @@ -303,8 +302,6 @@ def read_json(paths: Union[str, List[str]], filesystem: The filesystem implementation to read from. parallelism: The amount of parallelism to use for the dataset. ray_remote_args: kwargs passed to ray.remote in the read tasks. - arrow_open_stream_args: kwargs passed to - pyarrow.fs.FileSystem.open_input_stream arrow_json_args: Other json read options to pass to pyarrow. Returns: @@ -316,7 +313,6 @@ def read_json(paths: Union[str, List[str]], paths=paths, filesystem=filesystem, ray_remote_args=ray_remote_args, - open_stream_args=arrow_open_stream_args, **arrow_json_args) @@ -326,7 +322,6 @@ def read_csv(paths: Union[str, List[str]], filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, ray_remote_args: Dict[str, Any] = None, - arrow_open_stream_args: Optional[Dict[str, Any]] = None, **arrow_csv_args) -> Dataset[ArrowRow]: """Create an Arrow dataset from csv files. @@ -346,8 +341,6 @@ def read_csv(paths: Union[str, List[str]], filesystem: The filesystem implementation to read from. parallelism: The amount of parallelism to use for the dataset. ray_remote_args: kwargs passed to ray.remote in the read tasks. - arrow_open_stream_args: kwargs passed to - pyarrow.fs.FileSystem.open_input_stream arrow_csv_args: Other csv read options to pass to pyarrow. Returns: @@ -359,7 +352,6 @@ def read_csv(paths: Union[str, List[str]], paths=paths, filesystem=filesystem, ray_remote_args=ray_remote_args, - open_stream_args=arrow_open_stream_args, **arrow_csv_args) @@ -370,7 +362,6 @@ def read_text( encoding: str = "utf-8", filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, - arrow_open_stream_args: Optional[Dict[str, Any]] = None, ) -> Dataset[str]: """Create a dataset from lines stored in text files. @@ -386,18 +377,13 @@ def read_text( encoding: The encoding of the files (e.g., "utf-8" or "ascii"). filesystem: The filesystem implementation to read from. parallelism: The amount of parallelism to use for the dataset. - arrow_open_stream_args: kwargs passed to - pyarrow.fs.FileSystem.open_input_stream Returns: Dataset holding lines of text read from the specified paths. """ return read_binary_files( - paths, - filesystem=filesystem, - parallelism=parallelism, - arrow_open_stream_args=arrow_open_stream_args).flat_map( + paths, filesystem=filesystem, parallelism=parallelism).flat_map( lambda x: x.decode(encoding).split("\n")) @@ -406,8 +392,7 @@ def read_numpy(paths: Union[str, List[str]], *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, - arrow_open_stream_args: Optional[Dict[str, Any]] = None, - **numpy_load_args) -> Dataset[ArrowRow]: + **numpy_load_args) -> Dataset[np.ndarray]: """Create an Arrow dataset from csv files. Examples: @@ -425,8 +410,6 @@ def read_numpy(paths: Union[str, List[str]], A list of paths can contain both files and directories. filesystem: The filesystem implementation to read from. parallelism: The amount of parallelism to use for the dataset. - arrow_open_stream_args: kwargs passed to - pyarrow.fs.FileSystem.open_input_stream numpy_load_args: Other options to pass to np.load. Returns: @@ -437,7 +420,6 @@ def read_numpy(paths: Union[str, List[str]], parallelism=parallelism, paths=paths, filesystem=filesystem, - open_stream_args=arrow_open_stream_args, **numpy_load_args) @@ -449,7 +431,6 @@ def read_binary_files( filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, ray_remote_args: Dict[str, Any] = None, - arrow_open_stream_args: Optional[Dict[str, Any]] = None, ) -> Dataset[Union[Tuple[str, bytes], bytes]]: """Create a dataset from binary files of arbitrary contents. @@ -468,8 +449,6 @@ def read_binary_files( filesystem: The filesystem implementation to read from. ray_remote_args: kwargs passed to ray.remote in the read tasks. parallelism: The amount of parallelism to use for the dataset. - arrow_open_stream_args: kwargs passed to - pyarrow.fs.FileSystem.open_input_stream Returns: Dataset holding Arrow records read from the specified paths. @@ -481,7 +460,6 @@ def read_binary_files( include_paths=include_paths, filesystem=filesystem, ray_remote_args=ray_remote_args, - open_stream_args=arrow_open_stream_args, schema=bytes) @@ -531,27 +509,12 @@ def from_modin(df: "modin.DataFrame") -> Dataset[ArrowRow]: from modin.distributed.dataframe.pandas.partitions import unwrap_partitions parts = unwrap_partitions(df, axis=0) - return from_pandas_refs(parts) + return from_pandas(parts) @PublicAPI(stability="beta") -def from_pandas(dfs: List["pandas.DataFrame"]) -> Dataset[ArrowRow]: - """Create a dataset from a list of Pandas dataframes. - - Args: - dfs: A list of Pandas dataframes. - - Returns: - Dataset holding Arrow records read from the dataframes. - """ - return from_pandas_refs([ray.put(df) for df in dfs]) - - -@DeveloperAPI -def from_pandas_refs( - dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]: - """Create a dataset from a list of Ray object references to Pandas - dataframes. +def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]: + """Create a dataset from a set of Pandas dataframes. Args: dfs: A list of Ray object references to pandas dataframes. @@ -566,7 +529,7 @@ def from_pandas_refs( return Dataset(BlockList(blocks, ray.get(list(metadata)))) -def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[ArrowRow]: +def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]: """Create a dataset from a set of NumPy ndarrays. Args: @@ -583,23 +546,8 @@ def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[ArrowRow]: @PublicAPI(stability="beta") -def from_arrow( - tables: List[Union["pyarrow.Table", bytes]]) -> Dataset[ArrowRow]: - """Create a dataset from a list of Arrow tables. - - Args: - tables: A list of Ray object references to Arrow tables, - or its streaming format in bytes. - - Returns: - Dataset holding Arrow records from the tables. - """ - return from_arrow_refs([ray.put(t) for t in tables]) - - -@DeveloperAPI -def from_arrow_refs(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]] - ) -> Dataset[ArrowRow]: +def from_arrow(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]] + ) -> Dataset[ArrowRow]: """Create a dataset from a set of Arrow tables. Args: @@ -642,11 +590,8 @@ def _df_to_block(df: "pandas.DataFrame") -> Block[ArrowRow]: def _ndarray_to_block(ndarray: np.ndarray) -> Block[np.ndarray]: - import pyarrow as pa - from ray.data.extensions import TensorArray - table = pa.Table.from_pydict({"value": TensorArray(ndarray)}) - return (table, - BlockAccessor.for_block(table).get_metadata(input_files=None)) + return (ndarray, + BlockAccessor.for_block(ndarray).get_metadata(input_files=None)) def _get_schema(block: Block) -> Any: diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 91aa91e5c2eb2..7562e2c5a7105 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -17,11 +17,9 @@ import ray from ray.tests.conftest import * # noqa -from ray.data.dataset import Dataset from ray.data.datasource import DummyOutputDatasource from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.block import BlockAccessor -from ray.data.impl.block_list import BlockList from ray.data.datasource.file_based_datasource import _unwrap_protocol from ray.data.extensions.tensor_extension import ( TensorArray, TensorDtype, ArrowTensorType, ArrowTensorArray) @@ -31,7 +29,7 @@ def maybe_pipeline(ds, enabled): if enabled: - return ds.window(blocks_per_window=1) + return ds.pipeline(parallelism=1) else: return ds @@ -60,10 +58,7 @@ def run(): assert sorted(ds.iter_rows()) == [0, 1, 2, 3, 4] pg = ray.util.placement_group([{"CPU": 1}]) - ray.get( - run.options( - placement_group=pg, - placement_group_capture_child_tasks=True).remote()) + ray.get(run.options(placement_group=pg).remote()) @pytest.mark.parametrize("pipelined", [False, True]) @@ -147,102 +142,6 @@ def __call__(self, x): assert len(actor_reuse) == 10, actor_reuse -def test_transform_failure(shutdown_only): - ray.init(num_cpus=2) - ds = ray.data.from_items([0, 10], parallelism=2) - - def mapper(x): - time.sleep(x) - raise ValueError("oops") - return x - - with pytest.raises(ray.exceptions.RayTaskError): - ds.map(mapper) - - -@pytest.mark.parametrize( - "block_sizes,num_splits", - [ - ( # Test baseline. - [3, 6, 3], 3), - ( # Already balanced. - [3, 3, 3], 3), - ( # Row truncation. - [3, 6, 4], 3), - ( # Row truncation, smaller number of blocks. - [3, 6, 2, 3], 3), - ( # Row truncation, larger number of blocks. - [5, 6, 2, 5], 5), - ( # All smaller but one. - [1, 1, 1, 1, 6], 5), - ( # All larger but one. - [4, 4, 4, 4, 1], 5), - ( # Single block. - [2], 2), - ( # Single split. - [2, 5], 1), - ]) -def test_equal_split_balanced(ray_start_regular_shared, block_sizes, - num_splits): - _test_equal_split_balanced(block_sizes, num_splits) - - -def _test_equal_split_balanced(block_sizes, num_splits): - blocks = [] - metadata = [] - total_rows = 0 - for block_size in block_sizes: - block = list(range(total_rows, total_rows + block_size)) - blocks.append(ray.put(block)) - metadata.append(BlockAccessor.for_block(block).get_metadata(None)) - total_rows += block_size - block_list = BlockList(blocks, metadata) - ds = Dataset(block_list) - - splits = ds.split(num_splits, equal=True) - split_counts = [split.count() for split in splits] - assert len(split_counts) == num_splits - expected_block_size = total_rows // num_splits - # Check that all splits are the expected size. - assert all([count == expected_block_size for count in split_counts]) - expected_total_rows = sum(split_counts) - # Check that the expected number of rows were dropped. - assert total_rows - expected_total_rows == total_rows % num_splits - # Check that all rows are unique (content check). - split_rows = [row for split in splits for row in split.take(total_rows)] - assert len(set(split_rows)) == len(split_rows) - - -def test_equal_split_balanced_grid(ray_start_regular_shared): - - # Tests balanced equal splitting over a grid of configurations. - # Grid: num_blocks x num_splits x num_rows_block_1 x ... x num_rows_block_n - seed = int(time.time()) - print(f"Seeding RNG for test_equal_split_balanced_grid with: {seed}") - random.seed(seed) - max_num_splits = 20 - num_splits_samples = 5 - max_num_blocks = 50 - max_num_rows_per_block = 100 - num_blocks_samples = 5 - block_sizes_samples = 5 - for num_splits in np.random.randint( - 2, max_num_splits + 1, size=num_splits_samples): - for num_blocks in np.random.randint( - 1, max_num_blocks + 1, size=num_blocks_samples): - block_sizes_list = [ - np.random.randint( - 1, max_num_rows_per_block + 1, size=num_blocks) - for _ in range(block_sizes_samples) - ] - for block_sizes in block_sizes_list: - if sum(block_sizes) < num_splits: - min_ = math.ceil(num_splits / num_blocks) - block_sizes = np.random.randint( - min_, max_num_rows_per_block + 1, size=num_blocks) - _test_equal_split_balanced(block_sizes, num_splits) - - @pytest.mark.parametrize("pipelined", [False, True]) def test_basic(ray_start_regular_shared, pipelined): ds0 = ray.data.range(5) @@ -296,15 +195,30 @@ def test_batch_tensors(ray_start_regular_shared): def test_tensors(ray_start_regular_shared): # Create directly. ds = ray.data.range_tensor(5, shape=(3, 5)) - assert str(ds) == ( - "Dataset(num_blocks=5, num_rows=5, " - "schema={value: })") + assert str(ds) == ("Dataset(num_blocks=5, num_rows=5, " + "schema=)") + + # Transform. + ds = ds.map_batches(lambda t: np.expand_dims(t, 3)) + assert str(ds) == ("Dataset(num_blocks=5, num_rows=5, " + "schema=)") # Pandas conversion. res = ray.data.range_tensor(10).map_batches( lambda t: t + 2, batch_format="pandas").take(2) - assert str(res) == \ - "[ArrowRow({'value': array([2])}), ArrowRow({'value': array([3])})]" + assert str(res) == "[ArrowRow({'0': 2}), ArrowRow({'0': 3})]", res + + # From other formats. + ds = ray.data.range(10).map_batches(lambda x: np.array(x)) + assert str(ds) == ("Dataset(num_blocks=10, num_rows=10, " + "schema=)") + ds = ray.data.range(10).map(lambda x: np.array(x)) + assert str(ds) == ("Dataset(num_blocks=10, num_rows=10, " + "schema=)") + ds = ray.data.from_items([np.zeros(shape=(2, 2, 2)) for _ in range(4)]) + assert str(ds) == ( + "Dataset(num_blocks=4, num_rows=4, " + "schema=)"), ds def test_tensor_array_ops(ray_start_regular_shared): @@ -394,7 +308,7 @@ def test_tensors_in_tables_from_pandas(ray_start_regular_shared): df = pd.DataFrame({"one": list(range(outer_dim)), "two": list(arr)}) # Cast column to tensor extension dtype. df["two"] = df["two"].astype(TensorDtype()) - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) values = [[s["one"], s["two"]] for s in ds.take()] expected = list(zip(list(range(outer_dim)), arr)) for v, e in zip(sorted(values), expected): @@ -408,8 +322,8 @@ def test_tensors_in_tables_pandas_roundtrip(ray_start_regular_shared): num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape) df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)}) - ds = ray.data.from_pandas([df]) - ds_df = ds.to_pandas() + ds = ray.data.from_pandas([ray.put(df)]) + ds_df = ray.get(ds.to_pandas())[0] assert ds_df.equals(df) @@ -421,7 +335,7 @@ def test_tensors_in_tables_parquet_roundtrip(ray_start_regular_shared, num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape) df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)}) - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds.write_parquet(str(tmp_path)) ds = ray.data.read_parquet(str(tmp_path)) values = [[s["one"], s["two"]] for s in ds.take()] @@ -438,7 +352,7 @@ def test_tensors_in_tables_parquet_with_schema(ray_start_regular_shared, num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape) df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)}) - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds.write_parquet(str(tmp_path)) schema = pa.schema([ ("one", pa.int32()), @@ -464,7 +378,7 @@ def test_tensors_in_tables_parquet_pickle_manual_serde( "one": list(range(outer_dim)), "two": [pickle.dumps(a) for a in arr] }) - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds.write_parquet(str(tmp_path)) ds = ray.data.read_parquet(str(tmp_path)) @@ -507,7 +421,7 @@ def test_tensors_in_tables_parquet_bytes_manual_serde(ray_start_regular_shared, "one": list(range(outer_dim)), "two": [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds.write_parquet(str(tmp_path)) ds = ray.data.read_parquet(str(tmp_path)) @@ -546,7 +460,7 @@ def test_tensors_in_tables_parquet_bytes_manual_serde_udf( "one": list(range(outer_dim)), tensor_col_name: [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds.write_parquet(str(tmp_path)) # Manually deserialize the tensor bytes and cast to a TensorArray. @@ -585,7 +499,7 @@ def test_tensors_in_tables_parquet_bytes_manual_serde_col_schema( "one": list(range(outer_dim)), tensor_col_name: [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds.write_parquet(str(tmp_path)) def _block_udf(block: pa.Table): @@ -622,7 +536,7 @@ def test_tensors_in_tables_parquet_bytes_with_schema(ray_start_regular_shared, "one": list(range(outer_dim)), "two": [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds.write_parquet(str(tmp_path)) schema = pa.schema([ ("one", pa.int32()), @@ -660,7 +574,7 @@ def test_tensors_in_tables_to_torch(ray_start_regular_shared, pipelined): "label": [4.0, 5.0, 6.0] }) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([df1, df2]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) ds = maybe_pipeline(ds, pipelined) torchd = ds.to_torch(label_column="label", batch_size=2) @@ -700,7 +614,7 @@ def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined): "label": TensorArray(arr2), }) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([df1, df2]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) ds = maybe_pipeline(ds, pipelined) tfd = ds.to_tf( label_column="label", @@ -725,11 +639,13 @@ def test_numpy_roundtrip(ray_start_regular_shared, fs, data_path): ds = ray.data.range_tensor(10, parallelism=2) ds.write_numpy(data_path, filesystem=fs) ds = ray.data.read_numpy(data_path, filesystem=fs) - assert str(ds) == ( - "Dataset(num_blocks=2, num_rows=None, " - "schema={value: })") - assert str(ds.take(2)) == \ - "[ArrowRow({'value': array([0])}), ArrowRow({'value': array([1])})]" + assert str(ds) == ("Dataset(num_blocks=2, num_rows=?, " + "schema=)") + + assert str( + ds.take()) == ("[array([0]), array([1]), array([2]), " + "array([3]), array([4]), array([5]), array([6]), " + "array([7]), array([8]), array([9])]"), ds.take() def test_numpy_read(ray_start_regular_shared, tmp_path): @@ -738,11 +654,13 @@ def test_numpy_read(ray_start_regular_shared, tmp_path): np.save( os.path.join(path, "test.npy"), np.expand_dims(np.arange(0, 10), 1)) ds = ray.data.read_numpy(path) - assert str(ds) == ( - "Dataset(num_blocks=1, num_rows=None, " - "schema={value: })") - assert str(ds.take(2)) == \ - "[ArrowRow({'value': array([0])}), ArrowRow({'value': array([1])})]" + assert str(ds) == ("Dataset(num_blocks=1, num_rows=?, " + "schema=)") + + assert str( + ds.take()) == ("[array([0]), array([1]), array([2]), " + "array([3]), array([4]), array([5]), array([6]), " + "array([7]), array([8]), array([9])]"), ds.take() @pytest.mark.parametrize("fs,data_path,endpoint_url", [ @@ -764,12 +682,7 @@ def test_numpy_write(ray_start_regular_shared, fs, data_path, endpoint_url): s3 = S3FileSystem(client_kwargs={"endpoint_url": endpoint_url}) arr1 = np.load(s3.open(file_path1)) arr2 = np.load(s3.open(file_path2)) - assert ds.count() == 10 - assert len(arr1) == 5 - assert len(arr2) == 5 - assert arr1.sum() == 10 - assert arr2.sum() == 35 - assert str(ds.take(1)) == "[ArrowRow({'value': array([0])})]" + np.testing.assert_equal(np.concatenate((arr1, arr2)), ds.take()) def test_read_text(ray_start_regular_shared, tmp_path): @@ -820,16 +733,6 @@ def test_empty_dataset(ray_start_regular_shared): assert str(ds) == \ "Dataset(num_blocks=1, num_rows=0, schema=Unknown schema)" - # Test map on empty dataset. - ds = ray.data.from_items([]) - ds = ds.map(lambda x: x) - assert ds.count() == 0 - - # Test filter on empty dataset. - ds = ray.data.from_items([]) - ds = ds.filter(lambda: True) - assert ds.count() == 0 - def test_schema(ray_start_regular_shared): ds = ray.data.range(10) @@ -848,17 +751,17 @@ def test_schema(ray_start_regular_shared): def test_lazy_loading_exponential_rampup(ray_start_regular_shared): ds = ray.data.range(100, parallelism=20) - assert ds._blocks._num_computed() == 1 + assert len(ds._blocks._blocks) == 1 assert ds.take(10) == list(range(10)) - assert ds._blocks._num_computed() == 2 + assert len(ds._blocks._blocks) == 2 assert ds.take(20) == list(range(20)) - assert ds._blocks._num_computed() == 4 + assert len(ds._blocks._blocks) == 4 assert ds.take(30) == list(range(30)) - assert ds._blocks._num_computed() == 8 + assert len(ds._blocks._blocks) == 8 assert ds.take(50) == list(range(50)) - assert ds._blocks._num_computed() == 16 + assert len(ds._blocks._blocks) == 16 assert ds.take(100) == list(range(100)) - assert ds._blocks._num_computed() == 20 + assert len(ds._blocks._blocks) == 20 def test_limit(ray_start_regular_shared): @@ -931,16 +834,7 @@ def test_repartition_arrow(ray_start_regular_shared): def test_from_pandas(ray_start_regular_shared): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([df1, df2]) - values = [(r["one"], r["two"]) for r in ds.take(6)] - rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] - assert values == rows - - -def test_from_pandas_refs(ray_start_regular_shared): - df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas_refs([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) values = [(r["one"], r["two"]) for r in ds.take(6)] rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] assert values == rows @@ -951,27 +845,13 @@ def test_from_numpy(ray_start_regular_shared): arr2 = np.expand_dims(np.arange(4, 8), 1) ds = ray.data.from_numpy([ray.put(arr1), ray.put(arr2)]) values = np.array(ds.take(8)) - for i in range(4): - assert values[i]["value"] == arr1[i] - for i in range(4, 8): - assert values[i]["value"] == arr2[i - 4] + np.testing.assert_equal(np.concatenate((arr1, arr2)), values) def test_from_arrow(ray_start_regular_shared): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_arrow( - [pa.Table.from_pandas(df1), - pa.Table.from_pandas(df2)]) - values = [(r["one"], r["two"]) for r in ds.take(6)] - rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] - assert values == rows - - -def test_from_arrow_refs(ray_start_regular_shared): - df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_arrow_refs([ + ds = ray.data.from_arrow([ ray.put(pa.Table.from_pandas(df1)), ray.put(pa.Table.from_pandas(df2)) ]) @@ -984,36 +864,20 @@ def test_to_pandas(ray_start_regular_shared): n = 5 df = pd.DataFrame({"value": list(range(n))}) ds = ray.data.range_arrow(n) - dfds = ds.to_pandas() - assert df.equals(dfds) - - # Test limit. - dfds = ds.to_pandas(limit=3) - assert df[:3].equals(dfds) - - # Test limit greater than number of rows. - dfds = ds.to_pandas(limit=6) - assert df.equals(dfds) - - -def test_to_pandas_refs(ray_start_regular_shared): - n = 5 - df = pd.DataFrame({"value": list(range(n))}) - ds = ray.data.range_arrow(n) - dfds = pd.concat(ray.get(ds.to_pandas_refs()), ignore_index=True) + dfds = pd.concat(ray.get(ds.to_pandas()), ignore_index=True) assert df.equals(dfds) def test_to_numpy(ray_start_regular_shared): # Tensor Dataset ds = ray.data.range_tensor(10, parallelism=2) - arr = np.concatenate(ray.get(ds.to_numpy(column="value"))) + arr = np.concatenate(ray.get(ds.to_numpy())) np.testing.assert_equal(arr, np.expand_dims(np.arange(0, 10), 1)) # Table Dataset ds = ray.data.range_arrow(10) - arr = np.concatenate(ray.get(ds.to_numpy(column="value"))) - np.testing.assert_equal(arr, np.arange(0, 10)) + arr = np.concatenate(ray.get(ds.to_numpy())) + np.testing.assert_equal(arr, np.expand_dims(np.arange(0, 10), 1)) # Simple Dataset ds = ray.data.range(10) @@ -1024,41 +888,23 @@ def test_to_numpy(ray_start_regular_shared): def test_to_arrow(ray_start_regular_shared): n = 5 - # Zero-copy. - df = pd.DataFrame({"value": list(range(n))}) - ds = ray.data.range_arrow(n) - dfds = pd.concat([t.to_pandas() for t in ds.to_arrow()], ignore_index=True) - assert df.equals(dfds) - - # Conversion. - df = pd.DataFrame({0: list(range(n))}) - ds = ray.data.range(n) - dfds = pd.concat([t.to_pandas() for t in ds.to_arrow()], ignore_index=True) - assert df.equals(dfds) - - -def test_to_arrow_refs(ray_start_regular_shared): - n = 5 - # Zero-copy. df = pd.DataFrame({"value": list(range(n))}) ds = ray.data.range_arrow(n) dfds = pd.concat( - [t.to_pandas() for t in ray.get(ds.to_arrow_refs())], - ignore_index=True) + [t.to_pandas() for t in ray.get(ds.to_arrow())], ignore_index=True) assert df.equals(dfds) # Conversion. df = pd.DataFrame({0: list(range(n))}) ds = ray.data.range(n) dfds = pd.concat( - [t.to_pandas() for t in ray.get(ds.to_arrow_refs())], - ignore_index=True) + [t.to_pandas() for t in ray.get(ds.to_arrow())], ignore_index=True) assert df.equals(dfds) -def test_get_internal_block_refs(ray_start_regular_shared): - blocks = ray.data.range(10).get_internal_block_refs() +def test_get_blocks(ray_start_regular_shared): + blocks = ray.data.range(10).get_blocks() assert len(blocks) == 10 out = [] for b in ray.get(blocks): @@ -1070,9 +916,9 @@ def test_get_internal_block_refs(ray_start_regular_shared): def test_pandas_roundtrip(ray_start_regular_shared, tmp_path): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([df1, df2]) - dfds = ds.to_pandas() - assert pd.concat([df1, df2], ignore_index=True).equals(dfds) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + dfds = pd.concat(ray.get(ds.to_pandas())) + assert pd.concat([df1, df2]).equals(dfds) def test_fsspec_filesystem(ray_start_regular_shared, tmp_path): @@ -1096,7 +942,7 @@ def test_fsspec_filesystem(ray_start_regular_shared, tmp_path): ds = ray.data.read_parquet([path1, path2], filesystem=fs) # Test metadata-only parquet ops. - assert ds._blocks._num_computed() == 1 + assert len(ds._blocks._blocks) == 1 assert ds.count() == 6 out_path = os.path.join(tmp_path, "out") @@ -1135,7 +981,7 @@ def test_parquet_read(ray_start_regular_shared, fs, data_path): ds = ray.data.read_parquet(data_path, filesystem=fs) # Test metadata-only parquet ops. - assert ds._blocks._num_computed() == 1 + assert len(ds._blocks._blocks) == 1 assert ds.count() == 6 assert ds.size_bytes() > 0 assert ds.schema() is not None @@ -1149,11 +995,11 @@ def test_parquet_read(ray_start_regular_shared, fs, data_path): assert repr(ds) == \ "Dataset(num_blocks=2, num_rows=6, " \ "schema={one: int64, two: string})", ds - assert ds._blocks._num_computed() == 1 + assert len(ds._blocks._blocks) == 1 # Forces a data read. values = [[s["one"], s["two"]] for s in ds.take()] - assert ds._blocks._num_computed() == 2 + assert len(ds._blocks._blocks) == 2 assert sorted(values) == [[1, "a"], [2, "b"], [3, "c"], [4, "e"], [5, "f"], [6, "g"]] @@ -1184,7 +1030,7 @@ def test_parquet_read_partitioned(ray_start_regular_shared, fs, data_path): ds = ray.data.read_parquet(data_path, filesystem=fs) # Test metadata-only parquet ops. - assert ds._blocks._num_computed() == 1 + assert len(ds._blocks._blocks) == 1 assert ds.count() == 6 assert ds.size_bytes() > 0 assert ds.schema() is not None @@ -1198,11 +1044,11 @@ def test_parquet_read_partitioned(ray_start_regular_shared, fs, data_path): "Dataset(num_blocks=2, num_rows=6, " \ "schema={two: string, " \ "one: dictionary})", ds - assert ds._blocks._num_computed() == 1 + assert len(ds._blocks._blocks) == 1 # Forces a data read. values = [[s["one"], s["two"]] for s in ds.take()] - assert ds._blocks._num_computed() == 2 + assert len(ds._blocks._blocks) == 2 assert sorted(values) == [[1, "a"], [1, "b"], [1, "c"], [3, "e"], [3, "f"], [3, "g"]] @@ -1231,7 +1077,7 @@ def test_parquet_read_partitioned_with_filter(ray_start_regular_shared, str(tmp_path), parallelism=1, filter=(pa.dataset.field("two") == "a")) values = [[s["one"], s["two"]] for s in ds.take()] - assert ds._blocks._num_computed() == 1 + assert len(ds._blocks._blocks) == 1 assert sorted(values) == [[1, "a"], [1, "a"]] # 2 partitions, 1 empty partition, 2 block/read tasks, 1 empty block @@ -1240,7 +1086,7 @@ def test_parquet_read_partitioned_with_filter(ray_start_regular_shared, str(tmp_path), parallelism=2, filter=(pa.dataset.field("two") == "a")) values = [[s["one"], s["two"]] for s in ds.take()] - assert ds._blocks._num_computed() == 2 + assert len(ds._blocks._blocks) == 2 assert sorted(values) == [[1, "a"], [1, "a"]] @@ -1268,7 +1114,7 @@ def _block_udf(block: pa.Table): str(tmp_path), parallelism=1, _block_udf=_block_udf) ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()]) - assert ds._blocks._num_computed() == 1 + assert len(ds._blocks._blocks) == 1 np.testing.assert_array_equal(sorted(ones), np.array(one_data) + 1) # 2 blocks/read tasks @@ -1277,7 +1123,7 @@ def _block_udf(block: pa.Table): str(tmp_path), parallelism=2, _block_udf=_block_udf) ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()]) - assert ds._blocks._num_computed() == 2 + assert len(ds._blocks._blocks) == 2 np.testing.assert_array_equal(sorted(ones), np.array(one_data) + 1) # 2 blocks/read tasks, 1 empty block @@ -1289,7 +1135,7 @@ def _block_udf(block: pa.Table): _block_udf=_block_udf) ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()]) - assert ds._blocks._num_computed() == 2 + assert len(ds._blocks._blocks) == 2 np.testing.assert_array_equal(sorted(ones), np.array(one_data[:2]) + 1) @@ -1306,7 +1152,7 @@ def test_parquet_write(ray_start_regular_shared, fs, data_path, endpoint_url): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([df1, df2]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) path = os.path.join(data_path, "test_parquet_dir") if fs is None: os.mkdir(path) @@ -1341,7 +1187,7 @@ def test_parquet_write_create_dir(ray_start_regular_shared, fs, data_path, df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([df1, df2]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) path = os.path.join(data_path, "test_parquet_dir") ds._set_uuid("data") ds.write_parquet(path, filesystem=fs) @@ -1395,7 +1241,7 @@ def test_parquet_write_with_udf(ray_start_regular_shared, tmp_path): df1 = pd.DataFrame({"one": one_data[:3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": one_data[3:], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([df1, df2]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) def _block_udf(block: pa.Table): df = block.to_pandas() @@ -1420,7 +1266,7 @@ def _block_udf(block: pa.Table): def test_parquet_roundtrip(ray_start_regular_shared, fs, data_path): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([df1, df2]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) ds._set_uuid("data") path = os.path.join(data_path, "test_parquet_dir") if fs is None: @@ -1429,8 +1275,8 @@ def test_parquet_roundtrip(ray_start_regular_shared, fs, data_path): fs.create_dir(_unwrap_protocol(path)) ds.write_parquet(path, filesystem=fs) ds2 = ray.data.read_parquet(path, parallelism=2, filesystem=fs) - ds2df = ds2.to_pandas() - assert pd.concat([df1, df2], ignore_index=True).equals(ds2df) + ds2df = pd.concat(ray.get(ds2.to_pandas())) + assert pd.concat([df1, df2]).equals(ds2df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() == meta.size_bytes @@ -1511,7 +1357,9 @@ def test_iter_batches_basic(ray_start_regular_shared): df3 = pd.DataFrame({"one": [7, 8, 9], "two": [8, 9, 10]}) df4 = pd.DataFrame({"one": [10, 11, 12], "two": [11, 12, 13]}) dfs = [df1, df2, df3, df4] - ds = ray.data.from_pandas(dfs) + ds = ray.data.from_pandas( + [ray.put(df1), ray.put(df2), + ray.put(df3), ray.put(df4)]) # Default. for batch, df in zip(ds.iter_batches(batch_format="pandas"), dfs): @@ -1621,7 +1469,7 @@ def test_iter_batches_grid(ray_start_regular_shared): })) running_size += block_size num_rows = running_size - ds = ray.data.from_pandas(dfs) + ds = ray.data.from_pandas([ray.put(df) for df in dfs]) for batch_size in np.random.randint( 1, num_rows + 1, size=batch_size_samples): for drop_last in (False, True): @@ -1637,7 +1485,10 @@ def test_iter_batches_grid(ray_start_regular_shared): # Concatenated batches should equal the DataFrame # representation of the entire dataset. assert pd.concat( - batches, ignore_index=True).equals(ds.to_pandas()) + batches, ignore_index=True).equals( + pd.concat( + ray.get(ds.to_pandas()), + ignore_index=True)) else: # Number of batches should be equal to # num_rows / batch_size, rounded down. @@ -1647,8 +1498,9 @@ def test_iter_batches_grid(ray_start_regular_shared): # remainder sliced off. assert pd.concat( batches, ignore_index=True).equals( - ds.to_pandas()[:batch_size * - (num_rows // batch_size)]) + pd.concat( + ray.get(ds.to_pandas()), ignore_index=True) + [:batch_size * (num_rows // batch_size)]) if num_rows % batch_size == 0 or drop_last: assert all( len(batch) == batch_size for batch in batches) @@ -1663,7 +1515,7 @@ def test_lazy_loading_iter_batches_exponential_rampup( ds = ray.data.range(32, parallelism=8) expected_num_blocks = [1, 2, 4, 4, 8, 8, 8, 8] for _, expected in zip(ds.iter_batches(), expected_num_blocks): - assert ds._blocks._num_computed() == expected + assert len(ds._blocks._blocks) == expected def test_map_batch(ray_start_regular_shared, tmp_path): @@ -1917,7 +1769,7 @@ def test_from_dask(ray_start_regular_shared): df = pd.DataFrame({"one": list(range(100)), "two": list(range(100))}) ddf = dd.from_pandas(df, npartitions=10) ds = ray.data.from_dask(ddf) - dfds = ds.to_pandas() + dfds = pd.concat(ray.get(ds.to_pandas())) assert df.equals(dfds) @@ -1926,7 +1778,7 @@ def test_to_dask(ray_start_regular_shared): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([df1, df2]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) ddf = ds.to_dask() # Explicit Dask-on-Ray assert df.equals(ddf.compute(scheduler=ray_dask_get)) @@ -1939,7 +1791,7 @@ def test_from_modin(ray_start_regular_shared): df = pd.DataFrame({"one": list(range(100)), "two": list(range(100))}, ) modf = mopd.DataFrame(df) ds = ray.data.from_modin(modf) - dfds = ds.to_pandas() + dfds = pd.concat(ray.get(ds.to_pandas())) assert df.equals(dfds) @@ -1971,7 +1823,7 @@ def test_to_tf(ray_start_regular_shared, pipelined): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]) - ds = ray.data.from_pandas([df1, df2, df3]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) ds = maybe_pipeline(ds, pipelined) tfd = ds.to_tf( label_column="label", @@ -1999,7 +1851,7 @@ def test_to_tf_feature_columns(ray_start_regular_shared): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]).drop("two", axis=1) - ds = ray.data.from_pandas([df1, df2, df3]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) tfd = ds.to_tf( label_column="label", feature_columns=["one"], @@ -2028,7 +1880,7 @@ def test_to_torch(ray_start_regular_shared, pipelined): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]) - ds = ray.data.from_pandas([df1, df2, df3]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) ds = maybe_pipeline(ds, pipelined) torchd = ds.to_torch(label_column="label", batch_size=3) @@ -2055,7 +1907,7 @@ def test_to_torch_feature_columns(ray_start_regular_shared): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]).drop("two", axis=1) - ds = ray.data.from_pandas([df1, df2, df3]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) torchd = ds.to_torch( label_column="label", feature_columns=["one"], batch_size=3) iterations = [] @@ -2082,7 +1934,7 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df1.to_json( path1, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json(path1, filesystem=fs) - dsdf = ds.to_pandas() + dsdf = ray.get(ds.to_pandas())[0] assert df1.equals(dsdf) # Test metadata ops. assert ds.count() == 3 @@ -2095,8 +1947,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df2.to_json( path2, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json([path1, path2], parallelism=2, filesystem=fs) - dsdf = ds.to_pandas() - df = pd.concat([df1, df2], ignore_index=True) + dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2]) assert df.equals(dsdf) # Test metadata ops. for block, meta in zip(ds._blocks, ds._blocks.get_metadata()): @@ -2110,7 +1962,7 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): ds = ray.data.read_json( [path1, path2, path3], parallelism=2, filesystem=fs) df = pd.concat([df1, df2, df3], ignore_index=True) - dsdf = ds.to_pandas() + dsdf = pd.concat(ray.get(ds.to_pandas()), ignore_index=True) assert df.equals(dsdf) # Directory, two files. @@ -2128,8 +1980,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df2.to_json( path2, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json(path, filesystem=fs) - df = pd.concat([df1, df2], ignore_index=True) - dsdf = ds.to_pandas() + df = pd.concat([df1, df2]) + dsdf = pd.concat(ray.get(ds.to_pandas())) assert df.equals(dsdf) if fs is None: shutil.rmtree(path) @@ -2167,8 +2019,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): lines=True, storage_options=storage_options) ds = ray.data.read_json([path1, path2], filesystem=fs) - df = pd.concat([df1, df2, df3], ignore_index=True) - dsdf = ds.to_pandas() + df = pd.concat([df1, df2, df3]) + dsdf = pd.concat(ray.get(ds.to_pandas())) assert df.equals(dsdf) if fs is None: shutil.rmtree(path1) @@ -2192,8 +2044,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df2.to_json( path2, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json([dir_path, path2], filesystem=fs) - df = pd.concat([df1, df2], ignore_index=True) - dsdf = ds.to_pandas() + df = pd.concat([df1, df2]) + dsdf = pd.concat(ray.get(ds.to_pandas())) assert df.equals(dsdf) if fs is None: shutil.rmtree(dir_path) @@ -2207,7 +2059,7 @@ def test_zipped_json_read(ray_start_regular_shared, tmp_path): path1 = os.path.join(tmp_path, "test1.json.gz") df1.to_json(path1, compression="gzip", orient="records", lines=True) ds = ray.data.read_json(path1) - assert df1.equals(ds.to_pandas()) + assert df1.equals(ray.get(ds.to_pandas())[0]) # Test metadata ops. assert ds.count() == 3 assert ds.input_files() == [path1] @@ -2217,8 +2069,8 @@ def test_zipped_json_read(ray_start_regular_shared, tmp_path): path2 = os.path.join(tmp_path, "test2.json.gz") df2.to_json(path2, compression="gzip", orient="records", lines=True) ds = ray.data.read_json([path1, path2], parallelism=2) - dsdf = ds.to_pandas() - assert pd.concat([df1, df2], ignore_index=True).equals(dsdf) + dsdf = pd.concat(ray.get(ds.to_pandas())) + assert pd.concat([df1, df2]).equals(dsdf) # Test metadata ops. for block, meta in zip(ds._blocks, ds._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() @@ -2233,8 +2085,8 @@ def test_zipped_json_read(ray_start_regular_shared, tmp_path): path2 = os.path.join(tmp_path, "data1.json.gz") df2.to_json(path2, compression="gzip", orient="records", lines=True) ds = ray.data.read_json([dir_path, path2]) - df = pd.concat([df1, df2], ignore_index=True) - dsdf = ds.to_pandas() + df = pd.concat([df1, df2]) + dsdf = pd.concat(ray.get(ds.to_pandas())) assert df.equals(dsdf) shutil.rmtree(dir_path) @@ -2251,7 +2103,7 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url): storage_options = dict(client_kwargs=dict(endpoint_url=endpoint_url)) # Single block. df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([df1]) + ds = ray.data.from_pandas([ray.put(df1)]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.json") @@ -2264,7 +2116,7 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([df1, df2]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) file_path2 = os.path.join(data_path, "data_000001.json") @@ -2291,12 +2143,12 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url): def test_json_roundtrip(ray_start_regular_shared, fs, data_path): # Single block. df = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.json") ds2 = ray.data.read_json([file_path], filesystem=fs) - ds2df = ds2.to_pandas() + ds2df = pd.concat(ray.get(ds2.to_pandas())) assert ds2df.equals(df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): @@ -2309,12 +2161,12 @@ def test_json_roundtrip(ray_start_regular_shared, fs, data_path): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([df, df2]) + ds = ray.data.from_pandas([ray.put(df), ray.put(df2)]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) ds2 = ray.data.read_json(data_path, parallelism=2, filesystem=fs) - ds2df = ds2.to_pandas() - assert pd.concat([df, df2], ignore_index=True).equals(ds2df) + ds2df = pd.concat(ray.get(ds2.to_pandas())) + assert pd.concat([df, df2]).equals(ds2df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() == meta.size_bytes @@ -2338,7 +2190,7 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path1 = os.path.join(data_path, "test1.csv") df1.to_csv(path1, index=False, storage_options=storage_options) ds = ray.data.read_csv(path1, filesystem=fs) - dsdf = ds.to_pandas() + dsdf = ray.get(ds.to_pandas())[0] assert df1.equals(dsdf) # Test metadata ops. assert ds.count() == 3 @@ -2350,8 +2202,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path2 = os.path.join(data_path, "test2.csv") df2.to_csv(path2, index=False, storage_options=storage_options) ds = ray.data.read_csv([path1, path2], parallelism=2, filesystem=fs) - dsdf = ds.to_pandas() - df = pd.concat([df1, df2], ignore_index=True) + dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2]) assert df.equals(dsdf) # Test metadata ops. for block, meta in zip(ds._blocks, ds._blocks.get_metadata()): @@ -2363,7 +2215,7 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): df3.to_csv(path3, index=False, storage_options=storage_options) ds = ray.data.read_csv([path1, path2, path3], parallelism=2, filesystem=fs) df = pd.concat([df1, df2, df3], ignore_index=True) - dsdf = ds.to_pandas() + dsdf = pd.concat(ray.get(ds.to_pandas()), ignore_index=True) assert df.equals(dsdf) # Directory, two files. @@ -2379,8 +2231,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path2 = os.path.join(path, "data1.csv") df2.to_csv(path2, index=False, storage_options=storage_options) ds = ray.data.read_csv(path, filesystem=fs) - df = pd.concat([df1, df2], ignore_index=True) - dsdf = ds.to_pandas() + df = pd.concat([df1, df2]) + dsdf = pd.concat(ray.get(ds.to_pandas())) assert df.equals(dsdf) if fs is None: shutil.rmtree(path) @@ -2406,8 +2258,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): file_path3 = os.path.join(path2, "data2.csv") df3.to_csv(file_path3, index=False, storage_options=storage_options) ds = ray.data.read_csv([path1, path2], filesystem=fs) - df = pd.concat([df1, df2, df3], ignore_index=True) - dsdf = ds.to_pandas() + df = pd.concat([df1, df2, df3]) + dsdf = pd.concat(ray.get(ds.to_pandas())) assert df.equals(dsdf) if fs is None: shutil.rmtree(path1) @@ -2429,8 +2281,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path2 = os.path.join(data_path, "data1.csv") df2.to_csv(path2, index=False, storage_options=storage_options) ds = ray.data.read_csv([dir_path, path2], filesystem=fs) - df = pd.concat([df1, df2], ignore_index=True) - dsdf = ds.to_pandas() + df = pd.concat([df1, df2]) + dsdf = pd.concat(ray.get(ds.to_pandas())) assert df.equals(dsdf) if fs is None: shutil.rmtree(dir_path) @@ -2450,7 +2302,7 @@ def test_csv_write(ray_start_regular_shared, fs, data_path, endpoint_url): storage_options = dict(client_kwargs=dict(endpoint_url=endpoint_url)) # Single block. df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([df1]) + ds = ray.data.from_pandas([ray.put(df1)]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.csv") @@ -2458,7 +2310,7 @@ def test_csv_write(ray_start_regular_shared, fs, data_path, endpoint_url): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([df1, df2]) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) file_path2 = os.path.join(data_path, "data_000001.csv") @@ -2477,12 +2329,12 @@ def test_csv_write(ray_start_regular_shared, fs, data_path, endpoint_url): def test_csv_roundtrip(ray_start_regular_shared, fs, data_path): # Single block. df = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([df]) + ds = ray.data.from_pandas([ray.put(df)]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.csv") ds2 = ray.data.read_csv([file_path], filesystem=fs) - ds2df = ds2.to_pandas() + ds2df = pd.concat(ray.get(ds2.to_pandas())) assert ds2df.equals(df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): @@ -2490,12 +2342,12 @@ def test_csv_roundtrip(ray_start_regular_shared, fs, data_path): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([df, df2]) + ds = ray.data.from_pandas([ray.put(df), ray.put(df2)]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) ds2 = ray.data.read_csv(data_path, parallelism=2, filesystem=fs) - ds2df = ds2.to_pandas() - assert pd.concat([df, df2], ignore_index=True).equals(ds2df) + ds2df = pd.concat(ray.get(ds2.to_pandas())) + assert pd.concat([df, df2]).equals(ds2df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() == meta.size_bytes @@ -2513,21 +2365,13 @@ def test_sort_simple(ray_start_regular_shared): assert ds.sort(key=lambda x: -x).take(num_items) == list( reversed(range(num_items))) - # Test empty dataset. - ds = ray.data.from_items([]) - s1 = ds.sort() - assert s1.count() == 0 - assert s1 == ds - @pytest.mark.parametrize("pipelined", [False, True]) def test_random_shuffle(shutdown_only, pipelined): def range(n, parallelism=200): ds = ray.data.range(n, parallelism=parallelism) if pipelined: - pipe = ds.repeat(2) - pipe.random_shuffle = pipe.random_shuffle_each_window - return pipe + return ds.repeat(2) else: return ds @@ -2572,12 +2416,6 @@ def range(n, parallelism=200): r2 = range(100).random_shuffle(_move=True).take(999) assert r1 != r2, (r1, r2) - # Test empty dataset. - ds = ray.data.from_items([]) - r1 = ds.random_shuffle() - assert r1.count() == 0 - assert r1 == ds - def test_random_shuffle_spread(ray_start_cluster): cluster = ray_start_cluster @@ -2599,7 +2437,7 @@ def get_node_id(): ds = ray.data.range( 100, parallelism=2).random_shuffle(_spread_resource_prefix="bar:") - blocks = ds.get_internal_block_refs() + blocks = ds.get_blocks() ray.wait(blocks, num_returns=len(blocks), fetch_local=False) location_data = ray.experimental.get_object_locations(blocks) locations = [] @@ -2640,7 +2478,7 @@ def get_node_id(): ds = ray.data.read_parquet(data_path, _spread_resource_prefix="bar:") # Force reads. - blocks = ds.get_internal_block_refs() + blocks = ds.get_blocks() assert len(blocks) == 2 ray.wait(blocks, num_returns=len(blocks), fetch_local=False) @@ -2667,7 +2505,7 @@ def test_sort_arrow(ray_start_regular, num_items, parallelism): offset += shard if offset < num_items: dfs.append(pd.DataFrame({"a": a[offset:], "b": b[offset:]})) - ds = ray.data.from_pandas(dfs) + ds = ray.data.from_pandas([ray.put(df) for df in dfs]) def assert_sorted(sorted_ds, expected_rows): assert [tuple(row.values()) @@ -2697,7 +2535,7 @@ def __init__(self): def _read_file(self, f: "pa.NativeFile", path: str, **reader_args): count = self.counter.increment.remote() if ray.get(count) == 1: - raise ValueError("oops") + raise ValueError() else: return CSVDatasource._read_file(self, f, path, **reader_args) @@ -2705,7 +2543,7 @@ def _write_block(self, f: "pa.NativeFile", block: BlockAccessor, **writer_args): count = self.counter.increment.remote() if ray.get(count) == 1: - raise ValueError("oops") + raise ValueError() else: CSVDatasource._write_block(self, f, block, **writer_args) @@ -2725,7 +2563,7 @@ def _write_block(self, f: "pa.NativeFile", block: BlockAccessor, def flaky_mapper(x): count = counter.increment.remote() if ray.get(count) == 1: - raise ValueError("oops") + raise ValueError() else: return ray.get(count) diff --git a/python/ray/data/tests/test_dataset_pipeline.py b/python/ray/data/tests/test_dataset_pipeline.py index cffb378f36861..b199374f80437 100644 --- a/python/ray/data/tests/test_dataset_pipeline.py +++ b/python/ray/data/tests/test_dataset_pipeline.py @@ -30,14 +30,14 @@ def block_on_ones(x: int) -> int: time.sleep(999999) return x - pipe = ray.data.range(2).window(blocks_per_window=1) + pipe = ray.data.range(2).pipeline(parallelism=1) pipe = pipe.map(block_on_ones) assert pipe.take(1) == [0] def test_cannot_read_twice(ray_start_regular_shared): ds = ray.data.range(10) - pipe = ds.window(blocks_per_window=1) + pipe = ds.pipeline(parallelism=1) assert pipe.count() == 10 with pytest.raises(RuntimeError): pipe.count() @@ -52,70 +52,25 @@ def test_cannot_read_twice(ray_start_regular_shared): def test_basic_pipeline(ray_start_regular_shared): ds = ray.data.range(10) - pipe = ds.window(blocks_per_window=1) - assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" + pipe = ds.pipeline(parallelism=1) + assert str(pipe) == "DatasetPipeline(length=10, num_stages=1)" assert pipe.count() == 10 - pipe = ds.window(blocks_per_window=1).map(lambda x: x).map(lambda x: x) - assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=3)" + pipe = ds.pipeline(parallelism=1).map(lambda x: x).map(lambda x: x) + assert str(pipe) == "DatasetPipeline(length=10, num_stages=3)" assert pipe.take() == list(range(10)) - pipe = ds.window(blocks_per_window=999) - assert str(pipe) == "DatasetPipeline(num_windows=1, num_stages=1)" + pipe = ds.pipeline(parallelism=999) + assert str(pipe) == "DatasetPipeline(length=1, num_stages=1)" assert pipe.count() == 10 pipe = ds.repeat(10) - assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" + assert str(pipe) == "DatasetPipeline(length=10, num_stages=1)" assert pipe.count() == 100 pipe = ds.repeat(10) assert pipe.sum() == 450 -def test_window(ray_start_regular_shared): - ds = ray.data.range(10) - pipe = ds.window(blocks_per_window=1) - assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" - pipe = pipe.rewindow(blocks_per_window=3) - assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)" - datasets = list(pipe.iter_datasets()) - assert len(datasets) == 4 - assert datasets[0].take() == [0, 1, 2] - assert datasets[1].take() == [3, 4, 5] - assert datasets[2].take() == [6, 7, 8] - assert datasets[3].take() == [9] - - ds = ray.data.range(10) - pipe = ds.window(blocks_per_window=5) - assert str(pipe) == "DatasetPipeline(num_windows=2, num_stages=1)" - pipe = pipe.rewindow(blocks_per_window=3) - assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)" - datasets = list(pipe.iter_datasets()) - assert len(datasets) == 4 - assert datasets[0].take() == [0, 1, 2] - assert datasets[1].take() == [3, 4, 5] - assert datasets[2].take() == [6, 7, 8] - assert datasets[3].take() == [9] - - -def test_repeat(ray_start_regular_shared): - ds = ray.data.range(5) - pipe = ds.window(blocks_per_window=1) - assert str(pipe) == "DatasetPipeline(num_windows=5, num_stages=1)" - pipe = pipe.repeat(2) - assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" - assert pipe.take() == (list(range(5)) + list(range(5))) - - ds = ray.data.range(5) - pipe = ds.window(blocks_per_window=1) - pipe = pipe.repeat() - assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)" - assert len(pipe.take(99)) == 99 - - pipe = ray.data.range(5).repeat() - with pytest.raises(ValueError): - pipe.repeat() - - def test_from_iterable(ray_start_regular_shared): pipe = DatasetPipeline.from_iterable( [lambda: ray.data.range(3), lambda: ray.data.range(2)]) @@ -125,7 +80,7 @@ def test_from_iterable(ray_start_regular_shared): def test_repeat_forever(ray_start_regular_shared): ds = ray.data.range(10) pipe = ds.repeat() - assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)" + assert str(pipe) == "DatasetPipeline(length=None, num_stages=1)" for i, v in enumerate(pipe.iter_rows()): assert v == i % 10, (v, i, i % 10) if i > 1000: @@ -134,38 +89,38 @@ def test_repeat_forever(ray_start_regular_shared): def test_repartition(ray_start_regular_shared): pipe = ray.data.range(10).repeat(10) - assert pipe.repartition_each_window(1).sum() == 450 + assert pipe.repartition(1).sum() == 450 pipe = ray.data.range(10).repeat(10) - assert pipe.repartition_each_window(10).sum() == 450 + assert pipe.repartition(10).sum() == 450 pipe = ray.data.range(10).repeat(10) - assert pipe.repartition_each_window(100).sum() == 450 + assert pipe.repartition(100).sum() == 450 def test_iter_batches(ray_start_regular_shared): - pipe = ray.data.range(10).window(blocks_per_window=2) + pipe = ray.data.range(10).pipeline(parallelism=2) batches = list(pipe.iter_batches()) assert len(batches) == 10 assert all(len(e) == 1 for e in batches) def test_iter_datasets(ray_start_regular_shared): - pipe = ray.data.range(10).window(blocks_per_window=2) + pipe = ray.data.range(10).pipeline(parallelism=2) ds = list(pipe.iter_datasets()) assert len(ds) == 5 - pipe = ray.data.range(10).window(blocks_per_window=5) + pipe = ray.data.range(10).pipeline(parallelism=5) ds = list(pipe.iter_datasets()) assert len(ds) == 2 -def test_foreach_window(ray_start_regular_shared): - pipe = ray.data.range(5).window(blocks_per_window=2) - pipe = pipe.foreach_window(lambda ds: ds.map(lambda x: x * 2)) +def test_foreach_dataset(ray_start_regular_shared): + pipe = ray.data.range(5).pipeline(parallelism=2) + pipe = pipe.foreach_dataset(lambda ds: ds.map(lambda x: x * 2)) assert pipe.take() == [0, 2, 4, 6, 8] def test_schema(ray_start_regular_shared): - pipe = ray.data.range(5).window(blocks_per_window=2) + pipe = ray.data.range(5).pipeline(parallelism=2) assert pipe.schema() == int @@ -223,8 +178,8 @@ def test_parquet_write(ray_start_regular_shared, tmp_path): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([df1, df2]) - ds = ds.window(blocks_per_window=1) + ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ds.pipeline(parallelism=1) path = os.path.join(tmp_path, "test_parquet_dir") os.mkdir(path) ds._set_uuid("data") diff --git a/python/ray/data/tests/test_raydp_dataset.py b/python/ray/data/tests/test_raydp_dataset.py index c23b672f97e38..c86c6a0803c13 100644 --- a/python/ray/data/tests/test_raydp_dataset.py +++ b/python/ray/data/tests/test_raydp_dataset.py @@ -16,10 +16,6 @@ def stop_all(): return spark -@pytest.mark.skip( - reason=( - "raydp.spark.spark_dataframe_to_ray_dataset needs to be updated to " - "use ray.data.from_arrow_refs.")) def test_raydp_roundtrip(spark_on_ray_small): spark = spark_on_ray_small spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index f46be7c0a1a15..58d1706549d2a 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -28,11 +28,7 @@ def from_bytes(b): ray_exception = RayException() ray_exception.ParseFromString(b) if ray_exception.language == PYTHON: - try: - return pickle.loads(ray_exception.serialized_exception) - except Exception as e: - msg = "Failed to unpickle serialized exception" - raise RuntimeError(msg) from e + return pickle.loads(ray_exception.serialized_exception) else: return CrossLanguageError(ray_exception) diff --git a/python/ray/experimental/array/remote/core.py b/python/ray/experimental/array/remote/core.py index 7b6d24f75b283..f4572da82babe 100644 --- a/python/ray/experimental/array/remote/core.py +++ b/python/ray/experimental/array/remote/core.py @@ -68,8 +68,8 @@ def diag(v, k=0): @ray.remote -def transpose(a, axes=None): - axes = None if (axes == [] or axes is None) else axes +def transpose(a, axes=[]): + axes = None if axes == [] else axes return np.transpose(a, axes=axes) diff --git a/python/ray/experimental/internal_kv.py b/python/ray/experimental/internal_kv.py index 456adabcb66ca..e434c3cf5f979 100644 --- a/python/ray/experimental/internal_kv.py +++ b/python/ray/experimental/internal_kv.py @@ -35,7 +35,7 @@ def _initialize_internal_kv(gcs_client: "ray._raylet.GcsClient" = None): return global_gcs_client -@client_mode_hook(auto_init=False) +@client_mode_hook def _internal_kv_initialized(): gcs_client = _initialize_internal_kv() @@ -46,7 +46,7 @@ def _internal_kv_initialized(): return hasattr(worker, "mode") and worker.mode is not None -@client_mode_hook(auto_init=False) +@client_mode_hook def _internal_kv_get(key: Union[str, bytes]) -> bytes: """Fetch the value of a binary key.""" gcs_client = _initialize_internal_kv() @@ -57,7 +57,7 @@ def _internal_kv_get(key: Union[str, bytes]) -> bytes: return ray.worker.global_worker.redis_client.hget(key, "value") -@client_mode_hook(auto_init=False) +@client_mode_hook def _internal_kv_exists(key: Union[str, bytes]) -> bool: """Check key exists or not.""" gcs_client = _initialize_internal_kv() @@ -67,7 +67,7 @@ def _internal_kv_exists(key: Union[str, bytes]) -> bool: return ray.worker.global_worker.redis_client.hexists(key, "value") -@client_mode_hook(auto_init=False) +@client_mode_hook def _internal_kv_put(key: Union[str, bytes], value: Union[str, bytes], overwrite: bool = True) -> bool: @@ -91,7 +91,7 @@ def _internal_kv_put(key: Union[str, bytes], return updated == 0 # already exists -@client_mode_hook(auto_init=False) +@client_mode_hook def _internal_kv_del(key: Union[str, bytes]): gcs_client = _initialize_internal_kv() if gcs_client is not None: @@ -100,7 +100,7 @@ def _internal_kv_del(key: Union[str, bytes]): return ray.worker.global_worker.redis_client.delete(key) -@client_mode_hook(auto_init=False) +@client_mode_hook def _internal_kv_list(prefix: Union[str, bytes]) -> List[bytes]: """List all keys in the internal KV store that start with the prefix. """ diff --git a/python/ray/experimental/raysort/constants.py b/python/ray/experimental/raysort/constants.py index 9c32b5f07330e..5ab3b2df29831 100644 --- a/python/ray/experimental/raysort/constants.py +++ b/python/ray/experimental/raysort/constants.py @@ -1,15 +1,12 @@ import os -from ray.experimental.raysort.types import ByteCount, PartId, RecordCount +from ray.experimental.raysort.types import ByteCount, RecordCount __DIR__ = os.path.dirname(os.path.abspath(__file__)) # Basics RECORD_SIZE = 100 # bytes -# Progress Tracker Actor -PROGRESS_TRACKER_ACTOR = "ProgressTrackerActor" - # Executable locations GENSORT_PATH = os.path.join(__DIR__, "bin/gensort/64/gensort") VALSORT_PATH = os.path.join(__DIR__, "bin/gensort/64/valsort") @@ -21,12 +18,10 @@ DATA_DIR_FMT = { "input": "{mnt}/tmp/input/", "output": "{mnt}/tmp/output/", - "temp": "{mnt}/tmp/temp/", } FILENAME_FMT = { "input": "input-{part_id:08}", "output": "output-{part_id:08}", - "temp": "temp-{part_id:08}", } # Prometheus config @@ -38,7 +33,3 @@ def bytes_to_records(n_bytes: ByteCount) -> RecordCount: assert n_bytes % RECORD_SIZE == 0 return int(n_bytes / RECORD_SIZE) - - -def merge_part_ids(reducer_id: PartId, mapper_id: PartId) -> PartId: - return reducer_id * 1_000_000 + mapper_id diff --git a/python/ray/experimental/raysort/main.py b/python/ray/experimental/raysort/main.py index 0df5bfb59ec75..1cc8d0df1c5af 100644 --- a/python/ray/experimental/raysort/main.py +++ b/python/ray/experimental/raysort/main.py @@ -1,12 +1,10 @@ import argparse -import contextlib import csv import logging import os import random import subprocess -import tempfile -from typing import Callable, Dict, Iterable, List +from typing import Iterable, List import numpy as np import ray @@ -15,17 +13,14 @@ from ray.experimental.raysort import logging_utils from ray.experimental.raysort import sortlib from ray.experimental.raysort import tracing_utils -from ray.experimental.raysort.types import \ - BlockInfo, ByteCount, RecordCount, PartId, PartInfo, Path - -Args = argparse.Namespace +from ray.experimental.raysort.types import BlockInfo, ByteCount, RecordCount, PartId, PartitionInfo, Path # noqa: E501 # ------------------------------------------------------------ # Parse Arguments # ------------------------------------------------------------ -def get_args(*args, **kwargs): +def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--ray_address", @@ -35,39 +30,27 @@ def get_args(*args, **kwargs): ) parser.add_argument( "--total_data_size", - default=1 * 1000 * 1024 * 1024 * 1024, + default=1_000_000_000, type=ByteCount, - help="total data size in bytes", + help="partition size in bytes", ) parser.add_argument( "--num_mappers", - default=256, + default=4, type=int, help="number of map tasks", ) - parser.add_argument( - "--num_mappers_per_round", - default=16, - type=int, - help="number of map tasks per first-stage merge tasks", - ) parser.add_argument( "--num_reducers", - default=16, - type=int, - help="number of second-stage reduce tasks", - ) - parser.add_argument( - "--num_concurrent_rounds", default=4, type=int, - help="max number of rounds of map/merge tasks in flight", + help="number of reduce tasks", ) parser.add_argument( - "--reducer_input_chunk", - default=100 * 1024 * 1024, - type=ByteCount, - help="bytes to read from each file in reduce tasks", + "--reducer_batch_num_records", + default=1_000_000, + type=RecordCount, + help="number of bytes to buffer before writing the output to EBS", ) parser.add_argument( "--skip_sorting", @@ -92,13 +75,13 @@ def get_args(*args, **kwargs): "tasks to run", "if no task is specified, will run all tasks") tasks = ["generate_input", "sort", "validate_output"] for task in tasks: - tasks_group.add_argument(f"--{task}", action="store_true") + tasks_group.add_argument( + f"--{task}", action="store_true", help=f"run task {task}") - args = parser.parse_args(*args, **kwargs) + args = parser.parse_args() # Derive additional arguments. args.input_part_size = ByteCount(args.total_data_size / args.num_mappers) - assert args.num_mappers % args.num_mappers_per_round == 0 - args.num_rounds = int(args.num_mappers / args.num_mappers_per_round) + args.output_part_size = ByteCount(args.total_data_size / args.num_reducers) args.mount_points = _get_mount_points() # If no tasks are specified, run all tasks. args_dict = vars(args) @@ -109,29 +92,28 @@ def get_args(*args, **kwargs): def _get_mount_points(): - default_ret = [tempfile.gettempdir()] mnt = "/mnt" - if os.path.exists(mnt): - ret = [os.path.join(mnt, d) for d in os.listdir(mnt)] - if len(ret) > 0: - return ret - return default_ret + if not os.path.exists(mnt): + return [] + return [os.path.join(mnt, d) for d in os.listdir(mnt)] +args = None + # ------------------------------------------------------------ # Generate Input # ------------------------------------------------------------ -def _part_info(args: Args, part_id: PartId, kind="input") -> PartInfo: +def _make_partition_info(part_id: PartId, kind="input") -> PartitionInfo: node = ray.worker.global_worker.node_ip_address mnt = random.choice(args.mount_points) filepath = _get_part_path(mnt, part_id, kind) - return PartInfo(part_id, node, filepath) + return PartitionInfo(part_id, node, filepath) def _get_part_path(mnt: Path, part_id: PartId, kind="input") -> Path: - assert kind in {"input", "output", "temp"} + assert kind in {"input", "output"} dir_fmt = constants.DATA_DIR_FMT[kind] dirpath = dir_fmt.format(mnt=mnt) os.makedirs(dirpath, exist_ok=True) @@ -142,25 +124,26 @@ def _get_part_path(mnt: Path, part_id: PartId, kind="input") -> Path: @ray.remote -def generate_part(args: Args, part_id: PartId, size: RecordCount, - offset: RecordCount) -> PartInfo: +def generate_part(part_id: PartId, size: RecordCount, + offset: RecordCount) -> PartitionInfo: logging_utils.init() - pinfo = _part_info(args, part_id) - subprocess.run( - [constants.GENSORT_PATH, f"-b{offset}", f"{size}", pinfo.path], - check=True) - logging.info(f"Generated input {pinfo}") + pinfo = _make_partition_info(part_id) + if not args.skip_input: + subprocess.run( + [constants.GENSORT_PATH, f"-b{offset}", f"{size}", pinfo.path], + check=True) + logging.info(f"Generated input {pinfo}") return pinfo -def generate_input(args: Args): +def generate_input(): if args.skip_input: return size = constants.bytes_to_records(args.input_part_size) offset = 0 tasks = [] for part_id in range(args.num_mappers): - tasks.append(generate_part.remote(args, part_id, size, offset)) + tasks.append(generate_part.remote(part_id, size, offset)) offset += size assert offset == constants.bytes_to_records(args.total_data_size), args logging.info(f"Generating {len(tasks)} partitions") @@ -175,21 +158,22 @@ def generate_input(args: Args): # ------------------------------------------------------------ -def _load_manifest(args: Args, path: Path) -> List[PartInfo]: +def _load_manifest(path: Path) -> List[PartitionInfo]: if args.skip_input: - return [PartInfo(i, None, None) for i in range(args.num_mappers)] + return _load_dummy_manifest() with open(path) as fin: reader = csv.reader(fin) return [ - PartInfo(int(part_id), node, path) + PartitionInfo(int(part_id), node, path) for part_id, node, path in reader ] -def _load_partition(args: Args, path: Path) -> np.ndarray: - if args.skip_input: - return np.frombuffer( - np.random.bytes(args.input_part_size), dtype=np.uint8).copy() +def _load_dummy_manifest() -> List[PartitionInfo]: + return [PartitionInfo(i, "", "") for i in range(args.num_mappers)] + + +def _load_partition(path: Path) -> np.ndarray: return np.fromfile(path, dtype=np.uint8) @@ -206,214 +190,115 @@ def _dummy_sort_and_partition(part: np.ndarray, @ray.remote -@tracing_utils.timeit("map") -def mapper(args: Args, mapper_id: PartId, boundaries: List[int], - path: Path) -> List[np.ndarray]: +def mapper(boundaries: List[int], mapper_id: PartId, + path: Path) -> List[ray.ObjectRef]: logging_utils.init() - part = _load_partition(args, path) + task_id = f"M-{mapper_id} Mapper" + logging.info(f"{task_id} starting {args}") + if args.skip_input: + block_size = int(np.ceil(args.input_part_size / args.num_reducers)) + return [ + ray.put( + np.frombuffer(np.random.bytes(block_size), dtype=np.uint8)) + for _ in range(args.num_reducers) + ] + + part = _load_partition(path) sort_fn = _dummy_sort_and_partition \ if args.skip_sorting else sortlib.sort_and_partition blocks = sort_fn(part, boundaries) - return [part[offset:offset + size] for offset, size in blocks] + logging.info(f"{task_id} saving to object store") + return [ray.put(part[offset:offset + size]) for offset, size in blocks] -def _dummy_merge( - num_blocks: int, _n: int, - get_block: Callable[[int, int], np.ndarray]) -> Iterable[np.ndarray]: - blocks = [((i, 0), get_block(i, 0)) for i in range(num_blocks)] - while len(blocks) > 0: - (m, d), block = blocks.pop(random.randrange(len(blocks))) +def _dummy_merge(blocks: List[np.ndarray], _n: int) -> Iterable[memoryview]: + for block in blocks: yield block - d_ = d + 1 - block = get_block(m, d_) - if block is None: - continue - blocks.append(((m, d_), block)) - - -def _merge_impl(args: Args, - M: int, - pinfo: PartInfo, - get_block: Callable[[int, int], np.ndarray], - skip_output=False): - merge_fn = _dummy_merge if args.skip_sorting else sortlib.merge_partitions - merger = merge_fn(M, get_block) - if skip_output: + +@ray.remote +def reducer(reducer_id: PartId, *blocks: List[ray.ObjectRef]) -> PartitionInfo: + logging_utils.init() + task_id = f"R-{reducer_id} Reducer" + logging.info(f"{task_id} starting") + blocks = [np.copy(ray.get(block)) for block in blocks] + merge_fn = _dummy_merge if args.skip_sorting else sortlib.merge_partitions + merger = merge_fn(blocks, args.reducer_batch_num_records) + if args.skip_output: for datachunk in merger: del datachunk + logging.info(f"{task_id} done") + return None else: + pinfo = _make_partition_info(reducer_id, "output") with open(pinfo.path, "wb") as fout: for datachunk in merger: fout.write(datachunk) - return pinfo + logging.info(f"{task_id} done") + return pinfo -# See worker_placement_groups() for why `num_cpus=0`. -@ray.remote(num_cpus=0, resources={"worker": 1}) -@tracing_utils.timeit("merge") -def merge_mapper_blocks(args: Args, reducer_id: PartId, mapper_id: PartId, - *blocks: List[np.ndarray]) -> PartInfo: - part_id = constants.merge_part_ids(reducer_id, mapper_id) - pinfo = _part_info(args, part_id, kind="temp") - M = len(blocks) - - def get_block(i, d): - if i >= M or d > 0: - return None - return blocks[i] - - return _merge_impl(args, M, pinfo, get_block) - - -# See worker_placement_groups() for why `num_cpus=0`. -@ray.remote(num_cpus=0, resources={"worker": 1}) -@tracing_utils.timeit("reduce") -def final_merge(args: Args, reducer_id: PartId, - *merged_parts: List[PartInfo]) -> PartInfo: - M = len(merged_parts) - - def _load_block_chunk(pinfo: PartInfo, d: int) -> np.ndarray: - return np.fromfile( - pinfo.path, - dtype=np.uint8, - count=args.reducer_input_chunk, - offset=d * args.reducer_input_chunk) - - def get_block(i, d): - ret = _load_block_chunk(merged_parts[i], d) - if ret.size == 0: - return None - return ret - - pinfo = _part_info(args, reducer_id, "output") - return _merge_impl(args, M, pinfo, get_block, args.skip_output) - - -def _node_res(node: str) -> Dict[str, float]: - return {"resources": {f"node:{node}": 1e-3}} - - -@contextlib.contextmanager -def worker_placement_groups(args: Args) -> List[ray.PlacementGroupID]: - """ - Returns one placement group per node with a `worker` resource. To run - tasks in the placement group, use - `@ray.remote(num_cpus=0, resources={"worker": 1})`. Ray does not - automatically reserve CPU resources, so tasks must specify `num_cpus=0` - in order to run in a placement group. - """ - pgs = [ - ray.util.placement_group([{ - "worker": 1 - }]) for _ in range(args.num_reducers) - ] - ray.get([pg.ready() for pg in pgs]) - try: - yield pgs - finally: - for pg in pgs: - ray.util.remove_placement_group(pg) - - -@tracing_utils.timeit("sort", report_time=True) -def sort_main(args: Args): - parts = _load_manifest(args, constants.INPUT_MANIFEST_FILE) - assert len(parts) == args.num_mappers +@tracing_utils.timeit("sorting") +def sort_main(): + partitions = _load_manifest(constants.INPUT_MANIFEST_FILE) boundaries = sortlib.get_boundaries(args.num_reducers) - - mapper_opt = { - "num_returns": args.num_reducers, - "num_cpus": os.cpu_count() / args.num_concurrent_rounds, - } # Load balance across worker nodes by setting `num_cpus`. - merge_results = np.empty( - (args.num_rounds, args.num_reducers), dtype=object) - - part_id = 0 - with worker_placement_groups(args) as pgs: - for round in range(args.num_rounds): - # Limit the number of in-flight rounds. - num_extra_rounds = round - args.num_concurrent_rounds + 1 - if num_extra_rounds > 0: - ray.wait( - [f for f in merge_results.flatten() if f is not None], - num_returns=num_extra_rounds * args.num_reducers) - - # Submit map tasks. - mapper_results = np.empty( - (args.num_mappers_per_round, args.num_reducers), dtype=object) - for _ in range(args.num_mappers_per_round): - _, node, path = parts[part_id] - m = part_id % args.num_mappers_per_round - mapper_results[m, :] = mapper.options(**mapper_opt).remote( - args, part_id, boundaries, path) - part_id += 1 - - # Submit merge tasks. - merge_results[round, :] = [ - merge_mapper_blocks.options(placement_group=pgs[r]).remote( - args, r, round, *mapper_results[:, r].tolist()) - for r in range(args.num_reducers) - ] - - # Delete local references to mapper results. - mapper_results = None - - # Submit second-stage reduce tasks. - reducer_results = [ - final_merge.options(placement_group=pgs[r]).remote( - args, r, *merge_results[:, r].tolist()) - for r in range(args.num_reducers) - ] - reducer_results = ray.get(reducer_results) - + mapper_results = np.empty( + (args.num_mappers, args.num_reducers), dtype=object) + for part_id, node, path in partitions: + opt = {} if args.skip_input else { + "resources": { + f"node:{node}": 1 / args.num_mappers + }, + "memory": args.input_part_size * 1.2, + } + opt.update(num_returns=args.num_reducers) + mapper_results[part_id, :] = mapper.options(**opt).remote( + boundaries, part_id, path) + + reducer_results = [] + for r in range(args.num_reducers): + opt = { + "memory": args.output_part_size * 1.0, + } + blocks = mapper_results[:, r].tolist() + ret = reducer.options(**opt).remote(r, *blocks) + reducer_results.append(ret) + + reducer_results = ray.get(reducer_results) if not args.skip_output: with open(constants.OUTPUT_MANIFEST_FILE, "w") as fout: writer = csv.writer(fout) writer.writerows(reducer_results) - logging.info(ray.internal.internal_api.memory_summary(stats_only=True)) - # ------------------------------------------------------------ # Validate Output # ------------------------------------------------------------ -def _run_valsort(args: List[str]): - proc = subprocess.run([constants.VALSORT_PATH] + args, capture_output=True) - if proc.returncode != 0: - logging.critical("\n" + proc.stderr.decode("ascii")) - raise RuntimeError(f"Validation failed: {args}") - - @ray.remote def validate_part(path: Path): logging_utils.init() - sum_path = path + ".sum" - _run_valsort(["-o", sum_path, path]) + proc = subprocess.run([constants.VALSORT_PATH, path], capture_output=True) + if proc.returncode != 0: + logging.critical("\n" + proc.stderr.decode("ascii")) + raise RuntimeError(f"Validation failed: {path}") logging.info(f"Validated output {path}") - with open(sum_path, "rb") as fin: - return os.path.getsize(path), fin.read() -def validate_output(args: Args): - if args.skip_sorting or args.skip_output: +def validate_output(): + if args.skip_output: return - partitions = _load_manifest(args, constants.OUTPUT_MANIFEST_FILE) - results = [] + partitions = _load_manifest(constants.OUTPUT_MANIFEST_FILE) + tasks = [] for _, node, path in partitions: - results.append(validate_part.options(**_node_res(node)).remote(path)) - logging.info(f"Validating {len(results)} partitions") - results = ray.get(results) - total = sum(s for s, _ in results) - assert total == args.total_data_size, total - args.total_data_size - all_checksum = b"".join(c for _, c in results) - with tempfile.NamedTemporaryFile() as fout: - fout.write(all_checksum) - fout.flush() - _run_valsort(["-s", fout.name]) - logging.info("All OK!") + tasks.append( + validate_part.options(resources={ + f"node:{node}": 1 / args.num_reducers + }).remote(path)) + logging.info(f"Validating {len(tasks)} partitions") + ray.get(tasks) + logging.info("All done!") # ------------------------------------------------------------ @@ -421,34 +306,30 @@ def validate_output(args: Args): # ------------------------------------------------------------ -def init(args: Args): - if not args.ray_address: - ray.init(resources={"worker": os.cpu_count()}) +def init(): + if args.ray_address is None: + ray.init() else: ray.init(address=args.ray_address) logging_utils.init() logging.info(args) + logging.info(ray.available_resources()) os.makedirs(constants.WORK_DIR, exist_ok=True) - resources = ray.cluster_resources() - logging.info(resources) - args.num_workers = resources["worker"] - progress_tracker = tracing_utils.create_progress_tracker(args) - return progress_tracker -def main(args: Args): - # Keep the actor handle in scope for the duration of the program. - _progress_tracker = init(args) # noqa F841 +def main(): + init() if args.generate_input: - generate_input(args) + generate_input() if args.sort: - sort_main(args) + sort_main() if args.validate_output: - validate_output(args) + validate_output() if __name__ == "__main__": - main(get_args()) + args = get_args() + main() diff --git a/python/ray/experimental/raysort/sortlib.py b/python/ray/experimental/raysort/sortlib.py index 6242867286d5f..ea79ec7168de4 100644 --- a/python/ray/experimental/raysort/sortlib.py +++ b/python/ray/experimental/raysort/sortlib.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, List +from typing import Iterable, List import numpy as np @@ -21,9 +21,7 @@ def sort_and_partition(part: np.ndarray, return blocks -def merge_partitions( - num_blocks: int, - get_block: Callable[[int, int], np.ndarray]) -> Iterable[memoryview]: - blocks = [get_block(i, 0) for i in range(num_blocks)] +def merge_partitions(blocks: List[np.ndarray], + _n: int) -> Iterable[memoryview]: for block in blocks: yield block diff --git a/python/ray/experimental/raysort/tracing_utils.py b/python/ray/experimental/raysort/tracing_utils.py index e67584b62588c..e75bae8297429 100644 --- a/python/ray/experimental/raysort/tracing_utils.py +++ b/python/ray/experimental/raysort/tracing_utils.py @@ -1,122 +1,13 @@ -import datetime -import functools +import contextlib import logging import time -from typing import List, Tuple -import ray -from ray.util.metrics import Gauge, Histogram -from ray.experimental.raysort import constants -from ray.experimental.raysort import logging_utils - -HISTOGRAM_BOUNDARIES = list(range(50, 200, 50)) - - -def timeit( - event: str, - report_time=False, - report_in_progress=True, - report_completed=True, -): - def decorator(f): - @functools.wraps(f) - def wrapped_f(*args, **kwargs): - progress_tracker = ray.get_actor(constants.PROGRESS_TRACKER_ACTOR) - progress_tracker.inc.remote( - f"{event}_in_progress", echo=report_in_progress) - try: - start = time.time() - ret = f(*args, **kwargs) - end = time.time() - duration = end - start - progress_tracker.observe.remote( - f"{event}_time", - duration, - echo=report_time, - ) - progress_tracker.inc.remote( - f"{event}_completed", echo=report_completed) - return ret - finally: - progress_tracker.dec.remote(f"{event}_in_progress") - - return wrapped_f - - return decorator - - -def get_metrics(_args): - return { - "gauges": [ - "map_in_progress", - "merge_in_progress", - "reduce_in_progress", - "sort_in_progress", - "map_completed", - "merge_completed", - "reduce_completed", - "sort_completed", - ], - "histograms": [ - ("map_time", HISTOGRAM_BOUNDARIES), - ("merge_time", HISTOGRAM_BOUNDARIES), - ("reduce_time", HISTOGRAM_BOUNDARIES), - ("sort_time", HISTOGRAM_BOUNDARIES), - ], - } - - -def create_progress_tracker(args): - return ProgressTracker.options( - name=constants.PROGRESS_TRACKER_ACTOR).remote(**get_metrics(args)) - - -@ray.remote -class ProgressTracker: - def __init__( - self, - gauges: List[str], - histograms: List[Tuple[str, List[int]]], - ): - self.counts = {m: 0 for m in gauges} - self.gauges = {m: Gauge(m) for m in gauges} - self.reset_gauges() - self.histograms = { - m: Histogram(m, boundaries=b) - for m, b in histograms - } - logging_utils.init() - - def reset_gauges(self): - for g in self.gauges.values(): - g.set(0) - - def inc(self, metric_name, value=1, echo=False): - gauge = self.gauges.get(metric_name) - if gauge is None: - logging.warning(f"No such Gauge: {metric_name}") - return - self.counts[metric_name] += value - gauge.set(self.counts[metric_name]) - if echo: - logging.info(f"{metric_name} {self.counts[metric_name]}") - - def dec(self, metric_name, value=1, echo=False): - return self.inc(metric_name, -value, echo) - - def observe(self, metric_name, value, echo=False): - histogram = self.histograms.get(metric_name) - if histogram is None: - logging.warning(f"No such Histogram: {metric_name}") - return - histogram.observe(value) - if echo: - logging.info(f"{metric_name} {value}") - - -def export_timeline(): - timestr = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - filename = f"/tmp/ray-timeline-{timestr}.json" - ray.timeline(filename=filename) - logging.info(f"Exported Ray timeline to {filename}") +@contextlib.contextmanager +def timeit(event="operation", args={}): + start = time.time() + yield + end = time.time() + duration = end - start + args = {"duration": duration} + logging.info(f"{event} {args}") diff --git a/python/ray/experimental/raysort/types.py b/python/ray/experimental/raysort/types.py index 5d1c39a33a521..02c6f70e5004a 100644 --- a/python/ray/experimental/raysort/types.py +++ b/python/ray/experimental/raysort/types.py @@ -7,12 +7,6 @@ RecordCount = int BlockInfo = Tuple[int, int] - - -class PartInfo(NamedTuple): - part_id: PartId - node: NodeAddress - path: Path - - def __repr__(self): - return f"Part({self.node}:{self.path})" +PartitionInfo = NamedTuple("PartitionInfo", + [("part_id", PartId), ("node", NodeAddress), + ("path", Path)]) diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 17b2f3879f05b..33d9ce1a92fd4 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -260,7 +260,8 @@ cdef extern from "ray/core_worker/common.h" nogil: unordered_map[c_string, double] &resources, c_string concurrency_group_name, c_string serialized_runtime_env, - c_vector[c_string] runtime_env_uris) + const unordered_map[c_string, c_string] + &override_environment_variables) cdef cppclass CActorCreationOptions "ray::core::ActorCreationOptions": CActorCreationOptions() @@ -276,7 +277,8 @@ cdef extern from "ray/core_worker/common.h" nogil: c_pair[CPlacementGroupID, int64_t] placement_options, c_bool placement_group_capture_child_tasks, c_string serialized_runtime_env, - c_vector[c_string] runtime_env_uris) + const unordered_map[c_string, c_string] + &override_environment_variables) cdef cppclass CPlacementGroupCreationOptions \ "ray::core::PlacementGroupCreationOptions": diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index a95eaee2c228f..7e56dab60965a 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -2,7 +2,7 @@ # distutils: language = c++ # cython: embedsignature = True -from libc.stdint cimport int64_t, uint64_t +from libc.stdint cimport int64_t from libcpp cimport bool as c_bool from libcpp.memory cimport shared_ptr, unique_ptr from libcpp.pair cimport pair as c_pair @@ -177,6 +177,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_vector[CObjectReference] GetObjectRefs( const c_vector[CObjectID] &object_ids) const + void PromoteObjectToPlasma(const CObjectID &object_id) void GetOwnershipInfo(const CObjectID &object_id, CAddress *owner_address, c_string *object_status) @@ -253,8 +254,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: int64_t GetNumLeasesRequested() const - unordered_map[c_string, c_vector[uint64_t]] GetActorCallStats() const - cdef cppclass CCoreWorkerOptions "ray::core::CoreWorkerOptions": CWorkerType worker_type CLanguage language diff --git a/python/ray/job_config.py b/python/ray/job_config.py index e9dc6b3d7cd7d..9ba513f71195e 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -1,13 +1,17 @@ from typing import Any, Dict, Optional import uuid +import json import ray._private.gcs_utils as gcs_utils +from ray.core.generated.common_pb2 import RuntimeEnv as RuntimeEnvPB class JobConfig: """A class used to store the configurations of a job. Attributes: + worker_env (dict): Environment variables to be set on worker + processes. num_java_workers_per_process (int): The number of java workers per worker process. jvm_options (str[]): The jvm options for java workers of the job. @@ -20,6 +24,7 @@ class JobConfig: """ def __init__(self, + worker_env=None, num_java_workers_per_process=1, jvm_options=None, code_search_path=None, @@ -27,6 +32,10 @@ def __init__(self, client_job=False, metadata=None, ray_namespace=None): + if worker_env is None: + self.worker_env = dict() + else: + self.worker_env = worker_env self.num_java_workers_per_process = num_java_workers_per_process self.jvm_options = jvm_options or [] self.code_search_path = code_search_path or [] @@ -45,23 +54,21 @@ def set_metadata(self, key: str, value: str) -> None: def serialize(self): """Serialize the struct into protobuf string""" - return self.get_proto_job_config().SerializeToString() + job_config = self.get_proto_job_config() + return job_config.SerializeToString() def set_runtime_env(self, runtime_env: Optional[Dict[str, Any]]) -> None: - # TODO(edoakes): this is really unfortunate, but JobConfig is imported - # all over the place so this causes circular imports. We should remove - # this dependency and pass in a validated runtime_env instead. - from ray._private.runtime_env.validation import ParsedRuntimeEnv - self._parsed_runtime_env = ParsedRuntimeEnv(runtime_env or {}) + # Lazily import this to avoid circular dependencies. + import ray._private.runtime_env as runtime_support + if runtime_env: + self._parsed_runtime_env = runtime_support.RuntimeEnvDict( + runtime_env) + self.worker_env.update( + self._parsed_runtime_env.get_parsed_dict().get("env_vars") + or {}) + else: + self._parsed_runtime_env = runtime_support.RuntimeEnvDict({}) self.runtime_env = runtime_env or dict() - eager_install = False - if runtime_env and "eager_install" in runtime_env: - eager_install = runtime_env["eager_install"] - self.runtime_env_eager_install = eager_install - assert isinstance(self.runtime_env_eager_install, bool), \ - f"The type of eager_install is incorrect: " \ - f"{type(self.runtime_env_eager_install)}" \ - f", the bool type is needed." self._cached_pb = None def set_ray_namespace(self, ray_namespace: str) -> None: @@ -77,27 +84,35 @@ def get_proto_job_config(self): self._cached_pb.ray_namespace = str(uuid.uuid4()) else: self._cached_pb.ray_namespace = self.ray_namespace + for key in self.worker_env: + self._cached_pb.worker_env[key] = self.worker_env[key] self._cached_pb.num_java_workers_per_process = ( self.num_java_workers_per_process) self._cached_pb.jvm_options.extend(self.jvm_options) self._cached_pb.code_search_path.extend(self.code_search_path) - self._cached_pb.runtime_env.uris[:] = self.get_runtime_env_uris() - serialized_env = self.get_serialized_runtime_env() - self._cached_pb.runtime_env.serialized_runtime_env = serialized_env + self._cached_pb.runtime_env.CopyFrom(self._get_proto_runtime()) + self._cached_pb.serialized_runtime_env = \ + self.get_serialized_runtime_env() for k, v in self.metadata.items(): self._cached_pb.metadata[k] = v - self._cached_pb.runtime_env.runtime_env_eager_install = \ - self.runtime_env_eager_install return self._cached_pb def get_runtime_env_uris(self): """Get the uris of runtime environment""" - return self._parsed_runtime_env.get("uris") or [] + if self.runtime_env.get("uris"): + return self.runtime_env.get("uris") + return [] + + def set_runtime_env_uris(self, uris): + self.runtime_env["uris"] = uris + self._parsed_runtime_env.set_uris(uris) def get_serialized_runtime_env(self) -> str: """Return the JSON-serialized parsed runtime env dict""" return self._parsed_runtime_env.serialize() - def set_runtime_env_uris(self, uris): - self.runtime_env["uris"] = uris - self._parsed_runtime_env["uris"] = uris + def _get_proto_runtime(self) -> RuntimeEnvPB: + runtime_env = RuntimeEnvPB() + runtime_env.uris[:] = self.get_runtime_env_uris() + runtime_env.raw_json = json.dumps(self.runtime_env) + return runtime_env diff --git a/python/ray/node.py b/python/ray/node.py index 0bc731e815bc2..cee0f8bfebeac 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -356,11 +356,7 @@ def merge_resources(env_dict, params_dict): env_string = os.getenv( ray_constants.RESOURCES_ENVIRONMENT_VARIABLE) if env_string: - try: - env_resources = json.loads(env_string) - except Exception: - logger.exception("Failed to load {}".format(env_string)) - raise + env_resources = json.loads(env_string) logger.debug( f"Autoscaler overriding resources: {env_resources}.") num_cpus, num_gpus, memory, object_store_memory, resources = \ @@ -576,10 +572,7 @@ def _get_log_file_names(self, name, unique=False): log_stderr = os.path.join(self._logs_dir, f"{name}.err") return log_stdout, log_stderr - def _get_unused_port(self, allocated_ports=None): - if allocated_ports is None: - allocated_ports = set() - + def _get_unused_port(self, close_on_exit=True): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) port = s.getsockname()[1] @@ -589,10 +582,6 @@ def _get_unused_port(self, allocated_ports=None): # from this method has been used by a different process. for _ in range(NUM_PORT_RETRIES): new_port = random.randint(port, 65535) - if new_port in allocated_ports: - # This port is allocated for other usage already, - # so we shouldn't use it even if it's not in use right now. - continue new_s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: new_s.bind(("", new_port)) @@ -600,11 +589,13 @@ def _get_unused_port(self, allocated_ports=None): new_s.close() continue s.close() - new_s.close() - return new_port + if close_on_exit: + new_s.close() + return new_port, new_s logger.error("Unable to succeed in selecting a random port.") - s.close() - return port + if close_on_exit: + s.close() + return port, s def _prepare_socket_file(self, socket_path, default_prefix): """Prepare the socket file for raylet and plasma. @@ -622,7 +613,7 @@ def _prepare_socket_file(self, socket_path, default_prefix): if sys.platform == "win32": if socket_path is None: result = (f"tcp://{self._localhost}" - f":{self._get_unused_port()}") + f":{self._get_unused_port()[0]}") else: if socket_path is None: result = self._make_inc_temp( @@ -674,8 +665,7 @@ def _get_cached_port(self, port = int(ports_by_node[self.unique_id][port_name]) else: # Pick a new port to use and cache it at this node. - port = (default_port or self._get_unused_port( - set(ports_by_node[self.unique_id].values()))) + port = (default_port or self._get_unused_port()[0]) ports_by_node[self.unique_id][port_name] = port with open(file_path, "w") as f: json.dump(ports_by_node, f) @@ -846,7 +836,6 @@ def start_raylet(self, start_initial_python_workers_for_first_job=self._ray_params. start_initial_python_workers_for_first_job, ray_debugger_external=self._ray_params.ray_debugger_external, - env_updates=self._ray_params.env_vars, ) assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index ea3df3acb2f9a..6854a93535b9e 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -1,7 +1,7 @@ -from functools import wraps -import inspect -import logging import uuid +import logging +import inspect +from functools import wraps from ray import cloudpickle as pickle from ray._raylet import PythonFunctionDescriptor @@ -14,8 +14,7 @@ get_current_placement_group, ) import ray._private.signature -from ray._private.runtime_env.validation import ( - override_task_or_actor_runtime_env, ParsedRuntimeEnv) +import ray._private.runtime_env as runtime_support from ray.util.tracing.tracing_helper import (_tracing_task_invocation, _inject_tracing_into_function) @@ -79,7 +78,7 @@ class RemoteFunction: def __init__(self, language, function, function_descriptor, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type, num_returns, max_calls, max_retries, - retry_exceptions, runtime_env, placement_group): + retry_exceptions, runtime_env): if inspect.iscoroutinefunction(function): raise ValueError("'async def' should not be used for remote " "tasks. You can wrap the async function with " @@ -109,12 +108,7 @@ def __init__(self, language, function, function_descriptor, num_cpus, self._retry_exceptions = (DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS if retry_exceptions is None else retry_exceptions) - # Parse local pip/conda config files here. If we instead did it in - # .remote(), it would get run in the Ray Client server, which runs on - # a remote node where the files aren't available. - self._runtime_env = ParsedRuntimeEnv( - runtime_env or {}, is_task_or_actor=True) - self._placement_group = placement_group + self._runtime_env = runtime_env self._decorator = getattr(function, "__ray_invocation_decorator__", None) self._function_signature = ray._private.signature.extract_signature( @@ -151,6 +145,7 @@ def options(self, placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, runtime_env=None, + override_environment_variables=None, name=""): """Configures and overrides the task invocation parameters. @@ -169,11 +164,6 @@ def f(): """ func_cls = self - # Parse local pip/conda config files here. If we instead did it in - # .remote(), it would get run in the Ray Client server, which runs on - # a remote node where the files aren't available. - new_runtime_env = ParsedRuntimeEnv( - runtime_env or {}, is_task_or_actor=True) class FuncWrapper: def remote(self, *args, **kwargs): @@ -193,7 +183,9 @@ def remote(self, *args, **kwargs): placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), - runtime_env=new_runtime_env, + runtime_env=runtime_env, + override_environment_variables=( + override_environment_variables), name=name) return FuncWrapper() @@ -215,10 +207,10 @@ def _remote(self, placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, runtime_env=None, + override_environment_variables=None, name=""): """Submit the remote function for execution.""" - - if client_mode_should_convert(auto_init=True): + if client_mode_should_convert(): return client_mode_convert_function( self, args, @@ -237,6 +229,7 @@ def _remote(self, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), runtime_env=runtime_env, + override_environment_variables=override_environment_variables, name=name) worker = ray.worker.global_worker @@ -277,12 +270,7 @@ def _remote(self, placement_group_capture_child_tasks = ( worker.should_capture_child_tasks_in_placement_group) - if self._placement_group != "default": - if self._placement_group: - placement_group = self._placement_group - else: - placement_group = PlacementGroup.empty() - elif placement_group == "default": + if placement_group == "default": if placement_group_capture_child_tasks: placement_group = get_current_placement_group() else: @@ -300,16 +288,18 @@ def _remote(self, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type) - if runtime_env and not isinstance(runtime_env, ParsedRuntimeEnv): - runtime_env = ParsedRuntimeEnv(runtime_env) - elif isinstance(runtime_env, ParsedRuntimeEnv): - pass - else: + if runtime_env is None: runtime_env = self._runtime_env - parent_runtime_env = worker.core_worker.get_current_runtime_env() - parsed_runtime_env = override_task_or_actor_runtime_env( - runtime_env, parent_runtime_env) + job_runtime_env = worker.core_worker.get_current_runtime_env_dict() + runtime_env_dict = runtime_support.override_task_or_actor_runtime_env( + runtime_env, job_runtime_env) + + if override_environment_variables: + logger.warning("override_environment_variables is deprecated and " + "will be removed in Ray 1.6. Please use " + ".options(runtime_env={'env_vars': {...}}).remote()" + "instead.") def invocation(args, kwargs): if self._is_cross_language: @@ -325,12 +315,21 @@ def invocation(args, kwargs): "Cross language remote function " \ "cannot be executed locally." object_refs = worker.core_worker.submit_task( - self._language, self._function_descriptor, list_args, name, - num_returns, resources, max_retries, retry_exceptions, - placement_group.id, placement_group_bundle_index, + self._language, + self._function_descriptor, + list_args, + name, + num_returns, + resources, + max_retries, + retry_exceptions, + placement_group.id, + placement_group_bundle_index, placement_group_capture_child_tasks, - worker.debugger_breakpoint, parsed_runtime_env.serialize(), - parsed_runtime_env.get("uris") or []) + worker.debugger_breakpoint, + runtime_env_dict, + override_environment_variables=override_environment_variables + or dict()) # Reset worker's debug context from the last "remote" command # (which applies only to this .remote call). worker.debugger_breakpoint = b"" diff --git a/python/ray/runtime_context.py b/python/ray/runtime_context.py index 64bee3fc7cf79..750e213cc12b0 100644 --- a/python/ray/runtime_context.py +++ b/python/ray/runtime_context.py @@ -152,7 +152,7 @@ def should_capture_child_tasks_in_placement_group(self): @property def runtime_env(self): - """Get the runtime env used for the current driver or worker. + """Get the runtime env passed to job_config Returns: The runtime env currently using by this worker. @@ -172,24 +172,12 @@ def current_actor(self): worker.check_connected() return worker.core_worker.get_actor_handle(self.actor_id) - def _get_actor_call_stats(self): - """Get the current worker's task counters. - - Returns: - A dictionary keyed by the function name. The values are - dictionaries with form ``{"received": 0, "executing": 1, - "exectued": 2}``. - """ - worker = self.worker - worker.check_connected() - return worker.core_worker.get_actor_call_stats() - _runtime_context = None @PublicAPI(stability="beta") -@client_mode_hook(auto_init=False) +@client_mode_hook def get_runtime_context(): """Get the runtime context of the current driver/worker. diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index dec530cb3022b..08465f7a422e0 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -25,8 +25,6 @@ get_local_dump_archive, get_cluster_dump_archive, debug_status, RUN_ENV_TYPES) from ray.autoscaler._private.constants import RAY_PROCESSES -from ray.autoscaler._private.fake_multi_node.node_provider import \ - FAKE_HEAD_NODE_ID from ray.autoscaler._private.util import DEBUG_AUTOSCALING_ERROR, \ DEBUG_AUTOSCALING_STATUS @@ -434,6 +432,12 @@ def debug(address): hidden=True, type=json.loads, help="Override system configuration defaults.") +@click.option( + "--lru-evict", + is_flag=True, + hidden=True, + default=False, + help="Specify whether LRU evict will be used for this cluster.") @click.option( "--enable-object-reconstruction", is_flag=True, @@ -479,9 +483,9 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, dashboard_agent_listen_port, block, plasma_directory, autoscaling_config, no_redirect_worker_output, no_redirect_output, plasma_store_socket_name, raylet_socket_name, temp_dir, - system_config, enable_object_reconstruction, metrics_export_port, - no_monitor, tracing_startup_hook, ray_debugger_external, log_style, - log_color, verbose): + system_config, lru_evict, enable_object_reconstruction, + metrics_export_port, no_monitor, tracing_startup_hook, + ray_debugger_external, log_style, log_color, verbose): """Start Ray processes manually on the local machine.""" cli_logger.configure(log_style, log_color, verbose) if gcs_server_port and not head: @@ -536,6 +540,7 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, dashboard_port=dashboard_port, dashboard_agent_listen_port=dashboard_agent_listen_port, _system_config=system_config, + lru_evict=lru_evict, enable_object_reconstruction=enable_object_reconstruction, metrics_export_port=metrics_export_port, no_monitor=no_monitor, @@ -551,11 +556,6 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, s.bind(("", 0)) port = s.getsockname()[1] - if os.environ.get("RAY_FAKE_CLUSTER"): - ray_params.env_vars = { - "RAY_OVERRIDE_NODE_ID_FOR_TESTING": FAKE_HEAD_NODE_ID - } - num_redis_shards = None # Start Ray on the head node. if redis_shard_ports is not None and address is None: @@ -598,9 +598,10 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, "password.", cf.bold("--redis-password"), cf.bold("--address")) + node_ip_address = services.get_node_ip_address() + # Get the node IP address if one is not provided. - ray_params.update_if_absent( - node_ip_address=services.get_node_ip_address()) + ray_params.update_if_absent(node_ip_address=node_ip_address) cli_logger.labeled_value("Local node IP", ray_params.node_ip_address) ray_params.update_if_absent( redis_port=port, @@ -613,7 +614,7 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, # Fail early when starting a new cluster when one is already running if address is None: - default_address = f"{ray_params.node_ip_address}:{port}" + default_address = f"{node_ip_address}:{port}" redis_addresses = services.find_redis_address(default_address) if len(redis_addresses) > 0: raise ConnectionError( diff --git a/python/ray/serialization.py b/python/ray/serialization.py index 5bf8c0d1437f3..bc335e4a8c539 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -91,7 +91,7 @@ def object_ref_reducer(obj): worker = ray.worker.global_worker worker.check_connected() obj, owner_address, object_status = ( - worker.core_worker.serialize_object_ref(obj)) + worker.core_worker.serialize_and_promote_object_ref(obj)) return _object_ref_deserializer, \ (obj.binary(), obj.call_site(), owner_address, object_status) diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 9417f7c3798a3..cfab567726e77 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -106,7 +106,7 @@ py_test( py_test( name = "test_ray_client", - size = "medium", + size = "small", srcs = serve_tests_srcs, tags = ["exclusive", "team:serverless"], deps = [":serve_lib"], @@ -338,11 +338,3 @@ py_test( tags = ["exclusive", "team:serve"], deps = [":serve_lib"] ) - -py_test( - name = "conda_env", - size = "medium", - srcs = glob(["examples/doc/*.py"]), - tags = ["exclusive", "post_wheel_build", "team:serve"], - deps = [":serve_lib"] -) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 05a40dc34df0a..cd0ea1b033816 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -7,7 +7,8 @@ import time from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, overload +from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, Union, + overload) from weakref import WeakValueDictionary from fastapi import APIRouter, FastAPI @@ -188,8 +189,7 @@ def _wait_for_goal(self, def deploy(self, name: str, backend_def: Union[Callable, Type[Callable], str], - init_args: Tuple[Any], - init_kwargs: Dict[Any, Any], + *init_args: Any, ray_actor_options: Optional[Dict] = None, config: Optional[Union[BackendConfig, Dict[str, Any]]] = None, version: Optional[str] = None, @@ -213,10 +213,7 @@ def deploy(self, del ray_actor_options["runtime_env"]["working_dir"] replica_config = ReplicaConfig( - backend_def, - init_args=init_args, - init_kwargs=init_kwargs, - ray_actor_options=ray_actor_options) + backend_def, *init_args, ray_actor_options=ray_actor_options) if isinstance(config, dict): backend_config = BackendConfig.parse_obj(config) @@ -225,10 +222,16 @@ def deploy(self, else: raise TypeError("config must be a BackendConfig or a dictionary.") + python_methods = [] + if inspect.isclass(backend_def): + for method_name, _ in inspect.getmembers(backend_def, + inspect.isfunction): + python_methods.append(method_name) + goal_id, updating = ray.get( self._controller.deploy.remote( - name, backend_config.to_proto_bytes(), replica_config, version, - prev_version, route_prefix, + name, backend_config.to_proto_bytes(), replica_config, + python_methods, version, prev_version, route_prefix, ray.get_runtime_context().job_id)) tag = f"component=serve deployment={name}" @@ -315,16 +318,27 @@ def get_handle( "to create sync handle. Learn more at https://docs.ray.io/en/" "master/serve/http-servehandle.html#sync-and-async-handles") + if endpoint_name in all_endpoints: + this_endpoint = all_endpoints[endpoint_name] + python_methods: List[str] = this_endpoint["python_methods"] + else: + # This can happen in the missing_ok=True case. + # handle.method_name.remote won't work and user must + # use the legacy handle.options(method).remote(). + python_methods: List[str] = [] + if sync: handle = RayServeSyncHandle( self._controller, endpoint_name, + known_python_methods=python_methods, _internal_pickled_http_request=_internal_pickled_http_request, ) else: handle = RayServeHandle( self._controller, endpoint_name, + known_python_methods=python_methods, _internal_pickled_http_request=_internal_pickled_http_request, ) @@ -605,7 +619,6 @@ def __init__(self, version: Optional[str] = None, prev_version: Optional[str] = None, init_args: Optional[Tuple[Any]] = None, - init_kwargs: Optional[Tuple[Any]] = None, route_prefix: Optional[str] = None, ray_actor_options: Optional[Dict] = None, _internal=False) -> None: @@ -631,8 +644,6 @@ def __init__(self, raise TypeError("prev_version must be a string.") if not (init_args is None or isinstance(init_args, tuple)): raise TypeError("init_args must be a tuple.") - if not (init_kwargs is None or isinstance(init_kwargs, dict)): - raise TypeError("init_kwargs must be a dict.") if route_prefix is not None: if not isinstance(route_prefix, str): raise TypeError("route_prefix must be a string.") @@ -649,16 +660,6 @@ def __init__(self, if init_args is None: init_args = () - if init_kwargs is None: - init_kwargs = {} - - # TODO(architkulkarni): Enforce that autoscaling_config and - # user-provided num_replicas should be mutually exclusive. - if version is None and config.autoscaling_config is not None: - # TODO(architkulkarni): Remove this restriction. - raise ValueError( - "Currently autoscaling is only supported for " - "versioned deployments. Try @serve.deployment(version=...).") self._func_or_class = func_or_class self._name = name @@ -666,7 +667,6 @@ def __init__(self, self._prev_version = prev_version self._config = config self._init_args = init_args - self._init_kwargs = init_kwargs self._route_prefix = route_prefix self._ray_actor_options = ray_actor_options @@ -724,12 +724,7 @@ def ray_actor_options(self) -> Optional[Dict]: @property def init_args(self) -> Tuple[Any]: - """Positional args passed to the underlying class's constructor.""" - return self._init_args - - @property - def init_kwargs(self) -> Tuple[Any]: - """Keyword args passed to the underlying class's constructor.""" + """Arguments passed to the underlying class's constructor.""" return self._init_args @property @@ -743,25 +738,20 @@ def __call__(self): "Use `deployment.deploy() instead.`") @PublicAPI - def deploy(self, *init_args, _blocking=True, **init_kwargs): + def deploy(self, *init_args, _blocking=True): """Deploy or update this deployment. Args: init_args (optional): args to pass to the class __init__ method. Not valid if this deployment wraps a function. - init_kwargs (optional): kwargs to pass to the class __init__ - method. Not valid if this deployment wraps a function. """ if len(init_args) == 0 and self._init_args is not None: init_args = self._init_args - if len(init_kwargs) == 0 and self._init_kwargs is not None: - init_kwargs = self._init_kwargs return _get_global_client().deploy( self._name, self._func_or_class, - init_args, - init_kwargs, + *init_args, ray_actor_options=self._ray_actor_options, config=self._config, version=self._version, @@ -793,23 +783,19 @@ def get_handle(self, sync: Optional[bool] = True self._name, missing_ok=True, sync=sync) @PublicAPI - def options(self, - func_or_class: Optional[Callable] = None, - name: Optional[str] = None, - version: Optional[str] = None, - prev_version: Optional[str] = None, - init_args: Optional[Tuple[Any]] = None, - init_kwargs: Optional[Dict[Any, Any]] = None, - route_prefix: Optional[str] = None, - num_replicas: Optional[int] = None, - ray_actor_options: Optional[Dict] = None, - user_config: Optional[Any] = None, - max_concurrent_queries: Optional[int] = None, - _autoscaling_config: Optional[Union[Dict, - AutoscalingConfig]] = None, - _graceful_shutdown_wait_loop_s: Optional[float] = None, - _graceful_shutdown_timeout_s: Optional[float] = None - ) -> "Deployment": + def options( + self, + func_or_class: Optional[Callable] = None, + name: Optional[str] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + init_args: Optional[Tuple[Any]] = None, + route_prefix: Optional[str] = None, + num_replicas: Optional[int] = None, + ray_actor_options: Optional[Dict] = None, + user_config: Optional[Any] = None, + max_concurrent_queries: Optional[int] = None, + ) -> "Deployment": """Return a copy of this deployment with updated options. Only those options passed in will be updated, all others will remain @@ -835,9 +821,6 @@ def options(self, if init_args is None: init_args = self._init_args - if init_kwargs is None: - init_kwargs = self._init_kwargs - if route_prefix is None: if self._route_prefix == f"/{self._name}": route_prefix = None @@ -847,17 +830,6 @@ def options(self, if ray_actor_options is None: ray_actor_options = self._ray_actor_options - if _autoscaling_config is None: - new_config.autoscaling_config = _autoscaling_config - - if _graceful_shutdown_wait_loop_s is not None: - new_config.graceful_shutdown_wait_loop_s = ( - _graceful_shutdown_wait_loop_s) - - if _graceful_shutdown_timeout_s is not None: - new_config.graceful_shutdown_timeout_s = ( - _graceful_shutdown_timeout_s) - return Deployment( func_or_class, name, @@ -865,7 +837,6 @@ def options(self, version=version, prev_version=prev_version, init_args=init_args, - init_kwargs=init_kwargs, route_prefix=route_prefix, ray_actor_options=ray_actor_options, _internal=True, @@ -877,7 +848,6 @@ def __eq__(self, other): self._version == other._version, self._config == other._config, self._init_args == other._init_args, - self._init_kwargs == other._init_kwargs, self._route_prefix == other._route_prefix, self._ray_actor_options == self._ray_actor_options, ]) @@ -901,20 +871,16 @@ def deployment(func_or_class: Callable) -> Deployment: @overload -def deployment( - name: Optional[str] = None, - version: Optional[str] = None, - prev_version: Optional[str] = None, - num_replicas: Optional[int] = None, - init_args: Optional[Tuple[Any]] = None, - init_kwargs: Optional[Dict[Any, Any]] = None, - ray_actor_options: Optional[Dict] = None, - user_config: Optional[Any] = None, - max_concurrent_queries: Optional[int] = None, - _autoscaling_config: Optional[Union[Dict, AutoscalingConfig]] = None, - _graceful_shutdown_wait_loop_s: Optional[float] = None, - _graceful_shutdown_timeout_s: Optional[float] = None -) -> Callable[[Callable], Deployment]: +def deployment(name: Optional[str] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + num_replicas: Optional[int] = None, + init_args: Optional[Tuple[Any]] = None, + ray_actor_options: Optional[Dict] = None, + user_config: Optional[Any] = None, + max_concurrent_queries: Optional[int] = None, + _autoscaling_config: Optional[dict] = None + ) -> Callable[[Callable], Deployment]: pass @@ -926,14 +892,11 @@ def deployment( prev_version: Optional[str] = None, num_replicas: Optional[int] = None, init_args: Optional[Tuple[Any]] = None, - init_kwargs: Optional[Dict[Any, Any]] = None, route_prefix: Optional[str] = None, ray_actor_options: Optional[Dict] = None, user_config: Optional[Any] = None, max_concurrent_queries: Optional[int] = None, - _autoscaling_config: Optional[Union[Dict, AutoscalingConfig]] = None, - _graceful_shutdown_wait_loop_s: Optional[float] = None, - _graceful_shutdown_timeout_s: Optional[float] = None + _autoscaling_config: Optional[dict] = None, ) -> Callable[[Callable], Deployment]: """Define a Serve deployment. @@ -952,10 +915,7 @@ def deployment( not check the existing deployment's version. num_replicas (Optional[int]): The number of processes to start up that will handle requests to this deployment. Defaults to 1. - init_args (Optional[Tuple]): Positional args to be passed to the class - constructor when starting up deployment replicas. These can also be - passed when you call `.deploy()` on the returned Deployment. - init_kwargs (Optional[Dict]): Keyword args to be passed to the class + init_args (Optional[Tuple]): Arguments to be passed to the class constructor when starting up deployment replicas. These can also be passed when you call `.deploy()` on the returned Deployment. route_prefix (Optional[str]): Requests to paths under this HTTP path @@ -1002,13 +962,8 @@ class MyDeployment: config.max_concurrent_queries = max_concurrent_queries if _autoscaling_config is not None: - config.autoscaling_config = _autoscaling_config - - if _graceful_shutdown_wait_loop_s is not None: - config.graceful_shutdown_wait_loop_s = _graceful_shutdown_wait_loop_s - - if _graceful_shutdown_timeout_s is not None: - config.graceful_shutdown_timeout_s = _graceful_shutdown_timeout_s + config.autoscaling_config = AutoscalingConfig.parse_obj( + _autoscaling_config) def decorator(_func_or_class): return Deployment( @@ -1018,7 +973,6 @@ def decorator(_func_or_class): version=version, prev_version=prev_version, init_args=init_args, - init_kwargs=init_kwargs, route_prefix=route_prefix, ray_actor_options=ray_actor_options, _internal=True, @@ -1060,7 +1014,6 @@ def get_deployment(name: str) -> Deployment: backend_info.backend_config, version=backend_info.version, init_args=backend_info.replica_config.init_args, - init_kwargs=backend_info.replica_config.init_kwargs, route_prefix=route_prefix, ray_actor_options=backend_info.replica_config.ray_actor_options, _internal=True, @@ -1084,7 +1037,6 @@ def list_deployments() -> Dict[str, Deployment]: backend_info.backend_config, version=backend_info.version, init_args=backend_info.replica_config.init_args, - init_kwargs=backend_info.replica_config.init_kwargs, route_prefix=route_prefix, ray_actor_options=backend_info.replica_config.ray_actor_options, _internal=True, diff --git a/python/ray/serve/autoscaling_metrics.py b/python/ray/serve/autoscaling_metrics.py index 4b0d030d700cc..084996760d297 100644 --- a/python/ray/serve/autoscaling_metrics.py +++ b/python/ray/serve/autoscaling_metrics.py @@ -74,7 +74,7 @@ def add_metrics_point(self, data_points: Dict[str, float], Args: data_points(dict): dictionary containing the metrics values. The - key should be a string that uniquely identifies this time series + key should be a string that uniquely identitify this time series and to be used to perform aggregation. timestamp(float): the unix epoch timestamp the metrics are collected at. @@ -98,9 +98,6 @@ def window_average(self, do_compact(bool): whether or not to delete the datapoints that's before `window_start_timestamp_s` to save memory. Default is true. - Returns: - The average of all the datapoints for the key on and after time - window_start_timestamp_s, or None if there are no such points. """ datapoints = self.data[key] diff --git a/python/ray/serve/autoscaling_policy.py b/python/ray/serve/autoscaling_policy.py index 23dbbf65159e9..6a9887fb7497c 100644 --- a/python/ray/serve/autoscaling_policy.py +++ b/python/ray/serve/autoscaling_policy.py @@ -16,6 +16,7 @@ def calculate_desired_num_replicas(autoscaling_config: AutoscalingConfig, current_num_ongoing_requests (List[float]): A list of the number of ongoing requests for each replica. Assumes each entry has already been time-averaged over the desired lookback window. + current_num_replicas (int): The current number of active replicas. Returns: desired_num_replicas: The desired number of replicas to scale to, based diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 6068887f9bd7a..2ab4c5e41d99d 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -76,7 +76,6 @@ def __init__(self, actor_name: str, detached: bool, controller_name: str, self._ready_obj_ref = None self._graceful_shutdown_ref = None - self._graceful_shutdown_timeout_s = None self._actor_resources = None self._health_check_ref = None @@ -148,8 +147,6 @@ def start(self, backend_info: BackendInfo, version: BackendVersion): Start a new actor for current BackendReplica instance. """ self._actor_resources = backend_info.replica_config.resource_dict - self._graceful_shutdown_timeout_s = ( - backend_info.backend_config.graceful_shutdown_timeout_s) if USE_PLACEMENT_GROUP: self._placement_group = self.create_placement_group( self._placement_group_name, self._actor_resources) @@ -167,7 +164,6 @@ def start(self, backend_info: BackendInfo, version: BackendVersion): **backend_info.replica_config.ray_actor_options).remote( self.backend_tag, self.replica_tag, backend_info.replica_config.init_args, - backend_info.replica_config.init_kwargs, backend_info.backend_config.to_proto_bytes(), version, self._controller_name, self._detached) @@ -247,19 +243,14 @@ def actor_resources(self) -> Dict[str, float]: def available_resources(self) -> Dict[str, float]: return ray.available_resources() - def graceful_stop(self) -> Duration: - """Request the actor to exit gracefully. - - Returns the timeout after which to kill the actor. - """ + def graceful_stop(self) -> None: + """Request the actor to exit gracefully.""" try: handle = ray.get_actor(self._actor_name) self._graceful_shutdown_ref = handle.prepare_for_shutdown.remote() except ValueError: pass - return self._graceful_shutdown_timeout_s - def check_stopped(self) -> bool: """Check if the actor has exited.""" try: @@ -395,15 +386,14 @@ def check_started(self) -> ReplicaStartupStatus: return status - def stop(self, graceful: bool = True) -> None: + def stop(self, graceful_shutdown_timeout_s: Duration = 0) -> None: """Stop the replica. Should handle the case where the replica is already stopped. """ - timeout_s = self._actor.graceful_stop() - if not graceful: - timeout_s = 0 - self._shutdown_deadline = time.time() + timeout_s + self._actor.graceful_stop() + self._graceful_shutdown_timeout_s = graceful_shutdown_timeout_s + self._shutdown_deadline = time.time() + graceful_shutdown_timeout_s def check_stopped(self) -> bool: """Check if the replica has finished stopping.""" @@ -412,13 +402,14 @@ def check_stopped(self) -> bool: self._actor.cleanup() return True - timeout_passed = time.time() > self._shutdown_deadline + timeout_passed = time.time() >= self._shutdown_deadline + if timeout_passed: # Graceful period passed, kill it forcefully. # This will be called repeatedly until the replica shuts down. logger.debug( - f"Replica {self.replica_tag} did not shut down after grace " - "period, force-killing it. " + f"Replica {self.replica_tag} did not shutdown after " + f"{self._graceful_shutdown_timeout_s}s, force-killing. " f"component=serve deployment={self.backend_tag} " f"replica={self.replica_tag}") @@ -731,9 +722,9 @@ def deploy(self, backend_info: BackendInfo) -> Tuple[Optional[GoalId], bool]: """Deploy the backend. - If the backend already exists with the same version and BackendConfig, - this is a no-op and returns the GoalId corresponding to the existing - update if there is one. + If the backend already exists with the same version, this is a no-op + and returns the GoalId corresponding to the existing update if there + is one. Returns: GoalId, bool: The GoalId for the client to wait for and whether or @@ -769,8 +760,11 @@ def deploy(self, self._goal_manager.complete_goal(existing_goal_id) return new_goal_id, True - def delete(self) -> Optional[GoalId]: + def delete(self, force_kill: bool = False) -> Optional[GoalId]: new_goal_id, existing_goal_id = self._set_backend_goal(None) + if force_kill: + self._target_info.backend_config.\ + experimental_graceful_shutdown_timeout_s = 0 self._save_checkpoint_func() self._notify_backend_configs_changed() @@ -828,6 +822,9 @@ def _stop_wrong_version_replicas(self) -> int: states=[ReplicaState.STARTING, ReplicaState.RUNNING], max_replicas=max_to_stop) + graceful_shutdown_timeout_s = ( + self._target_info.backend_config. + experimental_graceful_shutdown_timeout_s) code_version_changes = 0 user_config_changes = 0 for replica in replicas_to_update: @@ -837,7 +834,8 @@ def _stop_wrong_version_replicas(self) -> int: if (replica.version.code_version != self._target_version.code_version): code_version_changes += 1 - replica.stop() + replica.stop( + graceful_shutdown_timeout_s=graceful_shutdown_timeout_s) self._replicas.add(ReplicaState.STOPPING, replica) # If only the user_config is a mismatch, we update it dynamically # without restarting the replica. @@ -871,6 +869,10 @@ def _scale_backend_replicas(self) -> bool: assert self._target_replicas >= 0, ("Number of replicas must be" " greater than or equal to 0.") + graceful_shutdown_timeout_s = ( + self._target_info.backend_config. + experimental_graceful_shutdown_timeout_s) + self._stop_wrong_version_replicas() current_replicas = self._replicas.count(states=[ @@ -922,7 +924,8 @@ def _scale_backend_replicas(self) -> bool: for replica in replicas_to_stop: logger.debug(f"Adding STOPPING to replica_tag: {replica}, " f"backend_tag: {self._name}") - replica.stop() + replica.stop( + graceful_shutdown_timeout_s=graceful_shutdown_timeout_s) self._replicas.add(ReplicaState.STOPPING, replica) return True @@ -1011,7 +1014,7 @@ def _check_startup_replicas(self, # Increase startup failure counter if we're tracking it self._replica_constructor_retry_counter += 1 - replica.stop(graceful=False) + replica.stop(graceful_shutdown_timeout_s=0) self._replicas.add(ReplicaState.STOPPING, replica) transitioned = True elif start_status == ReplicaStartupStatus.PENDING: @@ -1023,7 +1026,7 @@ def _check_startup_replicas(self, if not stop_on_slow: self._replicas.add(original_state, replica) else: - replica.stop(graceful=False) + replica.stop(graceful_shutdown_timeout_s=0) self._replicas.add(ReplicaState.STOPPING, replica) transitioned = True slow_replicas.append(replica) @@ -1046,7 +1049,7 @@ def _check_and_update_replicas(self) -> bool: f"{self._name} failed health check, stopping it. " f"component=serve deployment={self._name} " f"replica={replica.replica_tag}") - replica.stop(graceful=False) + replica.stop(graceful_shutdown_timeout_s=0) self._replicas.add(ReplicaState.STOPPING, replica) slow_start_replicas = [] @@ -1070,9 +1073,8 @@ def _check_and_update_replicas(self) -> bool: f"Deployment '{self._name}' has " f"{len(slow_start_replicas)} replicas that have taken " f"more than {SLOW_STARTUP_WARNING_S}s to start up. This " - "may be caused by waiting for the cluster to auto-scale, " - "waiting for a runtime environment to install, or a slow " - "constructor. Resources required " + "may be caused by waiting for the cluster to auto-scale " + "or because the constructor is slow. Resources required " f"for each replica: {required}, resources available: " f"{available}. component=serve deployment={self._name}") @@ -1234,7 +1236,7 @@ def shutdown(self) -> List[GoalId]: shutdown_goals = [] for backend_state in self._backend_states.values(): - goal = backend_state.delete() + goal = backend_state.delete(force_kill=True) if goal is not None: shutdown_goals.append(goal) @@ -1300,9 +1302,9 @@ def deploy_backend(self, backend_tag: BackendTag, backend_info: BackendInfo ) -> Tuple[Optional[GoalId], bool]: """Deploy the backend. - If the backend already exists with the same version and BackendConfig, - this is a no-op and returns the GoalId corresponding to the existing - update if there is one. + If the backend already exists with the same version, this is a no-op + and returns the GoalId corresponding to the existing update if there + is one. Returns: GoalId, bool: The GoalId for the client to wait for and whether or @@ -1317,14 +1319,15 @@ def deploy_backend(self, backend_tag: BackendTag, backend_info: BackendInfo return self._backend_states[backend_tag].deploy(backend_info) - def delete_backend(self, backend_tag: BackendTag) -> Optional[GoalId]: + def delete_backend(self, backend_tag: BackendTag, + force_kill: bool = False) -> Optional[GoalId]: # This method must be idempotent. We should validate that the # specified backend exists on the client. if backend_tag not in self._backend_states: return None backend_state = self._backend_states[backend_tag] - return backend_state.delete() + return backend_state.delete(force_kill=force_kill) def update(self) -> bool: """Updates the state of all backends to match their goal state.""" diff --git a/python/ray/serve/replica.py b/python/ray/serve/backend_worker.py similarity index 91% rename from python/ray/serve/replica.py rename to python/ray/serve/backend_worker.py index cc90ada23fd7e..a049bdfac3a84 100644 --- a/python/ray/serve/replica.py +++ b/python/ray/serve/backend_worker.py @@ -15,11 +15,12 @@ from ray.serve.autoscaling_metrics import start_metrics_pusher from ray.serve.common import BackendTag, ReplicaTag -from ray.serve.config import BackendConfig from ray.serve.http_util import ASGIHTTPSender from ray.serve.utils import parse_request_item, _get_logger from ray.serve.exceptions import RayServeException from ray.util import metrics +from ray.serve.config import BackendConfig +from ray.serve.long_poll import LongPollClient, LongPollNamespace from ray.serve.router import Query, RequestMetadata from ray.serve.constants import ( BACKEND_RECONFIGURE_METHOD, @@ -31,7 +32,7 @@ logger = _get_logger() -def create_replica_wrapper(name: str, serialized_backend_def: bytes): +def create_backend_replica(name: str, serialized_backend_def: bytes): """Creates a replica class wrapping the provided function or class. This approach is picked over inheritance to avoid conflict between user @@ -42,7 +43,7 @@ def create_replica_wrapper(name: str, serialized_backend_def: bytes): # TODO(architkulkarni): Add type hints after upgrading cloudpickle class RayServeWrappedReplica(object): async def __init__(self, backend_tag, replica_tag, init_args, - init_kwargs, backend_config_proto_bytes: bytes, + backend_config_proto_bytes: bytes, version: BackendVersion, controller_name: str, detached: bool): backend = cloudpickle.loads(serialized_backend_def) @@ -71,8 +72,7 @@ async def __init__(self, backend_tag, replica_tag, init_args, # This allows backends to define an async __init__ method # (required for FastAPI backend definition). _callable = backend.__new__(backend) - await sync_to_async(_callable.__init__)(*init_args, - **init_kwargs) + await sync_to_async(_callable.__init__)(*init_args) # Setting the context again to update the servable_object. ray.serve.api._set_internal_replica_context( backend_tag, @@ -149,6 +149,8 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, self.replica_tag = replica_tag self.callable = _callable self.is_function = is_function + + self.backend_config = backend_config self.user_config = user_config self.version = version @@ -164,6 +166,16 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, "replica": self.replica_tag }) + self.loop = asyncio.get_event_loop() + self.long_poll_client = LongPollClient( + controller_handle, + { + (LongPollNamespace.BACKEND_CONFIGS, self.backend_tag): self. + _update_backend_configs, + }, + call_in_event_loop=self.loop, + ) + self.error_counter = metrics.Counter( "serve_deployment_error_counter", description=("The number of exceptions that have " @@ -205,9 +217,6 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, self.restart_counter.inc() - self._shutdown_wait_loop_s = ( - backend_config.graceful_shutdown_wait_loop_s) - if backend_config.autoscaling_config: config = backend_config.autoscaling_config start_metrics_pusher( @@ -231,19 +240,10 @@ def _collect_autoscaling_metrics(self): def get_runner_method(self, request_item: Query) -> Callable: method_name = request_item.metadata.call_method if not hasattr(self.callable, method_name): - # Filter to methods that don't start with '__' prefix. - def callable_method_filter(attr): - if attr.startswith("__"): - return False - elif not callable(getattr(self.callable, attr)): - return False - - return True - - methods = list(filter(callable_method_filter, dir(self.callable))) - raise RayServeException(f"Tried to call a method '{method_name}' " - "that does not exist. Available methods: " - f"{methods}.") + raise RayServeException("Backend doesn't have method {} " + "which is specified in the request. " + "The available methods are {}".format( + method_name, dir(self.callable))) if self.is_function: return self.callable return getattr(self.callable, method_name) @@ -309,6 +309,9 @@ async def reconfigure(self, getattr(self.callable, BACKEND_RECONFIGURE_METHOD)) await reconfigure_method(user_config) + def _update_backend_configs(self, new_config_bytes: bytes) -> None: + self.backend_config = BackendConfig.from_proto_bytes(new_config_bytes) + async def handle_request(self, request: Query) -> asyncio.Future: request.tick_enter_replica = time.time() logger.debug("Replica {} received request {}".format( @@ -338,17 +341,18 @@ async def prepare_for_shutdown(self): Trigger a graceful shutdown protocol that will wait for all the queued tasks to be completed and return to the controller. """ + sleep_time = self.backend_config.experimental_graceful_shutdown_wait_loop_s # noqa: E501 while True: # Sleep first because we want to make sure all the routers receive # the notification to remove this replica first. - await asyncio.sleep(self._shutdown_wait_loop_s) + await asyncio.sleep(sleep_time) if self.num_ongoing_requests == 0: break else: logger.info( - "Waiting for an additional " - f"{self._shutdown_wait_loop_s}s to shut down because " - f"there are {self.num_ongoing_requests} ongoing requests.") + f"Waiting for an additional {sleep_time}s to shut down " + f"because there are {self.num_ongoing_requests} " + "ongoing requests.") # Explicitly call the del method to trigger clean up. # We set the del method to noop after succssifully calling it so the diff --git a/python/ray/serve/common.py b/python/ray/serve/common.py index 0e97b5cf98eb7..be13503c97334 100644 --- a/python/ray/serve/common.py +++ b/python/ray/serve/common.py @@ -1,7 +1,7 @@ import ray -from dataclasses import dataclass -from typing import Optional +from dataclasses import dataclass, field +from typing import List, Optional from uuid import UUID from ray.actor import ActorClass @@ -17,6 +17,7 @@ @dataclass class EndpointInfo: + python_methods: Optional[List[str]] = field(default_factory=list) route: Optional[str] = None diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index b7d5c08457691..4002550ae109f 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -1,7 +1,7 @@ import inspect import pickle from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, List, Optional import pydantic from google.protobuf.json_format import MessageToDict @@ -24,10 +24,8 @@ class AutoscalingConfig(BaseModel): # Private options below # Metrics scraping options - - # How often to scrape for metrics metrics_interval_s: float = 10.0 - # Time window to average over for metrics. + loop_period_s: float = 30.0 look_back_period_s: float = 30.0 # Internal autoscaling configuration options @@ -36,7 +34,6 @@ class AutoscalingConfig(BaseModel): smoothing_factor: float = 1.0 # TODO(architkulkarni): implement below - # loop_period_s = 30 # How frequently to make autoscaling decisions # How long to wait before scaling down replicas # downscale_delay_s: float = 600.0 # How long to wait before scaling up replicas @@ -55,6 +52,8 @@ class AutoscalingConfig(BaseModel): class BackendConfig(BaseModel): """Configuration options for a backend, to be set by the user. + DEPRECATED. Will be removed in Ray 1.5. See docs for details. + Args: num_replicas (Optional[int]): The number of processes to start up that will handle requests to this backend. Defaults to 1. @@ -64,10 +63,10 @@ class BackendConfig(BaseModel): user_config (Optional[Any]): Arguments to pass to the reconfigure method of the backend. The reconfigure method is called if user_config is not None. - graceful_shutdown_wait_loop_s (Optional[float]): Duration + experimental_graceful_shutdown_wait_loop_s (Optional[float]): Duration that backend workers will wait until there is no more work to be done before shutting down. Defaults to 2s. - graceful_shutdown_timeout_s (Optional[float]): + experimental_graceful_shutdown_timeout_s (Optional[float]): Controller waits for this duration to forcefully kill the replica for shutdown. Defaults to 20s. """ @@ -76,8 +75,8 @@ class BackendConfig(BaseModel): max_concurrent_queries: Optional[int] = None user_config: Any = None - graceful_shutdown_wait_loop_s: NonNegativeFloat = 2.0 - graceful_shutdown_timeout_s: NonNegativeFloat = 20.0 + experimental_graceful_shutdown_wait_loop_s: NonNegativeFloat = 2.0 + experimental_graceful_shutdown_timeout_s: NonNegativeFloat = 20.0 autoscaling_config: Optional[AutoscalingConfig] = None @@ -122,23 +121,16 @@ def from_proto_bytes(cls, proto_bytes: bytes): class ReplicaConfig: - def __init__(self, - backend_def: Callable, - init_args: Optional[Tuple[Any]] = None, - init_kwargs: Optional[Dict[Any, Any]] = None, - ray_actor_options=None): + def __init__(self, backend_def, *init_args, ray_actor_options=None): # Validate that backend_def is an import path, function, or class. if isinstance(backend_def, str): self.func_or_class_name = backend_def pass elif inspect.isfunction(backend_def): self.func_or_class_name = backend_def.__name__ - if init_args: + if len(init_args) != 0: raise ValueError( "init_args not supported for function backend.") - if init_kwargs: - raise ValueError( - "init_kwargs not supported for function backend.") elif inspect.isclass(backend_def): self.func_or_class_name = backend_def.__name__ else: @@ -147,8 +139,7 @@ def __init__(self, format(type(backend_def))) self.serialized_backend_def = cloudpickle.dumps(backend_def) - self.init_args = init_args if init_args is not None else () - self.init_kwargs = init_kwargs if init_kwargs is not None else {} + self.init_args = init_args if ray_actor_options is None: self.ray_actor_options = {} else: @@ -167,13 +158,12 @@ def _validate(self): raise TypeError("ray_actor_options must be a dictionary.") elif "lifetime" in self.ray_actor_options: raise ValueError( - "Specifying lifetime in ray_actor_options is not allowed.") + "Specifying lifetime in init_args is not allowed.") elif "name" in self.ray_actor_options: - raise ValueError( - "Specifying name in ray_actor_options is not allowed.") + raise ValueError("Specifying name in init_args is not allowed.") elif "max_restarts" in self.ray_actor_options: raise ValueError("Specifying max_restarts in " - "ray_actor_options is not allowed.") + "init_args is not allowed.") else: # Ray defaults to zero CPUs for placement, we default to one here. if "num_cpus" not in self.ray_actor_options: diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index cdaf1cf008151..c367dc4232b81 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -8,8 +8,8 @@ import ray from ray.actor import ActorHandle from ray.serve.async_goal_manager import AsyncGoalManager -from ray.serve.autoscaling_policy import calculate_desired_num_replicas from ray.serve.backend_state import ReplicaState, BackendStateManager +from ray.serve.backend_worker import create_backend_replica from ray.serve.common import ( BackendInfo, BackendTag, @@ -20,10 +20,9 @@ ReplicaTag, ) from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig -from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY +from ray.serve.constants import (CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY) from ray.serve.endpoint_state import EndpointState from ray.serve.http_state import HTTPState -from ray.serve.replica import create_replica_wrapper from ray.serve.storage.checkpoint_path import make_kv_store from ray.serve.long_poll import LongPollHost from ray.serve.utils import logger @@ -105,10 +104,6 @@ def record_autoscaling_metrics(self, data: Dict[str, float], def _dump_autoscaling_metrics_for_testing(self): return self.autoscaling_metrics_store.data - def _dump_replica_states_for_testing(self, deployment_name): - return self.backend_state_manager._backend_states[ - deployment_name]._replicas - async def wait_for_goal(self, goal_id: GoalId) -> Optional[Exception]: return await self.goal_manager.wait_for_goal(goal_id) @@ -134,55 +129,8 @@ def get_http_proxies(self) -> Dict[NodeId, ActorHandle]: """Returns a dictionary of node ID to http_proxy actor handles.""" return self.http_state.get_http_proxy_handles() - def autoscale(self) -> None: - """Update autoscaling deployments with calculated num_replicas.""" - for deployment_name, (backend_info, - route_prefix) in self.list_deployments().items(): - backend_config = backend_info.backend_config - autoscaling_config = backend_config.autoscaling_config - - if autoscaling_config is None: - continue - - replicas = self.backend_state_manager._backend_states[ - deployment_name]._replicas - running_replicas = replicas.get([ReplicaState.RUNNING]) - - current_num_ongoing_requests = [] - for replica in running_replicas: - replica_tag = replica.replica_tag - num_ongoing_requests = ( - self.autoscaling_metrics_store.window_average( - replica_tag, - time.time() - autoscaling_config.look_back_period_s)) - if num_ongoing_requests is not None: - current_num_ongoing_requests.append(num_ongoing_requests) - - if len(current_num_ongoing_requests) == 0: - continue - - new_backend_config = backend_config.copy() - new_backend_config.num_replicas = calculate_desired_num_replicas( - autoscaling_config, current_num_ongoing_requests) - - replica_config = backend_info.replica_config - deployer_job_id = backend_info.deployer_job_id - backend_config_proto_bytes = new_backend_config.to_proto_bytes() - goal_id, updating = self.deploy( - deployment_name, - backend_config_proto_bytes, - replica_config, - version=backend_info.version, - prev_version=backend_info.version, - route_prefix=route_prefix, - deployer_job_id=deployer_job_id) - async def run_control_loop(self) -> None: while True: - try: - self.autoscale() - except Exception: - logger.exception("Exception while autoscaling deployments.") async with self.write_lock: try: self.http_state.update() @@ -270,56 +218,57 @@ async def shutdown(self) -> List[GoalId]: return goal_ids - def deploy(self, - name: str, - backend_config_proto_bytes: bytes, - replica_config: ReplicaConfig, - version: Optional[str], - prev_version: Optional[str], - route_prefix: Optional[str], - deployer_job_id: "Optional[ray._raylet.JobID]" = None - ) -> Tuple[Optional[GoalId], bool]: + async def deploy(self, + name: str, + backend_config_proto_bytes: bytes, + replica_config: ReplicaConfig, + python_methods: List[str], + version: Optional[str], + prev_version: Optional[str], + route_prefix: Optional[str], + deployer_job_id: "Optional[ray._raylet.JobID]" = None + ) -> Tuple[Optional[GoalId], bool]: if route_prefix is not None: assert route_prefix.startswith("/") backend_config = BackendConfig.from_proto_bytes( backend_config_proto_bytes) - if prev_version is not None: - existing_backend_info = self.backend_state_manager.get_backend( - name) - if (existing_backend_info is None - or not existing_backend_info.version): - raise ValueError( - f"prev_version '{prev_version}' is specified but " - "there is no existing deployment.") - if existing_backend_info.version != prev_version: - raise ValueError(f"prev_version '{prev_version}' " - "does not match with the existing " - f"version '{existing_backend_info.version}'.") - backend_info = BackendInfo( - actor_def=ray.remote( - create_replica_wrapper(name, - replica_config.serialized_backend_def)), - version=version, - backend_config=backend_config, - replica_config=replica_config, - deployer_job_id=deployer_job_id, - start_time_ms=int(time.time() * 1000)) - # TODO(architkulkarni): When a deployment is redeployed, even if - # the only change was num_replicas, the start_time_ms is refreshed. - # This is probably not the desired behavior for an autoscaling - # deployment, which redeploys very often to change num_replicas. - - goal_id, updating = self.backend_state_manager.deploy_backend( - name, backend_info) - endpoint_info = EndpointInfo(route=route_prefix) - self.endpoint_state.update_endpoint(name, endpoint_info) - return goal_id, updating + async with self.write_lock: + if prev_version is not None: + existing_backend_info = self.backend_state_manager.get_backend( + name) + if (existing_backend_info is None + or not existing_backend_info.version): + raise ValueError( + f"prev_version '{prev_version}' is specified but " + "there is no existing deployment.") + if existing_backend_info.version != prev_version: + raise ValueError( + f"prev_version '{prev_version}' " + "does not match with the existing " + f"version '{existing_backend_info.version}'.") + backend_info = BackendInfo( + actor_def=ray.remote( + create_backend_replica( + name, replica_config.serialized_backend_def)), + version=version, + backend_config=backend_config, + replica_config=replica_config, + deployer_job_id=deployer_job_id, + start_time_ms=int(time.time() * 1000)) + + goal_id, updating = self.backend_state_manager.deploy_backend( + name, backend_info) + endpoint_info = EndpointInfo( + route=route_prefix, python_methods=python_methods) + self.endpoint_state.update_endpoint(name, endpoint_info) + return goal_id, updating def delete_deployment(self, name: str) -> Optional[GoalId]: self.endpoint_state.delete_endpoint(name) - return self.backend_state_manager.delete_backend(name) + return self.backend_state_manager.delete_backend( + name, force_kill=False) def get_deployment_info(self, name: str) -> Tuple[BackendInfo, str]: """Get the current information about a deployment. diff --git a/python/ray/serve/endpoint_state.py b/python/ray/serve/endpoint_state.py index 5bba277001c54..6483f7355ff0e 100644 --- a/python/ray/serve/endpoint_state.py +++ b/python/ray/serve/endpoint_state.py @@ -79,6 +79,7 @@ def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]: for endpoint, info in self._endpoints.items(): endpoints[endpoint] = { "route": info.route, + "python_methods": info.python_methods, } return endpoints diff --git a/python/ray/serve/examples/doc/conda_env.py b/python/ray/serve/examples/doc/conda_env.py index 1607cf7d60e37..c431964bb0772 100644 --- a/python/ray/serve/examples/doc/conda_env.py +++ b/python/ray/serve/examples/doc/conda_env.py @@ -1,28 +1,27 @@ import requests from ray import serve +import tensorflow as tf serve.start() @serve.deployment -def requests_version(request): - return requests.__version__ +def tf_version(request): + return ("Tensorflow " + tf.__version__) -requests_version.options( - name="25", - ray_actor_options={ +tf_version.options( + name="tf1", ray_actor_options={ "runtime_env": { - "pip": ["ray[serve]", "requests==2.25.1"] + "conda": "ray-tf1" } }).deploy() -requests_version.options( - name="26", - ray_actor_options={ +tf_version.options( + name="tf2", ray_actor_options={ "runtime_env": { - "pip": ["ray[serve]", "requests==2.26.0"] + "conda": "ray-tf2" } }).deploy() -assert requests.get("http://127.0.0.1:8000/25").text == "2.25.1" -assert requests.get("http://127.0.0.1:8000/26").text == "2.26.0" +print(requests.get("http://127.0.0.1:8000/tf1").text) # Tensorflow 1.15.0 +print(requests.get("http://127.0.0.1:8000/tf2").text) # Tensorflow 2.3.0 diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 340be1f987a7c..7c315f66605f4 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -1,7 +1,7 @@ import asyncio import concurrent.futures from dataclasses import dataclass, field -from typing import Dict, Optional, Union, Coroutine +from typing import Dict, List, Optional, Union, Coroutine import threading from enum import Enum @@ -75,12 +75,14 @@ def __init__( endpoint_name: EndpointTag, handle_options: Optional[HandleOptions] = None, *, + known_python_methods: List[str] = [], _router: Optional[Router] = None, _internal_pickled_http_request: bool = False, ): self.controller_handle = controller_handle self.endpoint_name = endpoint_name self.handle_options = handle_options or HandleOptions() + self.known_python_methods = known_python_methods self.handle_tag = f"{self.endpoint_name}#{get_random_letters()}" self._pickled_http_request = _internal_pickled_http_request @@ -179,11 +181,21 @@ def __reduce__(self): "controller_handle": self.controller_handle, "endpoint_name": self.endpoint_name, "handle_options": self.handle_options, + "known_python_methods": self.known_python_methods, "_internal_pickled_http_request": self._pickled_http_request, } return lambda kwargs: RayServeHandle(**kwargs), (serialized_data, ) def __getattr__(self, name): + if name not in self.known_python_methods: + raise AttributeError( + f"ServeHandle for endpoint {self.endpoint_name} doesn't have " + f"python method {name}. If you used the " + f"get_handle('{self.endpoint_name}', missing_ok=True) flag, " + f"Serve cannot know all methods for {self.endpoint_name}. " + "You can set the method manually via " + f"handle.options(method_name='{name}').remote().") + return self.options(method_name=name) @@ -225,6 +237,7 @@ def __reduce__(self): "controller_handle": self.controller_handle, "endpoint_name": self.endpoint_name, "handle_options": self.handle_options, + "known_python_methods": self.known_python_methods, "_internal_pickled_http_request": self._pickled_http_request, } return lambda kwargs: RayServeSyncHandle(**kwargs), (serialized_data, ) diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index e129f5d60cab5..7eedc17fcfd5a 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -259,11 +259,8 @@ def __init__(self, port: int, controller_name: str, controller_namespace: str, - http_middlewares: Optional[List[ - "starlette.middleware.Middleware"]] = None): # noqa: F821 - if http_middlewares is None: - http_middlewares = [] - + http_middlewares: List[ + "starlette.middleware.Middleware"] = []): # noqa: F821 self.host = host self.port = port diff --git a/python/ray/serve/long_poll.py b/python/ray/serve/long_poll.py index 9d5a31bf86e6b..b1133adb5a251 100644 --- a/python/ray/serve/long_poll.py +++ b/python/ray/serve/long_poll.py @@ -103,14 +103,13 @@ def _process_update(self, updates: Dict[str, UpdatedObject]): "Shutting down.") return - if isinstance(updates, ConnectionError): - logger.warning("LongPollClient connection failed, shutting down.") - return - if isinstance(updates, (ray.exceptions.RayTaskError)): - # Some error happened in the controller. It could be a bug or some - # undesired state. - logger.error("LongPollHost errored\n" + updates.traceback_str) + # This can happen during shutdown where the controller doesn't + # contain this key, we will just repull. + # NOTE(simon): should we repull or just wait in the long poll + # host? + if not isinstance(updates.as_instanceof_cause(), ValueError): + logger.error("LongPollHost errored\n" + updates.traceback_str) self._poll_next() return @@ -168,21 +167,22 @@ async def listen_for_change( until there's one updates. """ watched_keys = keys_to_snapshot_ids.keys() - existent_keys = set(watched_keys).intersection( - set(self.snapshot_ids.keys())) + nonexistent_keys = set(watched_keys) - set(self.snapshot_ids.keys()) + if len(nonexistent_keys) > 0: + raise ValueError(f"Keys not found: {nonexistent_keys}.") - # If there are any outdated keys (by comparing snapshot ids) - # return immediately. + # 2. If there are any outdated keys (by comparing snapshot ids) + # return immediately. client_outdated_keys = { key: UpdatedObject(self.object_snapshots[key], self.snapshot_ids[key]) - for key in existent_keys + for key in watched_keys if self.snapshot_ids[key] != keys_to_snapshot_ids[key] } if len(client_outdated_keys) > 0: return client_outdated_keys - # Otherwise, register asyncio events to be waited. + # 3. Otherwise, register asyncio events to be waited. async_task_to_watched_keys = {} for key in watched_keys: # Create a new asyncio event for this key diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 90fce03b03ceb..fa18456546fa2 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -1,4 +1,3 @@ -import sys import asyncio import pickle import itertools @@ -67,14 +66,7 @@ def __init__( # Used to unblock this replica set waiting for free replicas. A newly # added replica or updated max_concurrent_queries value means the # query that waits on a free replica might be unblocked on. - - # Python 3.8 has deprecated the 'loop' parameter, and Python 3.10 has - # removed it alltogether. Call accordingly. - if sys.version_info.major >= 3 and sys.version_info.minor >= 10: - self.config_updated_event = asyncio.Event() - else: - self.config_updated_event = asyncio.Event(loop=event_loop) - + self.config_updated_event = asyncio.Event(loop=event_loop) self.num_queued_queries = 0 self.num_queued_queries_gauge = metrics.Gauge( "serve_deployment_queued_queries", diff --git a/python/ray/serve/storage/checkpoint_path.py b/python/ray/serve/storage/checkpoint_path.py index f6abc8da22566..de892f0728978 100644 --- a/python/ray/serve/storage/checkpoint_path.py +++ b/python/ray/serve/storage/checkpoint_path.py @@ -32,9 +32,7 @@ def make_kv_store(checkpoint_path, namespace): if parsed_url.scheme == "s3": bucket = parsed_url.netloc - # We need to strip leading "/" in path as right key to use in - # boto3. Ex: s3://bucket/folder/file.zip -> key = "folder/file.zip" - prefix = parsed_url.path.lstrip("/") + prefix = parsed_url.path logger.info( "Using Ray S3 KVStore for controller checkpoint and recovery: " f"bucket={bucket} checkpoint_path={checkpoint_path}") diff --git a/python/ray/serve/storage/kv_store.py b/python/ray/serve/storage/kv_store.py index ea24de4541abe..74b17d7a75932 100644 --- a/python/ray/serve/storage/kv_store.py +++ b/python/ray/serve/storage/kv_store.py @@ -186,7 +186,7 @@ def __init__( ): self._namespace = namepsace self._bucket = bucket - self._prefix = prefix + self._prefix = prefix + "/" if prefix else "" if not boto3: raise ImportError( "You tried to use S3KVstore client without boto3 installed." @@ -199,7 +199,7 @@ def __init__( aws_session_token=aws_session_token) def get_storage_key(self, key: str) -> str: - return f"{self._prefix}/{self._namespace}-{key}" + return f"{self._prefix}{self._namespace}-{key}" def put(self, key: str, val: bytes) -> bool: """Put the key-value pair into the store. diff --git a/python/ray/serve/tests/conftest.py b/python/ray/serve/tests/conftest.py index 36fdba0d5b7cc..1e635dd1b647a 100644 --- a/python/ray/serve/tests/conftest.py +++ b/python/ray/serve/tests/conftest.py @@ -9,13 +9,6 @@ serve.controller._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5 -@pytest.fixture -def ray_shutdown(): - yield - serve.shutdown() - ray.shutdown() - - @pytest.fixture(scope="session") def _shared_serve_instance(): # Note(simon): diff --git a/python/ray/serve/tests/test_advanced.py b/python/ray/serve/tests/test_advanced.py index 03f606e58fcbd..74287ed358bef 100644 --- a/python/ray/serve/tests/test_advanced.py +++ b/python/ray/serve/tests/test_advanced.py @@ -9,11 +9,12 @@ def test_serve_forceful_shutdown(serve_instance): - @serve.deployment(_graceful_shutdown_timeout_s=0.1) + @serve.deployment def sleeper(): while True: time.sleep(1000) + sleeper._config.experimental_graceful_shutdown_timeout_s = 0.1 sleeper.deploy() handle = sleeper.get_handle() @@ -27,15 +28,14 @@ def sleeper(): def test_serve_graceful_shutdown(serve_instance): signal = SignalActor.remote() - @serve.deployment( - name="wait", - max_concurrent_queries=10, - _graceful_shutdown_timeout_s=1000, - _graceful_shutdown_wait_loop_s=0.5) + @serve.deployment(name="wait", max_concurrent_queries=10) class Wait: async def __call__(self, signal_actor): await signal_actor.wait.remote() + return "" + Wait._config.experimental_graceful_shutdown_wait_loop_s = 0.5 + Wait._config.experimental_graceful_shutdown_timeout_s = 1000 Wait.deploy() handle = Wait.get_handle() refs = [handle.remote(signal) for _ in range(10)] diff --git a/python/ray/serve/tests/test_autoscaling_metrics.py b/python/ray/serve/tests/test_autoscaling_metrics.py index d8a92d8a28b7a..e641f515d372d 100644 --- a/python/ray/serve/tests/test_autoscaling_metrics.py +++ b/python/ray/serve/tests/test_autoscaling_metrics.py @@ -59,20 +59,20 @@ def test_e2e(serve_instance): "min_replicas": 1, "max_replicas": 1 }, - # We will send over a lot of queries. This will make sure replicas are - # killed quickly during cleanup. - _graceful_shutdown_timeout_s=1, - max_concurrent_queries=1000, - version="v1") + max_concurrent_queries=1000) class A: def __call__(self): time.sleep(0.5) + # We will send over a lot of queries. This will make sure replicas are + # killed quickly during cleanup. + A._config.experimental_graceful_shutdown_timeout_s = 1 + A.deploy() handle = A.get_handle() [handle.remote() for _ in range(100)] - # Wait for metrics to propagate + # Wait for metrics to propogate def get_data(): return ray.get(serve_instance._controller. _dump_autoscaling_metrics_for_testing.remote()) diff --git a/python/ray/serve/tests/test_autoscaling_policy.py b/python/ray/serve/tests/test_autoscaling_policy.py index e72c2f68b65ce..56fb72ac4eea1 100644 --- a/python/ray/serve/tests/test_autoscaling_policy.py +++ b/python/ray/serve/tests/test_autoscaling_policy.py @@ -1,11 +1,3 @@ -import sys -import time -import pytest - -import ray -from ray import serve -from ray._private.test_utils import wait_for_condition -from ray.serve.backend_state import ReplicaState from ray.serve.config import AutoscalingConfig from ray.serve.autoscaling_policy import calculate_desired_num_replicas @@ -79,47 +71,3 @@ def test_smoothing_factor(self): autoscaling_config=config, current_num_ongoing_requests=num_ongoing_requests) assert 5 <= desired_num_replicas <= 8 # 10 + 0.5 * (2.5 - 10) = 6.25 - - -@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") -def test_e2e_basic_scale_up_down(serve_instance): - """Send 100 requests and check that we autoscale up, and then back down.""" - - @serve.deployment( - _autoscaling_config={ - "metrics_interval_s": 0.1, - "min_replicas": 1, - "max_replicas": 2, - "look_back_period_s": 0.2 - }, - # We will send over a lot of queries. This will make sure replicas are - # killed quickly during cleanup. - _graceful_shutdown_timeout_s=1, - max_concurrent_queries=1000, - version="v1") - class A: - def __call__(self): - time.sleep(1) - - A.deploy() - handle = A.get_handle() - [handle.remote() for _ in range(100)] - - controller = serve_instance._controller - - def get_num_running_replicas(): - replicas = ray.get( - controller._dump_replica_states_for_testing.remote("A")) - running_replicas = replicas.get([ReplicaState.RUNNING]) - return len(running_replicas) - - wait_for_condition(lambda: get_num_running_replicas() >= 2) - - # As the queue is drained, we should scale back down. - wait_for_condition(lambda: get_num_running_replicas() <= 1) - - -if __name__ == "__main__": - import sys - import pytest - sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_backend_state.py b/python/ray/serve/tests/test_backend_state.py index 0112868821388..aa31dc6a9d82a 100644 --- a/python/ray/serve/tests/test_backend_state.py +++ b/python/ray/serve/tests/test_backend_state.py @@ -1,5 +1,3 @@ -import os -import sys import time from typing import Any, Dict, List, Optional, Tuple from unittest.mock import patch, Mock @@ -183,7 +181,6 @@ def set_starting_version(self, version: BackendVersion): def start(self, backend_info: BackendInfo, version: BackendVersion): self.started = True self.version = version - self.backend_info = backend_info def update_user_config(self, user_config: Any): self.started = True @@ -221,7 +218,6 @@ def available_resources(self) -> Dict[str, float]: def graceful_stop(self) -> None: assert self.started self.stopped = True - return self.backend_info.backend_config.graceful_shutdown_timeout_s def check_stopped(self) -> bool: return self.done_stopping @@ -530,6 +526,9 @@ def test_create_delete_single_replica(mock_backend_state): # Now the replica should be marked running. backend_state.update() check_counts(backend_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) + + # TODO(edoakes): can we remove this extra update period for completing it? + backend_state.update() assert goal_manager.check_complete(create_goal) # Removing the replica should transition it to stopping. @@ -543,9 +542,12 @@ def test_create_delete_single_replica(mock_backend_state): # Once it's done stopping, replica should be removed. replica = backend_state._replicas.get()[0] replica._actor.set_done_stopping() + backend_state.update() + check_counts(backend_state, total=0) + + # TODO(edoakes): can we remove this extra update period for completing it? deleted = backend_state.update() assert deleted - check_counts(backend_state, total=0) assert goal_manager.check_complete(delete_goal) assert replica._actor.cleaned_up @@ -555,7 +557,7 @@ def test_force_kill(mock_backend_state): grace_period_s = 10 b_info_1, b_version_1 = backend_info( - graceful_shutdown_timeout_s=grace_period_s) + experimental_graceful_shutdown_timeout_s=grace_period_s) # Create and delete the backend. backend_state.deploy(b_info_1) @@ -569,8 +571,8 @@ def test_force_kill(mock_backend_state): check_counts(backend_state, total=1, by_state=[(ReplicaState.STOPPING, 1)]) assert backend_state._replicas.get()[0]._actor.stopped - for _ in range(10): - backend_state.update() + backend_state.update() + backend_state.update() # force_stop shouldn't be called until after the timer. assert not backend_state._replicas.get()[0]._actor.force_stopped_counter @@ -595,9 +597,12 @@ def test_force_kill(mock_backend_state): # Once the replica is done stopping, it should be removed. replica = backend_state._replicas.get()[0] replica._actor.set_done_stopping() + backend_state.update() + check_counts(backend_state, total=0) + + # TODO(edoakes): can we remove this extra update period for completing it? deleted = backend_state.update() assert deleted - check_counts(backend_state, total=0) assert goal_manager.check_complete(delete_goal) assert replica._actor.cleaned_up @@ -639,6 +644,8 @@ def test_redeploy_same_version(mock_backend_state): version=b_version_1, total=1, by_state=[(ReplicaState.RUNNING, 1)]) + + backend_state.update() assert goal_manager.check_complete(goal_1) # Test redeploying after the initial deployment has finished. @@ -720,10 +727,12 @@ def test_redeploy_no_version(mock_backend_state): states=[ReplicaState.STARTING])[0]._actor.set_ready() check_counts(backend_state, total=1, by_state=[(ReplicaState.STARTING, 1)]) - deleted = backend_state.update() - assert not deleted + backend_state.update() check_counts(backend_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) + + deleted = backend_state.update() assert goal_manager.check_complete(goal_3) + assert not deleted def test_redeploy_new_version(mock_backend_state): @@ -817,14 +826,16 @@ def test_redeploy_new_version(mock_backend_state): total=1, by_state=[(ReplicaState.STARTING, 1)]) - deleted = backend_state.update() - assert not deleted + backend_state.update() check_counts( backend_state, version=b_version_3, total=1, by_state=[(ReplicaState.RUNNING, 1)]) + + deleted = backend_state.update() assert goal_manager.check_complete(goal_3) + assert not deleted def test_deploy_new_config_same_version(mock_backend_state): @@ -844,6 +855,7 @@ def test_deploy_new_config_same_version(mock_backend_state): version=b_version_1, total=1, by_state=[(ReplicaState.RUNNING, 1)]) + backend_state.update() assert goal_manager.check_complete(goal_id) # Update to a new config without changing the version. @@ -874,6 +886,8 @@ def test_deploy_new_config_same_version(mock_backend_state): version=b_version_2, total=1, by_state=[(ReplicaState.RUNNING, 1)]) + + backend_state.update() assert goal_manager.check_complete(goal_id) @@ -893,6 +907,7 @@ def test_deploy_new_config_new_version(mock_backend_state): version=b_version_1, total=1, by_state=[(ReplicaState.RUNNING, 1)]) + backend_state.update() assert goal_manager.check_complete(create_goal) # Update to a new config and a new version. @@ -930,6 +945,8 @@ def test_deploy_new_config_new_version(mock_backend_state): version=b_version_2, total=1, by_state=[(ReplicaState.RUNNING, 1)]) + + backend_state.update() assert goal_manager.check_complete(update_goal) @@ -949,6 +966,8 @@ def test_initial_deploy_no_throttling(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() + backend_state.update() + # Check that the new replicas have started. backend_state.update() check_counts( @@ -975,6 +994,8 @@ def test_new_version_deploy_throttling(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() + backend_state.update() + # Check that the new replicas have started. backend_state.update() check_counts( @@ -1215,6 +1236,8 @@ def test_new_version_deploy_throttling(mock_backend_state): version=b_version_2, total=10, by_state=[(ReplicaState.RUNNING, 10)]) + + backend_state.update() assert goal_manager.check_complete(goal_2) @@ -1235,6 +1258,8 @@ def test_reconfigure_throttling(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() + backend_state.update() + # Check that the new replicas have started. backend_state.update() check_counts(backend_state, total=2, by_state=[(ReplicaState.RUNNING, 2)]) @@ -1293,6 +1318,8 @@ def test_reconfigure_throttling(mock_backend_state): version=b_version_2, total=2, by_state=[(ReplicaState.RUNNING, 2)]) + + backend_state.update() assert goal_manager.check_complete(goal_1) @@ -1314,6 +1341,8 @@ def test_new_version_and_scale_down(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() + backend_state.update() + # Check that the new replicas have started. backend_state.update() check_counts( @@ -1450,6 +1479,8 @@ def test_new_version_and_scale_down(mock_backend_state): version=b_version_2, total=2, by_state=[(ReplicaState.RUNNING, 2)]) + + backend_state.update() assert goal_manager.check_complete(goal_2) @@ -1470,6 +1501,8 @@ def test_new_version_and_scale_up(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() + backend_state.update() + # Check that the new replicas have started. backend_state.update() check_counts(backend_state, total=2, by_state=[(ReplicaState.RUNNING, 2)]) @@ -1577,6 +1610,8 @@ def test_health_check(mock_backend_state): # Check that the new replicas have started. backend_state.update() check_counts(backend_state, total=2, by_state=[(ReplicaState.RUNNING, 2)]) + + backend_state.update() assert goal_manager.check_complete(goal_1) backend_state.update() @@ -1824,9 +1859,6 @@ def mock_backend_state_manager( yield backend_state_manager, timer, goal_manager # Clear checkpoint at the end of each test kv_store.delete(CHECKPOINT_KEY) - if sys.platform != "win32": - # This line fails on windows with a PermissionError. - os.remove("test_kv_store.db") def test_shutdown(mock_backend_state_manager): @@ -1838,9 +1870,7 @@ def test_shutdown(mock_backend_state_manager): tag = "test" - grace_period_s = 10 - b_info_1, b_version_1 = backend_info( - graceful_shutdown_timeout_s=grace_period_s) + b_info_1, b_version_1 = backend_info() create_goal, updating = backend_state_manager.deploy_backend(tag, b_info_1) backend_state = backend_state_manager._backend_states[tag] @@ -1859,21 +1889,25 @@ def test_shutdown(mock_backend_state_manager): shutdown_goal = backend_state_manager.shutdown()[0] - timer.advance(grace_period_s + 0.1) backend_state_manager.update() check_counts(backend_state, total=1, by_state=[(ReplicaState.STOPPING, 1)]) assert backend_state._replicas.get()[0]._actor.stopped + assert backend_state._replicas.get()[0]._actor.force_stopped_counter == 1 assert not backend_state._replicas.get()[0]._actor.cleaned_up assert not goal_manager.check_complete(shutdown_goal) # Once it's done stopping, replica should be removed. replica = backend_state._replicas.get()[0] replica._actor.set_done_stopping() - backend_state_manager.update() + backend_state.update() check_counts(backend_state, total=0) + + # TODO(edoakes): can we remove this extra update period for completing it? + backend_state_manager.update() assert goal_manager.check_complete(shutdown_goal) assert replica._actor.cleaned_up + assert len(backend_state_manager._backend_states) == 0 @@ -1940,4 +1974,5 @@ def test_resume_backend_state_from_replica_tags(mock_backend_state_manager): if __name__ == "__main__": + import sys sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_config.py b/python/ray/serve/tests/test_config.py index 8c71cf8ae3a91..dd30aeab0f77f 100644 --- a/python/ray/serve/tests/test_config.py +++ b/python/ray/serve/tests/test_config.py @@ -52,8 +52,6 @@ def function(_): # Check ray_actor_options validation. ReplicaConfig( Class, - tuple(), - dict(), ray_actor_options={ "num_cpus": 1.0, "num_gpus": 10, diff --git a/python/ray/serve/tests/test_deploy.py b/python/ray/serve/tests/test_deploy.py index ab5ab2e2f3d75..d593081a43a9b 100644 --- a/python/ray/serve/tests/test_deploy.py +++ b/python/ray/serve/tests/test_deploy.py @@ -10,7 +10,6 @@ import ray from ray._private.test_utils import SignalActor, wait_for_condition from ray import serve -from ray.serve.exceptions import RayServeException from ray.serve.utils import get_random_letters @@ -677,8 +676,8 @@ def b(self, *args): assert ray.get(handle.options(method_name="b").remote()) == "hello" # New code path assert ray.get(handle.b.remote()) == "hello" - with pytest.raises(RayServeException): - ray.get(handle.c.remote()) + with pytest.raises(AttributeError): + handle.c.remote() def test_init_args(serve_instance): @@ -734,58 +733,6 @@ def check(*args): check(10, 11, 12) -def test_init_kwargs(serve_instance): - with pytest.raises(TypeError): - - @serve.deployment(init_kwargs=[1, 2, 3]) - class BadInitArgs: - pass - - @serve.deployment(init_kwargs={"a": 1, "b": 2}) - class D: - def __init__(self, **kwargs): - self._kwargs = kwargs - - def get_kwargs(self, *args): - return self._kwargs - - D.deploy() - handle = D.get_handle() - - def check(kwargs): - assert ray.get(handle.get_kwargs.remote()) == kwargs - - # Basic sanity check. - check({"a": 1, "b": 2}) - - # Check passing args to `.deploy()`. - D.deploy(a=3, b=4) - check({"a": 3, "b": 4}) - - # Passing args to `.deploy()` shouldn't override those passed in decorator. - D.deploy() - check({"a": 1, "b": 2}) - - # Check setting with `.options()`. - new_D = D.options(init_kwargs={"c": 8, "d": 10}) - new_D.deploy() - check({"c": 8, "d": 10}) - - # Should not have changed old deployment object. - D.deploy() - check({"a": 1, "b": 2}) - - # Check that args are only updated on version change. - D.options(version="1").deploy() - check({"a": 1, "b": 2}) - - D.options(version="1").deploy(c=10, d=11) - check({"a": 1, "b": 2}) - - D.options(version="2").deploy(c=10, d=11) - check({"c": 10, "d": 11}) - - def test_input_validation(): name = "test" diff --git a/python/ray/serve/tests/test_get_deployment.py b/python/ray/serve/tests/test_get_deployment.py index 1f6968abe4974..cb1d6c9484e31 100644 --- a/python/ray/serve/tests/test_get_deployment.py +++ b/python/ray/serve/tests/test_get_deployment.py @@ -116,37 +116,6 @@ def __call__(self, *arg): assert pid3 != pid2 -def test_init_kwargs(serve_instance): - name = "test" - - @serve.deployment(name=name) - class D: - def __init__(self, *, val=None): - assert val is not None - self._val = val - - def __call__(self, *arg): - return self._val, os.getpid() - - D.deploy(val="1") - val1, pid1 = ray.get(D.get_handle().remote()) - assert val1 == "1" - - del D - - D2 = serve.get_deployment(name=name) - D2.deploy() - val2, pid2 = ray.get(D2.get_handle().remote()) - assert val2 == "1" - assert pid2 != pid1 - - D2 = serve.get_deployment(name=name) - D2.deploy(val="2") - val3, pid3 = ray.get(D2.get_handle().remote()) - assert val3 == "2" - assert pid3 != pid2 - - def test_scale_replicas(serve_instance): name = "test" diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index 360fb3336b247..95c55aba35b3e 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -1,10 +1,9 @@ -import concurrent.futures import pytest import requests import ray +import concurrent.futures from ray import serve -from ray.serve.exceptions import RayServeException @pytest.mark.asyncio @@ -168,30 +167,6 @@ def call(): ray.get(obj_ref) -@pytest.mark.asyncio -@pytest.mark.parametrize("sync", [True, False]) -async def test_nonexistent_method(serve_instance, sync): - @serve.deployment - class A: - def exists(self): - pass - - A.deploy() - handle = A.get_handle(sync=sync) - - if sync: - obj_ref = handle.does_not_exist.remote() - else: - obj_ref = await handle.does_not_exist.remote() - - with pytest.raises(RayServeException) as excinfo: - ray.get(obj_ref) - - exception_string = str(excinfo.value) - assert "'does_not_exist'" in exception_string - assert "Available methods: ['exists']" in exception_string - - if __name__ == "__main__": import sys import pytest diff --git a/python/ray/serve/tests/test_long_poll.py b/python/ray/serve/tests/test_long_poll.py index 2081e705d976e..79cf0c841ea35 100644 --- a/python/ray/serve/tests/test_long_poll.py +++ b/python/ray/serve/tests/test_long_poll.py @@ -37,20 +37,6 @@ def test_host_standalone(serve_instance): assert "key_2" in result -def test_long_poll_wait_for_keys(serve_instance): - # Variation of the basic case, but the keys are requests before any values - # are set. - host = ray.remote(LongPollHost).remote() - object_ref = host.listen_for_change.remote({"key_1": -1, "key_2": -1}) - ray.get(host.notify_changed.remote("key_1", 999)) - ray.get(host.notify_changed.remote("key_2", 999)) - - # We should be able to get the one of the result immediately - result: Dict[str, UpdatedObject] = ray.get(object_ref) - assert set(result.keys()).issubset({"key_1", "key_2"}) - assert {v.object_snapshot for v in result.values()} == {999} - - def test_long_poll_restarts(serve_instance): @ray.remote( max_restarts=-1, diff --git a/python/ray/serve/tests/test_ray_client.py b/python/ray/serve/tests/test_ray_client.py index db640970eedbe..7bc2d54aad388 100644 --- a/python/ray/serve/tests/test_ray_client.py +++ b/python/ray/serve/tests/test_ray_client.py @@ -126,7 +126,7 @@ def hello(request): @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows") -def test_quickstart_counter(serve_with_client): +def test_quickstart_task(serve_with_client): serve.start() @serve.deployment @@ -140,13 +140,10 @@ def __call__(self, *args): # Deploy our class. Counter.deploy() - print("deploy finished") # Query our endpoint in two different ways: from HTTP and from Python. assert requests.get("http://127.0.0.1:8000/Counter").json() == {"count": 1} - print("query 1 finished") assert ray.get(Counter.get_handle().remote()) == {"count": 2} - print("query 2 finished") if __name__ == "__main__": diff --git a/python/ray/serve/tests/test_regression.py b/python/ray/serve/tests/test_regression.py index 9ac205803e492..e4d519bc06ffc 100644 --- a/python/ray/serve/tests/test_regression.py +++ b/python/ray/serve/tests/test_regression.py @@ -71,7 +71,7 @@ async def __call__(self, _request): assert result.json() == 100.0 -def test_replica_memory_growth(serve_instance): +def test_backend_worker_memory_growth(serve_instance): # https://github.com/ray-project/ray/issues/12395 @serve.deployment(name="model") def gc_unreachable_objects(*args): diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index 1c6df064247f5..ce8183b2d8577 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -27,6 +27,13 @@ import ray._private.gcs_utils as gcs_utils +@pytest.fixture +def ray_shutdown(): + yield + serve.shutdown() + ray.shutdown() + + @pytest.fixture def ray_cluster(): cluster = Cluster() @@ -95,7 +102,7 @@ def test_detached_deployment(ray_cluster): # https://github.com/ray-project/ray/issues/11437 cluster = ray_cluster - head_node = cluster.add_node(num_cpus=6) + head_node = cluster.add_node(node_ip_address="127.0.0.1", num_cpus=6) # Create first job, check we can run a simple serve endpoint ray.init(head_node.address, namespace="serve") diff --git a/python/ray/sgd/__init__.py b/python/ray/sgd/__init__.py index d5f8ec4c0d6f1..c5d4677aa041e 100644 --- a/python/ray/sgd/__init__.py +++ b/python/ray/sgd/__init__.py @@ -1 +1,2 @@ -from ray.util.sgd.v2 import * # noqa: F401, F403 +from ray.util.sgd.v2 import * # noqa: F401, F403 +from ray.util.sgd.v2.callbacks import JsonLoggerCallback, TBXLoggerCallback # noqa: E501, F401, F403 diff --git a/python/ray/sgd/callbacks.py b/python/ray/sgd/callbacks.py deleted file mode 100644 index 9b85815190b9b..0000000000000 --- a/python/ray/sgd/callbacks.py +++ /dev/null @@ -1 +0,0 @@ -from ray.util.sgd.v2.callbacks import * # noqa: E501, F401, F403 diff --git a/python/ray/state.py b/python/ray/state.py index b074fd4062641..3c2f2185caffb 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -1,6 +1,7 @@ from collections import defaultdict import json import logging +import os import ray @@ -49,6 +50,10 @@ def _check_connected(self): # _really_init_global_state should have set self.global_state_accessor if self.global_state_accessor is None: + if os.environ.get("RAY_ENABLE_AUTO_CONNECT", "") != "0": + ray.client().connect() + # Retry connect! + return self._check_connected() raise ray.exceptions.RaySystemError( "Ray has not been started yet. You can start Ray with " "'ray.init()'.") @@ -715,7 +720,6 @@ def _live_node_ids(self): def _available_resources_per_node(self): """Returns a dictionary mapping node id to avaiable resources.""" - self._check_connected() available_resources_by_id = {} all_available_resources = \ @@ -807,7 +811,7 @@ def next_job_id(): @DeveloperAPI -@client_mode_hook(auto_init=False) +@client_mode_hook def nodes(): """Get a list of the nodes in the cluster (for debugging only). @@ -871,7 +875,7 @@ def actors(actor_id=None): @DeveloperAPI -@client_mode_hook(auto_init=False) +@client_mode_hook def timeline(filename=None): """Return a list of profiling events that can viewed as a timeline. @@ -913,7 +917,7 @@ def object_transfer_timeline(filename=None): @DeveloperAPI -@client_mode_hook(auto_init=False) +@client_mode_hook def cluster_resources(): """Get the current total cluster resources. @@ -928,7 +932,7 @@ def cluster_resources(): @DeveloperAPI -@client_mode_hook(auto_init=False) +@client_mode_hook def available_resources(): """Get the current available cluster resources. diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 1fad95edee1e3..f854f00e560e7 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -48,7 +48,6 @@ py_test_module_list( files = [ "test_client.py", "test_client_builder.py", - "test_client_compat.py", "test_client_init.py", "test_client_multi.py", "test_client_proxy.py", @@ -78,12 +77,12 @@ py_test_module_list( "test_placement_group.py", "test_placement_group_2.py", "test_placement_group_3.py", + "test_placement_group_mini_integration.py", "test_ray_init.py", "test_reconstruction.py", "test_reference_counting.py", "test_resource_demand_scheduler.py", "test_runtime_env_env_vars.py", - "test_runtime_env_plugin.py", "test_runtime_env_fork_process.py", "test_serialization.py", "test_shuffle.py", @@ -102,7 +101,6 @@ py_test_module_list( py_test_module_list( files = [ - "test_autoscaler_fake_multinode.py", # Temporarily owned by core. "test_args.py", "test_asyncio_cluster.py", "test_asyncio.py", @@ -169,7 +167,6 @@ py_test_module_list( "test_failure_4.py", "test_object_spilling.py", "test_plasma_unlimited.py", - "test_placement_group_mini_integration.py", ], size = "large", extra_srcs = SRCS, @@ -303,14 +300,6 @@ py_test( deps = ["//:ray_lib"], ) -py_test( - name = "test_runtime_env_validation", - size = "small", - srcs = SRCS + ["test_runtime_env_validation.py"], - tags = ["exclusive", "team:serve"], - deps = ["//:ray_lib"], -) - # TODO(ekl) we can't currently support tagging these as flaky since there's # no way to filter by both flaky and client mode tests in bazel. py_test_module_list( diff --git a/python/ray/tests/client_test_utils.py b/python/ray/tests/client_test_utils.py index 30c016d32bd3a..c7b0081d3274c 100644 --- a/python/ray/tests/client_test_utils.py +++ b/python/ray/tests/client_test_utils.py @@ -18,20 +18,3 @@ async def wait(self, should_wait=True): await self.ready_event.wait() return SignalActor - - -# See test_client::test_wrapped_actor_creation for details on usage of -# run_wrapped_actor_creation and SomeClass. -def run_wrapped_actor_creation(): - import ray - RemoteClass = ray.remote(SomeClass) - handle = RemoteClass.remote() - return ray.get(handle.ready.remote()) - - -class SomeClass: - def __init__(self): - pass - - def ready(self): - return 1 diff --git a/python/ray/tests/mock_setup_worker.py b/python/ray/tests/mock_setup_worker.py index 7cd981b9ac00f..a19a9ce22d1fd 100644 --- a/python/ray/tests/mock_setup_worker.py +++ b/python/ray/tests/mock_setup_worker.py @@ -30,9 +30,6 @@ parser.add_argument( "--session-dir", type=str, help="the directory for the current session") -parser.add_argument( - "--language", type=str, help="the language type of the worker") - args, remaining_args = parser.parse_known_args() # add worker-shim-pid argument diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py index 041e5e7bb559a..b7962ff71e44b 100644 --- a/python/ray/tests/test_advanced.py +++ b/python/ray/tests/test_advanced.py @@ -777,13 +777,14 @@ def method(self): # This case tests whether RequestWorkerLeaseReply carries normal task resources # when the request is rejected (due to resource preemption by normal tasks). -@pytest.mark.skipif(sys.platform == "win32", reason="Time out on Windows") +@pytest.mark.skip( + reason="The period of pull based resource report (10ms) is hard-coded.") def test_worker_lease_reply_with_resources(ray_start_cluster): cluster = ray_start_cluster cluster.add_node( memory=2000 * 1024**2, _system_config={ - "gcs_resource_report_poll_period_ms": 1000000, + "raylet_report_resources_period_milliseconds": 1000000, "gcs_actor_scheduling_enabled": True, }) node2 = cluster.add_node(memory=1000 * 1024**2) diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index 90d3de16dd60d..a03850916328a 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -1,6 +1,5 @@ # coding: utf-8 import glob -import json import logging import os import sys @@ -727,19 +726,20 @@ def test_k8s_cpu(): def test_sync_job_config(shutdown_only): num_java_workers_per_process = 8 - runtime_env = {"env_vars": {"key": "value"}} + worker_env = { + "key": "value", + } ray.init( job_config=ray.job_config.JobConfig( num_java_workers_per_process=num_java_workers_per_process, - runtime_env=runtime_env)) + worker_env=worker_env)) # Check that the job config is synchronized at the driver side. job_config = ray.worker.global_worker.core_worker.get_job_config() assert (job_config.num_java_workers_per_process == num_java_workers_per_process) - job_runtime_env = json.loads(job_config.runtime_env.serialized_runtime_env) - assert job_runtime_env["env_vars"] == runtime_env["env_vars"] + assert (job_config.worker_env == worker_env) @ray.remote def get_job_config(): @@ -751,8 +751,7 @@ def get_job_config(): job_config.ParseFromString(ray.get(get_job_config.remote())) assert (job_config.num_java_workers_per_process == num_java_workers_per_process) - job_runtime_env = json.loads(job_config.runtime_env.serialized_runtime_env) - assert job_runtime_env["env_vars"] == runtime_env["env_vars"] + assert (job_config.worker_env == worker_env) def test_duplicated_arg(ray_start_cluster): diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index 4cc7ee63570fe..d428188173cbd 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -1,7 +1,6 @@ import json import jsonschema import os -import re import shutil from subprocess import CalledProcessError import tempfile @@ -14,7 +13,7 @@ from collections import defaultdict from ray.autoscaler._private.commands import get_or_create_head_node from jsonschema.exceptions import ValidationError -from typing import Dict, Callable, List, Optional +from typing import Dict, Callable import ray from ray.autoscaler._private.util import prepare_config, validate_config @@ -106,56 +105,42 @@ def check_output(self, cmd): return return_string.encode() - def assert_has_call(self, - ip: str, - pattern: Optional[str] = None, - exact: Optional[List[str]] = None): - """Checks if the given value was called by this process runner. - - NOTE: Either pattern or exact must be specified, not both! - - Args: - ip: IP address of the node that the given call was executed on. - pattern: RegEx that matches one specific call. - exact: List of strings that when joined exactly match one call. - """ + def assert_has_call(self, ip, pattern=None, exact=None): with self.lock: - assert bool(pattern) ^ bool(exact), \ + assert pattern or exact, \ "Must specify either a pattern or exact match." - debug_output = "" + out = "" if pattern is not None: for cmd in self.command_history(): if ip in cmd: - debug_output += cmd - debug_output += "\n" - if re.search(pattern, cmd): - return True + out += cmd + out += "\n" + if pattern in out: + return True else: raise Exception( - f"Did not find [{pattern}] in [{debug_output}] for " - f"ip={ip}.\n\nFull output: {self.command_history()}") + f"Did not find [{pattern}] in [{out}] for ip={ip}." + f"\n\nFull output: {self.command_history()}") elif exact is not None: exact_cmd = " ".join(exact) for cmd in self.command_history(): if ip in cmd: - debug_output += cmd - debug_output += "\n" + out += cmd + out += "\n" if cmd == exact_cmd: return True raise Exception( - f"Did not find [{exact_cmd}] in [{debug_output}] for " - f"ip={ip}.\n\nFull output: {self.command_history()}") + f"Did not find [{exact_cmd}] in [{out}] for ip={ip}." + f"\n\nFull output: {self.command_history()}") - def assert_not_has_call(self, ip: str, pattern: str): - """Ensure that the given regex pattern was never called. - """ + def assert_not_has_call(self, ip, pattern): with self.lock: out = "" for cmd in self.command_history(): if ip in cmd: out += cmd out += "\n" - if re.search(pattern, out): + if pattern in out: raise Exception("Found [{}] in [{}] for {}".format( pattern, out, ip)) else: @@ -464,10 +449,7 @@ def waitFor(self, condition, num_retries=50, fail_msg=None): fail_msg = fail_msg or "Timed out waiting for {}".format(condition) raise RayTestTimeoutException(fail_msg) - def waitForNodes(self, expected, comparison=None, tag_filters=None): - if tag_filters is None: - tag_filters = {} - + def waitForNodes(self, expected, comparison=None, tag_filters={}): MAX_ITER = 50 for i in range(MAX_ITER): n = len(self.provider.non_terminated_nodes(tag_filters)) @@ -2578,7 +2560,8 @@ def testContinuousFileMounts(self): for i in [0, 1]: runner.assert_not_has_call(f"172.0.0.{i}", "setup_cmd") runner.assert_has_call( - f"172.0.0.{i}", f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" + f"172.0.0.{i}", f"172.0.0.{i}", + f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" f"{docker_mount_prefix}/home/test-folder/") def testFileMountsNonContinuous(self): @@ -2613,7 +2596,8 @@ def testFileMountsNonContinuous(self): for i in [0, 1]: runner.assert_has_call(f"172.0.0.{i}", "setup_cmd") runner.assert_has_call( - f"172.0.0.{i}", f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" + f"172.0.0.{i}", f"172.0.0.{i}", + f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" f"{docker_mount_prefix}/home/test-folder/") runner.clear_history() @@ -2656,7 +2640,8 @@ def testFileMountsNonContinuous(self): for i in [0, 1]: runner.assert_has_call(f"172.0.0.{i}", "setup_cmd") runner.assert_has_call( - f"172.0.0.{i}", f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" + f"172.0.0.{i}", f"172.0.0.{i}", + f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" f"{docker_mount_prefix}/home/test-folder/") def testAutodetectResources(self): diff --git a/python/ray/tests/test_autoscaler_fake_multinode.py b/python/ray/tests/test_autoscaler_fake_multinode.py deleted file mode 100644 index 1f6c96b3d3c19..0000000000000 --- a/python/ray/tests/test_autoscaler_fake_multinode.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest -import platform - -import ray -from ray.cluster_utils import AutoscalingCluster - - -@pytest.mark.skipif( - platform.system() == "Windows", reason="Failing on Windows.") -def test_fake_autoscaler_basic_e2e(shutdown_only): - cluster = AutoscalingCluster( - head_resources={"CPU": 2}, - worker_node_types={ - "cpu_node": { - "resources": { - "CPU": 4, - "object_store_memory": 1024 * 1024 * 1024, - }, - "node_config": {}, - "min_workers": 0, - "max_workers": 2, - }, - "gpu_node": { - "resources": { - "CPU": 2, - "GPU": 1, - "object_store_memory": 1024 * 1024 * 1024, - }, - "node_config": {}, - "min_workers": 0, - "max_workers": 2, - }, - }) - - try: - cluster.start() - ray.init("auto") - - # Triggers the addition of a GPU node. - @ray.remote(num_gpus=1) - def f(): - print("gpu ok") - - # Triggers the addition of a CPU node. - @ray.remote(num_cpus=3) - def g(): - print("cpu ok") - - ray.get(f.remote()) - ray.get(g.remote()) - ray.shutdown() - finally: - cluster.shutdown() - - -if __name__ == "__main__": - import sys - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_autoscaler_yaml.py b/python/ray/tests/test_autoscaler_yaml.py index ad4ef152acd9e..137d188c8caaf 100644 --- a/python/ray/tests/test_autoscaler_yaml.py +++ b/python/ray/tests/test_autoscaler_yaml.py @@ -91,9 +91,6 @@ def testValidateDefaultConfig(self): if "local" in config_path: # local tested in testValidateLocal continue - if "fake_multi_node" in config_path: - # not supported with ray up - continue with open(config_path) as f: config = yaml.safe_load(f) config = prepare_config(config) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index ad4d844b7c304..d5b73ece9bf54 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -76,7 +76,8 @@ def test_omp_threads_set(shutdown_only): assert os.environ["OMP_NUM_THREADS"] == "1" -def test_submit_api(shutdown_only): +@pytest.mark.parametrize("use_tls", [False, True], indirect=True) +def test_submit_api(shutdown_only, use_tls): ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) @ray.remote @@ -140,7 +141,8 @@ def method(self): assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] -def test_invalid_arguments(shutdown_only): +@pytest.mark.parametrize("use_tls", [False, True], indirect=True) +def test_invalid_arguments(shutdown_only, use_tls): ray.init(num_cpus=2) for opt in [np.random.randint(-100, -1), np.random.uniform(0, 1)]: @@ -236,7 +238,8 @@ def check(): {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}) -def test_put_get(shutdown_only): +@pytest.mark.parametrize("use_tls", [False, True], indirect=True) +def test_put_get(shutdown_only, use_tls): ray.init(num_cpus=0) for i in range(100): @@ -265,7 +268,8 @@ def test_put_get(shutdown_only): @pytest.mark.skipif(sys.platform != "linux", reason="Failing on Windows") -def test_wait_timing(shutdown_only): +@pytest.mark.parametrize("use_tls", [False, True], indirect=True) +def test_wait_timing(shutdown_only, use_tls): ray.init(num_cpus=2) @ray.remote @@ -299,7 +303,8 @@ def test_function_descriptor(): assert d.get(python_descriptor2) == 123 -def test_ray_options(shutdown_only): +@pytest.mark.parametrize("use_tls", [False, True], indirect=True) +def test_ray_options(shutdown_only, use_tls): ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2}) @ray.remote( diff --git a/python/ray/tests/test_basic_3.py b/python/ray/tests/test_basic_3.py index 9e050e0b04979..400f79c407b8f 100644 --- a/python/ray/tests/test_basic_3.py +++ b/python/ray/tests/test_basic_3.py @@ -168,16 +168,7 @@ def f(): @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") def test_fair_queueing(shutdown_only): - ray.init( - num_cpus=1, - _system_config={ - # Having parallel leases is slow in this case - # because tasks are scheduled FIFO, - # the more parallism we have, - # the more workers we need to start to execute f and g tasks - # before we can execute the first h task. - "max_pending_lease_requests_per_scheduling_category": 1 - }) + ray.init(num_cpus=1) @ray.remote def h(): diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index 0f6dcadb10cbc..de552b1fe2977 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -6,11 +6,9 @@ import queue import threading import _thread -from unittest.mock import patch import ray.util.client.server.server as ray_client_server from ray.tests.client_test_utils import create_remote_signal_actor -from ray.tests.client_test_utils import run_wrapped_actor_creation from ray.util.client.common import ClientObjectRef from ray.util.client.ray_client_helpers import connect_to_client_or_not from ray.util.client.ray_client_helpers import ray_start_client_server @@ -26,11 +24,11 @@ def test_client_context_manager(ray_start_regular_shared, connect_to_client): with connect_to_client_or_not(connect_to_client): if connect_to_client: # Client mode is on. - assert client_mode_should_convert(auto_init=True) + assert client_mode_should_convert() # We're connected to Ray client. assert ray.util.client.ray.is_connected() else: - assert not client_mode_should_convert(auto_init=True) + assert not client_mode_should_convert() assert not ray.util.client.ray.is_connected() @@ -72,20 +70,20 @@ def run(self): def test_client_mode_hook_thread_safe(ray_start_regular_shared): with ray_start_client_server(): with enable_client_mode(): - assert client_mode_should_convert(auto_init=True) + assert client_mode_should_convert() lock = threading.Lock() lock.acquire() q = queue.Queue() def disable(): with disable_client_hook(): - q.put(client_mode_should_convert(auto_init=True)) + q.put(client_mode_should_convert()) lock.acquire() - q.put(client_mode_should_convert(auto_init=True)) + q.put(client_mode_should_convert()) t = threading.Thread(target=disable) t.start() - assert client_mode_should_convert(auto_init=True) + assert client_mode_should_convert() lock.release() t.join() assert q.get( @@ -469,11 +467,8 @@ def print_on_stderr_and_stdout(s): time.sleep(1) print_on_stderr_and_stdout.remote("Hello world") time.sleep(1) - num_hello = 0 - for msg in log_msgs: - if "Hello world" in msg: - num_hello += 1 - assert num_hello == 2, f"Invalid logs: {log_msgs}" + assert len(log_msgs) == 2 + assert all((msg.find("Hello world") for msg in log_msgs)) @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") @@ -653,7 +648,6 @@ def stop_server(server): @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") -@patch.dict(os.environ, {"RAY_ENABLE_AUTO_CONNECT": "0"}) def test_client_gpu_ids(call_ray_stop_only): import ray ray.init(num_cpus=2) @@ -708,42 +702,7 @@ def test_object_ref_cleanup(): # See https://github.com/ray-project/ray/issues/17968 for details with ray_start_client_server(): result = run_string_as_driver(object_ref_cleanup_script) - assert "Error in sys.excepthook:" not in result - assert "AttributeError: 'NoneType' object has no " not in result - assert "Exception ignored in" not in result - - -@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") -@pytest.mark.parametrize( - "call_ray_start", - ["ray start --head --ray-client-server-port 25552 --port 0"], - indirect=True) -def test_wrapped_actor_creation(call_ray_start): - """ - When the client schedules an actor, the server will load a separate - copy of the actor class if it's defined in a separate file. This - means that modifications to the client's copy of the actor class - aren't propagated to the server. Currently, tracing logic modifies - the signatures of actor methods to pass around metadata when ray.remote - is applied to an actor class. However, if a user does something like: - - class SomeActor: - def __init__(self): - pass - - def decorate_actor(): - RemoteActor = ray.remote(SomeActor) - ... - - Then the SomeActor class will have its signatures modified on the client - side, but not on the server side, since ray.remote was applied inside of - the function instead of directly on the actor. Note if it were directly - applied to the actor then the signature would be modified when the server - imports the class. - """ - import ray - ray.init("ray://localhost:25552") - run_wrapped_actor_creation() + assert result == "" if __name__ == "__main__": diff --git a/python/ray/tests/test_client_compat.py b/python/ray/tests/test_client_compat.py deleted file mode 100644 index 98f4e9f4ba43d..0000000000000 --- a/python/ray/tests/test_client_compat.py +++ /dev/null @@ -1,33 +0,0 @@ -import pytest -import sys - -import ray -try: - import pyspark # noqa -except ImportError: - pyspark = None - - -@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") -@pytest.mark.skipif(pyspark is None, reason="PySpark dependency not found") -@pytest.mark.parametrize( - "call_ray_start", [ - "ray start --head --num-cpus=1 --min-worker-port=0 " - "--max-worker-port=0 --port 0 --ray-client-server-port 10002", - ], - indirect=True) -def test_client_data_get(call_ray_start): - """PySpark import changes NamedTuple pickling behavior, leading - to inconpatibilities with the Ray client and Ray Data. This test - makes sure that our fix in the ClientPickler works.""" - address = call_ray_start - ip = address.split(":")[0] - - ray.util.connect(f"{ip}:10002") - - ray_pipeline = ray.data.from_items(list(range(1_000))) - ray.get(ray_pipeline.to_numpy()[0]) - - -if __name__ == "__main__": - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_client_library_integration.py b/python/ray/tests/test_client_library_integration.py index 417b31efb5e3b..774f46954d045 100644 --- a/python/ray/tests/test_client_library_integration.py +++ b/python/ray/tests/test_client_library_integration.py @@ -14,11 +14,11 @@ def test_rllib_integration(ray_start_regular_shared): import ray.rllib.agents.dqn as dqn # Confirming the behavior of this context manager. # (Client mode hook not yet enabled.) - assert not client_mode_should_convert(auto_init=True) + assert not client_mode_should_convert() # Need to enable this for client APIs to be used. with enable_client_mode(): # Confirming mode hook is enabled. - assert client_mode_should_convert(auto_init=True) + assert client_mode_should_convert() config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy() # Run locally. @@ -38,11 +38,11 @@ def test_rllib_integration_tune(ray_start_regular_shared): with ray_start_client_server(): # Confirming the behavior of this context manager. # (Client mode hook not yet enabled.) - assert not client_mode_should_convert(auto_init=True) + assert not client_mode_should_convert() # Need to enable this for client APIs to be used. with enable_client_mode(): # Confirming mode hook is enabled. - assert client_mode_should_convert(auto_init=True) + assert client_mode_should_convert() tune.run( "DQN", config={"env": "CartPole-v1"}, diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index 8440268da6980..03d1f34cb6582 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -253,10 +253,7 @@ def test_prepare_runtime_init_req_no_modification(): """ Check that `prepare_runtime_init_req` properly extracts the JobConfig. """ - job_config = JobConfig( - runtime_env={"env_vars": { - "KEY": "VALUE" - }}, ray_namespace="abc") + job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest( init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config), @@ -276,10 +273,7 @@ def test_prepare_runtime_init_req_modified_job(): Check that `prepare_runtime_init_req` properly extracts the JobConfig and modifies it according to `ray_client_server_env_prep`. """ - job_config = JobConfig( - runtime_env={"env_vars": { - "KEY": "VALUE" - }}, ray_namespace="abc") + job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest( init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config), diff --git a/python/ray/tests/test_client_reconnect.py b/python/ray/tests/test_client_reconnect.py index 0672b755f9eb1..b830403449ba3 100644 --- a/python/ray/tests/test_client_reconnect.py +++ b/python/ray/tests/test_client_reconnect.py @@ -294,7 +294,6 @@ def disconnect(middleman): disconnect_thread.join() -@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on windows") def test_valid_actor_state(): """ Repeatedly inject errors in the middle of mutating actor calls. Check @@ -312,28 +311,24 @@ def incr(self): return self.val i = 0 - # This is to prevent erroring in the initial connection logic. - started = False def fail_every_seven(_): # Inject an error every seventh time this method is called - nonlocal i, started + nonlocal i i += 1 - if i % 7 == 0 and started: + if i % 7 == 0: raise RuntimeError with start_middleman_server( on_data_response=fail_every_seven, on_task_request=fail_every_seven, on_task_response=fail_every_seven): - started = True actor = IncrActor.remote() for _ in range(100): ref = actor.incr.remote() assert ray.get(ref) == 100 -@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on windows") def test_valid_actor_state_2(): """ Do a full disconnect (cancel channel) every 11 requests. Failure diff --git a/python/ray/tests/test_dashboard.py b/python/ray/tests/test_dashboard.py index 578707baebf4a..c92d9610ead84 100644 --- a/python/ray/tests/test_dashboard.py +++ b/python/ray/tests/test_dashboard.py @@ -4,34 +4,14 @@ import sys import time -import psutil import pytest import requests -from ray._private.test_utils import (run_string_as_driver, wait_for_condition, - get_error_message) +from ray._private.test_utils import run_string_as_driver, wait_for_condition import ray from ray import ray_constants -def search_agents(cluster): - all_processes = cluster.head_node.all_processes - raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0] - raylet_proc = psutil.Process(raylet_proc_info.process.pid) - - def _search_agent(processes): - for p in processes: - try: - for c in p.cmdline(): - if "dashboard/agent.py" in c: - return p - except Exception: - pass - - agent_proc = _search_agent(raylet_proc.children()) - return agent_proc - - def test_ray_start_default_port_conflict(call_ray_stop_only, shutdown_only): subprocess.check_call(["ray", "start", "--head"]) ray.init(address="auto") @@ -110,6 +90,8 @@ def test_port_conflict(call_ray_stop_only, shutdown_only): sock.close() +@pytest.mark.skipif( + sys.version_info < (3, 5, 3), reason="requires python3.5.3 or higher") def test_dashboard(shutdown_only): addresses = ray.init(include_dashboard=True, num_cpus=1) dashboard_url = addresses["webui_url"] @@ -139,32 +121,8 @@ def test_dashboard(shutdown_only): f"Dashboard output log: {out_log}\n") -@pytest.mark.parametrize( - "ray_start_cluster_head", [{ - "metrics_export_port": 6379, - "_system_config": { - "agent_restart_interval_ms": 10, - "agent_max_restart_count": 5 - } - }], - indirect=True) -def test_dashboard_agent_restart(ray_start_cluster_head, error_pubsub): - """Test that when the agent fails to start many times in a row - if the error message is suppressed correctly without spamming - the driver. - """ - # Choose a duplicated port for the agent so that it will crash. - p = error_pubsub - errors = get_error_message( - p, 1, ray_constants.DASHBOARD_AGENT_DIED_ERROR, timeout=10) - for e in errors: - assert ("There are 2 possible problems " - "if you see this error." in e.error_message) - # Make sure the agent process is not started anymore. - cluster = ray_start_cluster_head - wait_for_condition(lambda: search_agents(cluster) is None) - - if __name__ == "__main__": + import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_distributed_sort.py b/python/ray/tests/test_distributed_sort.py index 75cb682b165e8..55cc7e37ebdfd 100644 --- a/python/ray/tests/test_distributed_sort.py +++ b/python/ray/tests/test_distributed_sort.py @@ -4,19 +4,14 @@ from ray.experimental.raysort import main -@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows") def test_distributed_sort(): - args = main.get_args([ - "--total_data_size=1_000_000_000", - "--num_mappers=4", - "--num_reducers=4", - "--num_mappers_per_round=2", - "--ray_address=", - "--skip_input", - "--skip_output", - ]) - main.main(args) + main.args = main.get_args() + main.args.ray_address = None + main.args.total_data_size = 1_000_000_000 + main.args.skip_input = True + main.args.skip_output = True + main.main() if __name__ == "__main__": - sys.exit(pytest.main(["-sv", __file__])) + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_failure_2.py b/python/ray/tests/test_failure_2.py index 3b33e1c3f173b..6bb0986e649c3 100644 --- a/python/ray/tests/test_failure_2.py +++ b/python/ray/tests/test_failure_2.py @@ -67,12 +67,11 @@ class Foo: pass # The actor creation should be infeasible. - a = Foo.remote() + Foo.remote() errors = get_error_message(p, 1, ray_constants.INFEASIBLE_TASK_ERROR) assert len(errors) == 1 assert errors[0].type == ray_constants.INFEASIBLE_TASK_ERROR p.close() - del a def test_warning_for_too_many_actors(shutdown_only): diff --git a/python/ray/tests/test_global_state.py b/python/ray/tests/test_global_state.py index 8bf964791292e..6d9c35bef37cd 100644 --- a/python/ray/tests/test_global_state.py +++ b/python/ray/tests/test_global_state.py @@ -287,9 +287,8 @@ def _read_resource_usage(self): def test_backlog_report(shutdown_only): cluster = ray.init( - num_cpus=1, - _system_config={ - "max_pending_lease_requests_per_scheduling_category": 1 + num_cpus=1, _system_config={ + "report_worker_backlog": True, }) global_state_accessor = GlobalStateAccessor( cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) @@ -334,7 +333,10 @@ def backlog_size_set(): def test_heartbeat_ip(shutdown_only): - cluster = ray.init(num_cpus=1) + cluster = ray.init( + num_cpus=1, _system_config={ + "report_worker_backlog": True, + }) global_state_accessor = GlobalStateAccessor( cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) global_state_accessor.connect() diff --git a/python/ray/tests/test_multi_tenancy.py b/python/ray/tests/test_multi_tenancy.py index f2913a50c05ba..1267570d3660b 100644 --- a/python/ray/tests/test_multi_tenancy.py +++ b/python/ray/tests/test_multi_tenancy.py @@ -111,14 +111,12 @@ def get_pid(): all_worker_pids.add(worker_pid) -@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") -def test_runtime_env(shutdown_only): +def test_worker_env(shutdown_only): ray.init( - job_config=ray.job_config.JobConfig( - runtime_env={"env_vars": { - "foo1": "bar1", - "foo2": "bar2" - }})) + job_config=ray.job_config.JobConfig(worker_env={ + "foo1": "bar1", + "foo2": "bar2" + })) @ray.remote def get_env(key): diff --git a/python/ray/tests/test_object_manager.py b/python/ray/tests/test_object_manager.py index e44bf22e83187..1f2c5e5dc4944 100644 --- a/python/ray/tests/test_object_manager.py +++ b/python/ray/tests/test_object_manager.py @@ -296,6 +296,8 @@ def driver(): ray.get(driver.remote()) +# TODO(ekl) this sometimes takes much longer (10+s) due to a higher level +# pull retry. We should try to resolve these hangs in the chunk transfer logic. def test_pull_bundles_admission_control(shutdown_only): cluster = Cluster() object_size = int(6e6) @@ -603,52 +605,6 @@ def task(x): ray.get(t, timeout=10) -@pytest.mark.parametrize( - "ray_start_cluster_head", [{ - "num_cpus": 0, - "object_store_memory": 75 * 1024 * 1024, - "_system_config": { - "worker_lease_timeout_milliseconds": 0, - "object_manager_pull_timeout_ms": 20000, - "object_spilling_threshold": 1.0, - } - }], - indirect=True) -def test_maximize_concurrent_pull_race_condition(ray_start_cluster_head): - # Test if https://github.com/ray-project/ray/issues/18062 is mitigated - cluster = ray_start_cluster_head - cluster.add_node(num_cpus=8, object_store_memory=75 * 1024 * 1024) - - @ray.remote - class RemoteObjectCreator: - def put(self, i): - return np.random.rand(i * 1024 * 1024) # 8 MB data - - def idle(self): - pass - - @ray.remote - def f(x): - print(f"timestamp={time.time()} pulled {len(x)*8} bytes") - time.sleep(1) - return - - remote_obj_creator = RemoteObjectCreator.remote() - remote_refs = [remote_obj_creator.put.remote(1) for _ in range(7)] - print(remote_refs) - # Make sure all objects are created. - ray.get(remote_obj_creator.idle.remote()) - - local_refs = [ray.put(np.random.rand(1 * 1024 * 1024)) for _ in range(20)] - remote_tasks = [f.remote(x) for x in local_refs] - - start = time.time() - ray.get(remote_tasks) - end = time.time() - assert end - start < 20, "Too much time spent in pulling objects, " \ - "check the amount of time in retries" - - if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tests/test_output.py b/python/ray/tests/test_output.py index 958cdecfe732e..93cba471ee21a 100644 --- a/python/ray/tests/test_output.py +++ b/python/ray/tests/test_output.py @@ -65,15 +65,13 @@ def test_autoscaler_no_spam(): import ray import time -# Check that there are no false positives with custom resources. -ray.init(num_cpus=1, resources={"node:x": 1}) +ray.init(num_cpus=1) -@ray.remote(num_cpus=1, resources={"node:x": 1}) +@ray.remote(num_cpus=1) def f(): time.sleep(1) - print("task done") -ray.get([f.remote() for _ in range(15)]) +ray.get([f.remote() for _ in range(5)]) """ proc = run_string_as_driver_nonblocking(script) diff --git a/python/ray/tests/test_placement_group.py b/python/ray/tests/test_placement_group.py index 345f19ff80951..55a2cc5a007e9 100644 --- a/python/ray/tests/test_placement_group.py +++ b/python/ray/tests/test_placement_group.py @@ -345,13 +345,6 @@ def test_remove_placement_group(ray_start_cluster, connect_to_client): cluster.add_node(num_cpus=4) ray.init(address=cluster.address) - @ray.remote - def warmup(): - pass - - # warm up the cluster. - ray.get([warmup.remote() for _ in range(4)]) - with connect_to_client_or_not(connect_to_client): # First try to remove a placement group that doesn't # exist. This should not do anything. diff --git a/python/ray/tests/test_placement_group_3.py b/python/ray/tests/test_placement_group_3.py index eeb6df0f5c4bb..12afdfee47ecb 100644 --- a/python/ray/tests/test_placement_group_3.py +++ b/python/ray/tests/test_placement_group_3.py @@ -608,40 +608,5 @@ def is_usage_updated(): assert cpu_usage == expected -def test_placement_group_removal_leak_regression(ray_start_cluster): - """Related issue: - https://github.com/ray-project/ray/issues/19131 - """ - cluster = ray_start_cluster - cluster.add_node(num_cpus=5) - ray.init(address=cluster.address) - - TOTAL_CPUS = 8 - bundles = [{"CPU": 1, "GPU": 1}] - bundles += [{"CPU": 1} for _ in range(TOTAL_CPUS - 1)] - - pg = placement_group(bundles, strategy="PACK") - # Here, we simulate that the ready task is queued and - # the new node is up. As soon as the new node is up, - # the ready task is scheduled. - # See https://github.com/ray-project/ray/pull/19138 - # for more details about the test. - o = pg.ready() - # Add an artificial delay until the new node is up. - time.sleep(3) - cluster.add_node(num_cpus=5, num_gpus=1) - ray.get(o) - bundle_resource_name = f"bundle_group_{pg.id.hex()}" - expected_bundle_wildcard_val = TOTAL_CPUS * 1000 - - # This should fail if there's a leakage - # because the bundle resources are never returned properly. - def check_bundle_leaks(): - bundle_resources = ray.available_resources()[bundle_resource_name] - return expected_bundle_wildcard_val == bundle_resources - - wait_for_condition(check_bundle_leaks) - - if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_ray_debugger.py b/python/ray/tests/test_ray_debugger.py index 3cc980bb14026..fdc6c56da1eb3 100644 --- a/python/ray/tests/test_ray_debugger.py +++ b/python/ray/tests/test_ray_debugger.py @@ -11,7 +11,6 @@ import ray from ray.cluster_utils import Cluster from ray._private.test_utils import run_string_as_driver, wait_for_condition -from ray._private import services def test_ray_debugger_breakpoint(shutdown_only): @@ -218,7 +217,7 @@ def f(): host, port = session["pdb_address"].split(":") if ray_debugger_external: - assert host == services.get_node_ip_address(), host + assert host not in ["localhost", "127.0.0.1"], host else: assert host == "localhost", host @@ -268,13 +267,13 @@ def f(): host1, port1 = session1["pdb_address"].split(":") if ray_debugger_external: - assert host1 == services.get_node_ip_address(), host1 + assert host1 not in ["localhost", "127.0.0.1"], host1 else: assert host1 == "localhost", host1 host2, port2 = session2["pdb_address"].split(":") if ray_debugger_external: - assert host2 == services.get_node_ip_address(), host2 + assert host2 not in ["localhost", "127.0.0.1"], host2 else: assert host2 == "localhost", host2 diff --git a/python/ray/tests/test_ray_init.py b/python/ray/tests/test_ray_init.py index 3fdb6a6ea110d..5040f4bd65ef4 100644 --- a/python/ray/tests/test_ray_init.py +++ b/python/ray/tests/test_ray_init.py @@ -11,7 +11,6 @@ from ray.client_builder import ClientContext from ray.cluster_utils import Cluster from ray._private.test_utils import run_string_as_driver -from ray._raylet import ClientObjectRef from ray.util.client.worker import Worker import grpc @@ -217,7 +216,6 @@ def test_ray_address(input, call_ray_start): res = ray.init(input) # Ensure this is not a client.connect() assert not isinstance(res, ClientContext) - ray.shutdown() class Credentials(grpc.ChannelCredentials): @@ -259,47 +257,9 @@ def mock_secure_channel(conn_str, with pytest.raises(Stop) as stop: ray.init("ray://127.0.0.1", _credentials=Credentials("test")) - ray.util.disconnect() assert stop.value.credentials.name == "test" -def test_auto_init_non_client(call_ray_start): - address = call_ray_start - with unittest.mock.patch.dict(os.environ, {"RAY_ADDRESS": address}): - res = ray.put(300) - # Ensure this is not a client.connect() - assert not isinstance(res, ClientObjectRef) - ray.shutdown() - - addr = "localhost:{}".format(address.split(":")[-1]) - with unittest.mock.patch.dict(os.environ, {"RAY_ADDRESS": addr}): - res = ray.put(300) - # Ensure this is not a client.connect() - assert not isinstance(res, ClientObjectRef) - - -@pytest.mark.parametrize( - "call_ray_start", - ["ray start --head --ray-client-server-port 25036 --port 0"], - indirect=True) -@pytest.mark.parametrize( - "function", [lambda: ray.put(300), lambda: ray.remote(ray.nodes).remote()]) -def test_auto_init_client(call_ray_start, function): - address = call_ray_start.split(":")[0] - with unittest.mock.patch.dict(os.environ, - {"RAY_ADDRESS": f"ray://{address}:25036"}): - res = function() - # Ensure this is a client connection. - assert isinstance(res, ClientObjectRef) - ray.shutdown() - - with unittest.mock.patch.dict(os.environ, - {"RAY_ADDRESS": "ray://localhost:25036"}): - res = function() - # Ensure this is a client connection. - assert isinstance(res, ClientObjectRef) - - if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tests/test_resource_demand_scheduler.py b/python/ray/tests/test_resource_demand_scheduler.py index eb1260db32aa1..add24d4a571a9 100644 --- a/python/ray/tests/test_resource_demand_scheduler.py +++ b/python/ray/tests/test_resource_demand_scheduler.py @@ -46,52 +46,24 @@ def get_nodes_for(*a, **kw): def test_util_score(): assert _utilization_score({"CPU": 64}, [{"TPU": 16}]) is None - assert _utilization_score({"GPU": 4}, [{"GPU": 2}]) == (1, 0.5, 0.5) + assert _utilization_score({"GPU": 4}, [{"GPU": 2}]) == (0.5, 0.5) assert _utilization_score({"GPU": 4}, [{"GPU": 1}, {"GPU": 1}]) == \ - (1, 0.5, 0.5) - assert _utilization_score({"GPU": 2}, [{"GPU": 2}]) == (1, 2, 2) - assert _utilization_score({ - "GPU": 2 - }, [{ - "GPU": 1 - }, { - "GPU": 1 - }]) == (1, 2, 2) - assert _utilization_score({ - "GPU": 1 - }, [{ - "GPU": 1, - "CPU": 1 - }, { - "GPU": 1 - }]) == (1, 1, 1) - assert _utilization_score({ - "GPU": 1, - "CPU": 1 - }, [{ - "GPU": 1, - "CPU": 1 - }, { - "GPU": 1 - }]) == (2, 1, 1) - assert _utilization_score({"GPU": 2, "TPU": 1}, [{"GPU": 2}]) == (1, 0, 1) - assert _utilization_score({"CPU": 64}, [{"CPU": 64}]) == (1, 64, 64) - assert _utilization_score({"CPU": 64}, [{"CPU": 32}]) == (1, 8, 8) + (0.5, 0.5) + assert _utilization_score({"GPU": 2}, [{"GPU": 2}]) == (2, 2) + assert _utilization_score({"GPU": 2}, [{"GPU": 1}, {"GPU": 1}]) == (2, 2) + assert _utilization_score({"GPU": 2, "TPU": 1}, [{"GPU": 2}]) == (0, 1) + assert _utilization_score({"CPU": 64}, [{"CPU": 64}]) == (64, 64) + assert _utilization_score({"CPU": 64}, [{"CPU": 32}]) == (8, 8) assert _utilization_score({"CPU": 64}, [{"CPU": 16}, {"CPU": 16}]) == \ - (1, 8, 8) + (8, 8) def test_gpu_node_util_score(): # Avoid scheduling CPU tasks on GPU node. assert _utilization_score({"GPU": 1, "CPU": 1}, [{"CPU": 1}]) is None assert _utilization_score({"GPU": 1, "CPU": 1}, [{"CPU": 1, "GPU": 1}]) \ - == (2, 1.0, 1.0) - assert _utilization_score({ - "GPU": 1, - "CPU": 1 - }, [{ - "GPU": 1 - }]) == (1, 0.0, 0.5) + == (1.0, 1.0) + assert _utilization_score({"GPU": 1, "CPU": 1}, [{"GPU": 1}]) == (0.0, 0.5) def test_zero_resource(): @@ -225,7 +197,7 @@ def test_get_nodes_packing_heuristic(): }] * 8) + ([{ "CPU": 1 }] * 64)) == { - "m4.4xlarge": 2, + "m4.16xlarge": 1, "p2.8xlarge": 1 } @@ -243,47 +215,6 @@ def test_get_nodes_packing_heuristic(): } -def test_node_packing_gpu_cpu_bundles(): - TYPES = { - "cpu": { - "resources": { - "CPU": 16, - }, - "max_workers": 10, - }, - "gpu": { - "resources": { - "CPU": 16, - "GPU": 1, - }, - "max_workers": 10, - }, - } - nodes = get_nodes_for(TYPES, {}, "cpu", 9999, ([{ - "CPU": 1 - }] * 30 + [{ - "GPU": 1, - "CPU": 1 - }])) - assert nodes == {"gpu": 1, "cpu": 1} - - nodes = get_nodes_for(TYPES, {}, "cpu", 9999, ([{ - "GPU": 1, - "CPU": 1 - }] + [{ - "CPU": 1 - }] * 30)) - assert nodes == {"gpu": 1, "cpu": 1} - - nodes = get_nodes_for(TYPES, {}, "cpu", 9999, ([{ - "GPU": 1, - "CPU": 1 - }] + [{ - "CPU": 1 - }] * 15)) - assert nodes == {"gpu": 1} - - def test_gpu_node_avoid_cpu_task(): types = { "cpu": { @@ -699,8 +630,13 @@ def test_backlog_queue_impact_on_binpacking_time_aux( "CPU": 1 }]) # If not for the max launch concurrency the next assert should be: - # {'m4.16xlarge': 1, 'p2.8xlarge': 125, 'p2.xlarge': 1} - assert to_launch == {"m4.16xlarge": 1, "p2.8xlarge": 5, "p2.xlarge": 1} + # {'m4.large': 4, 'm4.4xlarge': 2, 'm4.16xlarge': 15, 'p2.8xlarge': 125}. + assert to_launch == { + "m4.large": 4, + "m4.4xlarge": 2, + "m4.16xlarge": 5, + "p2.8xlarge": 5 + } # Check the time it takes when there are 100 nodes available and the demand # requires another 75 nodes. @@ -1386,10 +1322,7 @@ def tearDown(self): shutil.rmtree(self.tmpdir) ray.shutdown() - def waitForNodes(self, expected, comparison=None, tag_filters=None): - if tag_filters is None: - tag_filters = {} - + def waitForNodes(self, expected, comparison=None, tag_filters={}): MAX_ITER = 50 for i in range(MAX_ITER): n = len(self.provider.non_terminated_nodes(tag_filters)) @@ -1731,7 +1664,7 @@ def testScaleUpMinWorkers(self): assert cnt == 2 def testScaleUpIgnoreUsed(self): - config = copy.deepcopy(MULTI_WORKER_CLUSTER) + config = MULTI_WORKER_CLUSTER.copy() # Commenting out this line causes the test case to fail?!?! config["min_workers"] = 0 config["target_utilization_fraction"] = 1.0 @@ -1772,7 +1705,7 @@ def testScaleUpIgnoreUsed(self): assert self.provider.mock_nodes[1].node_type == "p2.xlarge" def testRequestBundlesAccountsForHeadNode(self): - config = copy.deepcopy(MULTI_WORKER_CLUSTER) + config = MULTI_WORKER_CLUSTER.copy() config["head_node_type"] = "p2.8xlarge" config["min_workers"] = 0 config["max_workers"] = 50 @@ -1811,7 +1744,7 @@ def testRequestBundlesAccountsForHeadNode(self): assert self.provider.mock_nodes[1].node_type == "p2.8xlarge" def testRequestBundles(self): - config = copy.deepcopy(MULTI_WORKER_CLUSTER) + config = MULTI_WORKER_CLUSTER.copy() config["min_workers"] = 0 config["max_workers"] = 50 config_path = self.write_config(config) @@ -1848,7 +1781,7 @@ def testRequestBundles(self): assert self.provider.mock_nodes[4].node_type == "m4.16xlarge" def testResourcePassing(self): - config = copy.deepcopy(MULTI_WORKER_CLUSTER) + config = MULTI_WORKER_CLUSTER.copy() config["min_workers"] = 0 config["max_workers"] = 50 config_path = self.write_config(config) @@ -1879,7 +1812,7 @@ def testResourcePassing(self): assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" # TODO (Alex): Autoscaler creates the node during one update then - # starts the updater in the next update. The sleep is largely + # starts the updater in the enxt update. The sleep is largely # unavoidable because the updater runs in its own thread and we have no # good way of ensuring that the commands are sent in time. autoscaler.update() @@ -1894,7 +1827,7 @@ def testResourcePassing(self): runner.assert_has_call("172.0.0.2", "\"GPU\":8") def testScaleUpLoadMetrics(self): - config = copy.deepcopy(MULTI_WORKER_CLUSTER) + config = MULTI_WORKER_CLUSTER.copy() config["min_workers"] = 0 config["max_workers"] = 50 config_path = self.write_config(config) @@ -1925,15 +1858,16 @@ def testScaleUpLoadMetrics(self): "CPU": 16 }]) autoscaler.update() - self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) nodes = { self.provider.mock_nodes[1].node_type, + self.provider.mock_nodes[2].node_type } - assert nodes == {"p2.xlarge"} + assert nodes == {"p2.xlarge", "m4.4xlarge"} def testCommandPassing(self): t = "custom" - config = copy.deepcopy(MULTI_WORKER_CLUSTER) + config = MULTI_WORKER_CLUSTER.copy() config["available_node_types"]["p2.8xlarge"][ "worker_setup_commands"] = ["new_worker_setup_command"] config["available_node_types"]["p2.xlarge"][ @@ -1989,7 +1923,7 @@ def testCommandPassing(self): "init_cmd") def testDockerWorkers(self): - config = copy.deepcopy(MULTI_WORKER_CLUSTER) + config = MULTI_WORKER_CLUSTER.copy() config["available_node_types"]["p2.8xlarge"]["docker"] = { "worker_image": "p2.8x_image:latest", "worker_run_options": ["p2.8x-run-options"] @@ -2047,7 +1981,7 @@ def testDockerWorkers(self): }]) autoscaler.update() self.waitForNodes(5) - assert self.provider.mock_nodes[4].node_type == "m4.large" + assert self.provider.mock_nodes[4].node_type == "m4.16xlarge" autoscaler.update() sleep(0.1) runner.assert_has_call(self.provider.mock_nodes[2].internal_ip, @@ -2110,7 +2044,7 @@ def testUpdateConfig(self): self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) def testEmptyDocker(self): - config = copy.deepcopy(MULTI_WORKER_CLUSTER) + config = MULTI_WORKER_CLUSTER.copy() del config["docker"] config["min_workers"] = 0 config["max_workers"] = 10 diff --git a/python/ray/tests/test_runtime_context.py b/python/ray/tests/test_runtime_context.py index 8ce983da2085a..1c069e10066df 100644 --- a/python/ray/tests/test_runtime_context.py +++ b/python/ray/tests/test_runtime_context.py @@ -4,8 +4,6 @@ import time import sys -from ray._private.test_utils import SignalActor - def test_was_current_actor_reconstructed(shutdown_only): ray.init() @@ -115,119 +113,6 @@ def echo2(self, s): assert ray.get(ray.get(obj)) == "hello" -def test_actor_stats_normal_task(ray_start_regular): - # Because it works at the core worker level, this API works for tasks. - @ray.remote - def func(): - return ray.get_runtime_context()._get_actor_call_stats() - - assert ray.get(func.remote())["func"] == { - "pending": 0, - "running": 1, - "finished": 0, - } - - -def test_actor_stats_sync_actor(ray_start_regular): - signal = SignalActor.remote() - - @ray.remote - class SyncActor: - def run(self): - return ray.get_runtime_context()._get_actor_call_stats() - - def wait_signal(self): - ray.get(signal.wait.remote()) - return ray.get_runtime_context()._get_actor_call_stats() - - actor = SyncActor.remote() - counts = ray.get(actor.run.remote()) - assert counts == { - "SyncActor.run": { - "pending": 0, - "running": 1, - "finished": 0 - }, - "SyncActor.__init__": { - "pending": 0, - "running": 0, - "finished": 1 - } - } - - ref = actor.wait_signal.remote() - other_refs = [actor.run.remote() for _ in range(3) - ] + [actor.wait_signal.remote() for _ in range(5)] - ray.wait(other_refs, timeout=1) - signal.send.remote() - counts = ray.get(ref) - assert counts == { - "SyncActor.run": { - "pending": 3, - "running": 0, - "finished": 1, # from previous run - }, - "SyncActor.wait_signal": { - "pending": 5, - "running": 1, - "finished": 0, - }, - "SyncActor.__init__": { - "pending": 0, - "running": 0, - "finished": 1 - } - } - - -def test_actor_stats_threaded_actor(ray_start_regular): - signal = SignalActor.remote() - - @ray.remote - class ThreadedActor: - def func(self): - ray.get(signal.wait.remote()) - return ray.get_runtime_context()._get_actor_call_stats() - - actor = ThreadedActor.options(max_concurrency=3).remote() - refs = [actor.func.remote() for _ in range(6)] - ready, _ = ray.wait(refs, timeout=1) - assert len(ready) == 0 - signal.send.remote() - results = ray.get(refs) - assert max(result["ThreadedActor.func"]["running"] - for result in results) > 1 - assert max(result["ThreadedActor.func"]["pending"] - for result in results) > 1 - - -def test_actor_stats_async_actor(ray_start_regular): - signal = SignalActor.remote() - - @ray.remote - class AysncActor: - async def func(self): - await signal.wait.remote() - return ray.get_runtime_context()._get_actor_call_stats() - - actor = AysncActor.options(max_concurrency=3).remote() - refs = [actor.func.remote() for _ in range(6)] - ready, _ = ray.wait(refs, timeout=1) - assert len(ready) == 0 - signal.send.remote() - results = ray.get(refs) - assert max(result["AysncActor.func"]["running"] for result in results) == 3 - assert max(result["AysncActor.func"]["pending"] for result in results) == 3 - - -# get_runtime_context() can be called outside of Ray so it should not start -# Ray automatically. -def test_no_auto_init(shutdown_only): - assert not ray.is_initialized() - ray.get_runtime_context() - assert not ray.is_initialized() - - if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_runtime_env.py b/python/ray/tests/test_runtime_env.py index 0f9297238c3cd..110beb4490a6b 100644 --- a/python/ray/tests/test_runtime_env.py +++ b/python/ray/tests/test_runtime_env.py @@ -13,6 +13,7 @@ from ray._private.test_utils import ( run_string_as_driver, run_string_as_driver_nonblocking, wait_for_condition) from ray._private.runtime_env import working_dir as working_dir_pkg +from ray._private.runtime_env.validation import override_task_or_actor_runtime_env # noqa: E501 from ray._private.utils import (get_wheel_filename, get_master_wheel_url, get_release_wheel_url) @@ -773,38 +774,41 @@ def test_container_option_serialize(): job_config = ray.job_config.JobConfig(runtime_env=runtime_env) job_config_serialized = job_config.serialize() # job_config_serialized is JobConfig protobuf serialized string, - # job_config.runtime_env.serialized_runtime_env has container_option info - assert job_config_serialized.count(b"image") == 1 + # job_config.runtime_env.raw_json has container_option info + # job_config.serialized_runtime_env also has container_option info + assert job_config_serialized.count(b"image") == 2 def test_working_dir_override_failure(shutdown_only): ray.init() - with pytest.raises(NotImplementedError): + @ray.remote(runtime_env={"working_dir": "."}) + def f(): + pass - @ray.remote(runtime_env={"working_dir": "."}) - def f(): - pass + with pytest.raises(NotImplementedError): + f.remote() @ray.remote def g(): pass with pytest.raises(NotImplementedError): - g.options(runtime_env={"working_dir": "."}) + g.options(runtime_env={"working_dir": "."}).remote() - with pytest.raises(NotImplementedError): + @ray.remote(runtime_env={"working_dir": "."}) + class A: + pass - @ray.remote(runtime_env={"working_dir": "."}) - class A: - pass + with pytest.raises(NotImplementedError): + A.remote() @ray.remote class B: pass with pytest.raises(NotImplementedError): - B.options(runtime_env={"working_dir": "."}) + B.options(runtime_env={"working_dir": "."}).remote() @pytest.mark.skipif( @@ -940,6 +944,46 @@ def test_large_file_error(shutdown_only): os.chdir(old_dir) +class TestOverrideTaskOrActorRuntimeEnv: + def test_working_dir_in_child_invalid(self): + child_env = {"working_dir": "some_dir"} + parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} + + with pytest.raises(NotImplementedError): + override_task_or_actor_runtime_env(child_env, parent_env) + + def test_uri_inherit(self): + child_env = {} + parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env == {"uris": ["a", "b"]} + + # The dicts passed in should not be mutated. + assert child_env == {} + assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} + + def test_uri_override(self): + child_env = {"uris": ["c", "d"]} + parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env["uris"] == ["c", "d"] + assert result_env.get("working_dir") is None + + # The dicts passed in should not be mutated. + assert child_env == {"uris": ["c", "d"]} + assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} + + def test_no_mutate(self): + child_env = {} + parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env == {"uris": ["a", "b"]} + + # The dictis passed in should not be mutated. + assert child_env == {} + assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_runtime_env_complicated.py b/python/ray/tests/test_runtime_env_complicated.py index e5c7047f275b5..d8c334c413606 100644 --- a/python/ray/tests/test_runtime_env_complicated.py +++ b/python/ray/tests/test_runtime_env_complicated.py @@ -12,16 +12,15 @@ import yaml import ray +from ray._private.runtime_env import RuntimeEnvDict from ray._private.runtime_env.conda import ( inject_dependencies, _inject_ray_to_conda_site, _resolve_install_from_source_ray_dependencies, _current_py_version, ) - -from ray._private.runtime_env.conda_utils import get_conda_env_list -from ray._private.test_utils import ( - run_string_as_driver, run_string_as_driver_nonblocking, wait_for_condition) +from ray._private.test_utils import (run_string_as_driver, + run_string_as_driver_nonblocking) from ray._private.utils import get_conda_env_dir, get_conda_bin_executable if not os.environ.get("CI"): @@ -191,39 +190,6 @@ def test_job_config_conda_env(conda_envs, shutdown_only): ray.shutdown() -@pytest.mark.skipif( - os.environ.get("CONDA_DEFAULT_ENV") is None, - reason="must be run from within a conda environment") -@pytest.mark.skipif( - os.environ.get("CI") and sys.platform != "linux", - reason="This test is only run on linux CI machines.") -def test_job_eager_install(shutdown_only): - # Test enable eager install - runtime_env = {"conda": {"dependencies": ["toolz"]}, "eager_install": True} - env_count = len(get_conda_env_list()) - ray.init(runtime_env=runtime_env) - wait_for_condition( - lambda: len(get_conda_env_list()) == env_count + 1, timeout=60) - ray.shutdown() - # Test disable eager install - runtime_env = { - "conda": { - "dependencies": ["toolz"] - }, - "eager_install": False - } - ray.init(runtime_env=runtime_env) - with pytest.raises(RuntimeError): - wait_for_condition( - lambda: len(get_conda_env_list()) == env_count + 2, timeout=60) - ray.shutdown() - # Test unavailable type - runtime_env = {"conda": {"dependencies": ["toolz"]}, "eager_install": 123} - with pytest.raises(AssertionError): - ray.init(runtime_env=runtime_env) - ray.shutdown() - - def test_get_conda_env_dir(tmp_path): """ Typical output of `conda env list`, for context: @@ -483,6 +449,28 @@ def f(): assert ray.get(f.remote()) +@pytest.mark.skipif(sys.platform == "win32", reason="Unsupported on Windows.") +@pytest.mark.parametrize("use_working_dir", [True, False]) +def test_conda_input_filepath(use_working_dir, tmp_path): + conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} + d = tmp_path / "pip_requirements" + d.mkdir() + p = d / "environment.yml" + + p.write_text(yaml.dump(conda_dict)) + + if use_working_dir: + runtime_env_dict = RuntimeEnvDict({ + "working_dir": str(d), + "conda": "environment.yml" + }) + else: + runtime_env_dict = RuntimeEnvDict({"conda": str(p)}) + + output_conda_dict = runtime_env_dict.get_parsed_dict().get("conda") + assert output_conda_dict == conda_dict + + @skipIf(sys.platform == "win32", "Fail to create temp dir.") def test_experimental_package(shutdown_only): ray.init(num_cpus=2) @@ -526,7 +514,7 @@ def test_experimental_package_github(shutdown_only): ["ray start --head --ray-client-server-port 24001 --port 0"], indirect=True) def test_client_working_dir_filepath(call_ray_start, tmp_path): - """Test that pip and conda filepaths work with working_dir.""" + """Test that pip and conda relative filepaths work with working_dir.""" working_dir = tmp_path / "requirements" working_dir.mkdir() @@ -536,7 +524,10 @@ def test_client_working_dir_filepath(call_ray_start, tmp_path): pip-install-test==0.5 """ pip_file.write_text(requirements_txt) - runtime_env_pip = {"working_dir": str(working_dir), "pip": str(pip_file)} + runtime_env_pip = { + "working_dir": str(working_dir), + "pip": "requirements.txt" + } conda_file = working_dir / "environment.yml" conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} @@ -544,7 +535,7 @@ def test_client_working_dir_filepath(call_ray_start, tmp_path): conda_file.write_text(conda_str) runtime_env_conda = { "working_dir": str(working_dir), - "conda": str(conda_file) + "conda": "environment.yml" } @ray.remote @@ -566,64 +557,6 @@ def f(): assert ray.get(f.remote()) -@pytest.mark.skipif( - os.environ.get("CI") and sys.platform != "linux", - reason="This test is only run on linux CI machines.") -@pytest.mark.parametrize( - "call_ray_start", - ["ray start --head --ray-client-server-port 24001 --port 0"], - indirect=True) -def test_conda_pip_filepaths_remote(call_ray_start, tmp_path): - """Test that pip and conda filepaths work, simulating a remote cluster.""" - - working_dir = tmp_path / "requirements" - working_dir.mkdir() - - pip_file = working_dir / "requirements.txt" - requirements_txt = """ - pip-install-test==0.5 - """ - pip_file.write_text(requirements_txt) - runtime_env_pip = {"pip": str(pip_file)} - - conda_file = working_dir / "environment.yml" - conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} - conda_str = yaml.dump(conda_dict) - conda_file.write_text(conda_str) - runtime_env_conda = {"conda": str(conda_file)} - - @ray.remote - def f(): - import pip_install_test # noqa - return True - - with ray.client("localhost:24001").connect(): - with pytest.raises(ModuleNotFoundError): - # Ensure pip-install-test is not installed in a client that doesn't - # use the runtime_env - ray.get(f.remote()) - - # pip and conda files should be parsed when the function is declared. - f_pip = f.options(runtime_env=runtime_env_pip) - f_conda = f.options(runtime_env=runtime_env_conda) - - # Remove the pip and conda files from the local filesystem. This is - # necessary to simulate the files not being present on the remote cluster, - # because in this single-machine test, the cluster has the same filesystem. - os.remove(pip_file) - os.remove(conda_file) - - # Test with and without a working_dir. - client_envs = [{}, {"working_dir": str(working_dir)}] - for runtime_env in client_envs: - with ray.client("localhost:24001").env(runtime_env).connect(): - with pytest.raises(ModuleNotFoundError): - # Ensure pip-install-test is not installed on the test machine - import pip_install_test # noqa - assert ray.get(f_pip.remote()) - assert ray.get(f_conda.remote()) - - install_env_script = """ import ray import time @@ -785,7 +718,7 @@ def test(self): # Start a new job on the same cluster using the Summit 2021 requirements. with ray.client(f"localhost:{CLIENT_SERVER_PORT}").env({ "working_dir": str(tmp_path), - "pip": str(requirement_path) + "pip": "requirements.txt" }).connect(): @ray.remote @@ -819,9 +752,7 @@ def test(self): return Path("./test").read_text() - a = TestActor.options(runtime_env={ - "pip": str(requirement_path) - }).remote() + a = TestActor.options(runtime_env={"pip": "requirements.txt"}).remote() assert ray.get(a.test.remote()) == "Hello" # Check that per-task pip specification works and that the job's @@ -957,7 +888,7 @@ def f(self): @pytest.mark.skipif( os.environ.get("CI") and sys.platform != "linux", reason="This test is only run on linux CI machines.") -def test_runtime_env_logging_to_driver(ray_start_regular_shared, log_pubsub): +def test_runtime_env_logging_to_dirver(ray_start_regular_shared, log_pubsub): @ray.remote(runtime_env={"pip": [f"requests=={REQUEST_VERSIONS[0]}"]}) def func(): pass diff --git a/python/ray/tests/test_runtime_env_env_vars.py b/python/ray/tests/test_runtime_env_env_vars.py index 479a7f4130bd2..22ce5d5ce59b9 100644 --- a/python/ray/tests/test_runtime_env_env_vars.py +++ b/python/ray/tests/test_runtime_env_env_vars.py @@ -7,37 +7,54 @@ import ray -def test_environment_variables_task(ray_start_regular): +@pytest.mark.parametrize("use_runtime_env", [True, False]) +def test_override_environment_variables_task(ray_start_regular, + use_runtime_env): @ray.remote def get_env(key): return os.environ.get(key) - assert (ray.get( - get_env.options(runtime_env={ - "env_vars": { + if use_runtime_env: + assert (ray.get( + get_env.options(runtime_env={ + "env_vars": { + "a": "b", + } + }).remote("a")) == "b") + else: + assert (ray.get( + get_env.options(override_environment_variables={ "a": "b", - } - }).remote("a")) == "b") + }).remote("a")) == "b") -def test_environment_variables_actor(ray_start_regular): +@pytest.mark.parametrize("use_runtime_env", [True, False]) +def test_override_environment_variables_actor(ray_start_regular, + use_runtime_env): @ray.remote class EnvGetter: def get(self, key): return os.environ.get(key) - a = EnvGetter.options(runtime_env={ - "env_vars": { + if use_runtime_env: + a = EnvGetter.options(runtime_env={ + "env_vars": { + "a": "b", + "c": "d", + } + }).remote() + else: + a = EnvGetter.options(override_environment_variables={ "a": "b", "c": "d", - } - }).remote() - + }).remote() assert (ray.get(a.get.remote("a")) == "b") assert (ray.get(a.get.remote("c")) == "d") -def test_environment_variables_nested_task(ray_start_regular): +@pytest.mark.parametrize("use_runtime_env", [True, False]) +def test_override_environment_variables_nested_task(ray_start_regular, + use_runtime_env): @ray.remote def get_env(key): return os.environ.get(key) @@ -46,19 +63,36 @@ def get_env(key): def get_env_wrapper(key): return ray.get(get_env.remote(key)) - assert (ray.get( - get_env_wrapper.options(runtime_env={ - "env_vars": { + if use_runtime_env: + assert (ray.get( + get_env_wrapper.options(runtime_env={ + "env_vars": { + "a": "b", + } + }).remote("a")) == "b") + else: + assert (ray.get( + get_env_wrapper.options(override_environment_variables={ "a": "b", - } - }).remote("a")) == "b") - - -def test_environment_variables_multitenancy(shutdown_only): - ray.init(runtime_env={"env_vars": { - "foo1": "bar1", - "foo2": "bar2", - }}) + }).remote("a")) == "b") + + +@pytest.mark.parametrize("use_runtime_env", [True, False]) +def test_override_environment_variables_multitenancy(shutdown_only, + use_runtime_env): + if use_runtime_env: + ray.init( + job_config=ray.job_config.JobConfig( + runtime_env={"env_vars": { + "foo1": "bar1", + "foo2": "bar2", + }})) + else: + ray.init( + job_config=ray.job_config.JobConfig(worker_env={ + "foo1": "bar1", + "foo2": "bar2", + })) @ray.remote def get_env(key): @@ -66,27 +100,48 @@ def get_env(key): assert ray.get(get_env.remote("foo1")) == "bar1" assert ray.get(get_env.remote("foo2")) == "bar2" - assert ray.get( - get_env.options(runtime_env={ - "env_vars": { + if use_runtime_env: + assert ray.get( + get_env.options(runtime_env={ + "env_vars": { + "foo1": "baz1", + } + }).remote("foo1")) == "baz1" + assert ray.get( + get_env.options(runtime_env={ + "env_vars": { + "foo1": "baz1", + } + }).remote("foo2")) == "bar2" + else: + assert ray.get( + get_env.options(override_environment_variables={ "foo1": "baz1", - } - }).remote("foo1")) == "baz1" - assert ray.get( - get_env.options(runtime_env={ - "env_vars": { + }).remote("foo1")) == "baz1" + assert ray.get( + get_env.options(override_environment_variables={ "foo1": "baz1", - } - }).remote("foo2")) == "bar2" + }).remote("foo2")) == "bar2" -def test_environment_variables_complex(shutdown_only): - ray.init( - runtime_env={"env_vars": { - "a": "job_a", - "b": "job_b", - "z": "job_z", - }}) +@pytest.mark.parametrize("use_runtime_env", [True, False]) +def test_override_environment_variables_complex(shutdown_only, + use_runtime_env): + if use_runtime_env: + ray.init(runtime_env={ + "env_vars": { + "a": "job_a", + "b": "job_b", + "z": "job_z", + } + }) + else: + ray.init( + job_config=ray.job_config.JobConfig(worker_env={ + "a": "job_a", + "b": "job_b", + "z": "job_z", + })) @ray.remote def get_env(key): @@ -109,45 +164,69 @@ def get_task(self, key): return ray.get(get_env.remote(key)) def nested_get(self, key): - aa = NestedEnvGetter.options(runtime_env={ - "env_vars": { + if use_runtime_env: + aa = NestedEnvGetter.options(runtime_env={ + "env_vars": { + "c": "e", + "d": "dd", + } + }).remote() + else: + aa = NestedEnvGetter.options(override_environment_variables={ "c": "e", "d": "dd", - } - }).remote() + }).remote() return ray.get(aa.get.remote(key)) - a = EnvGetter.options(runtime_env={ - "env_vars": { + if use_runtime_env: + a = EnvGetter.options(runtime_env={ + "env_vars": { + "a": "b", + "c": "d", + } + }).remote() + else: + a = EnvGetter.options(override_environment_variables={ "a": "b", "c": "d", - } - }).remote() - + }).remote() assert (ray.get(a.get.remote("a")) == "b") assert (ray.get(a.get_task.remote("a")) == "b") assert (ray.get(a.nested_get.remote("a")) == "b") assert (ray.get(a.nested_get.remote("c")) == "e") assert (ray.get(a.nested_get.remote("d")) == "dd") - assert (ray.get( - get_env.options(runtime_env={ - "env_vars": { + if use_runtime_env: + assert (ray.get( + get_env.options(runtime_env={ + "env_vars": { + "a": "b", + } + }).remote("a")) == "b") + else: + assert (ray.get( + get_env.options(override_environment_variables={ "a": "b", - } - }).remote("a")) == "b") + }).remote("a")) == "b") assert (ray.get(a.get.remote("z")) == "job_z") assert (ray.get(a.get_task.remote("z")) == "job_z") assert (ray.get(a.nested_get.remote("z")) == "job_z") - assert (ray.get( - get_env.options(runtime_env={ - "env_vars": { + if use_runtime_env: + assert (ray.get( + get_env.options(runtime_env={ + "env_vars": { + "a": "b", + } + }).remote("z")) == "job_z") + else: + assert (ray.get( + get_env.options(override_environment_variables={ "a": "b", - } - }).remote("z")) == "job_z") + }).remote("z")) == "job_z") -def test_environment_variables_reuse(shutdown_only): +@pytest.mark.parametrize("use_runtime_env", [True, False]) +def test_override_environment_variables_reuse(shutdown_only, use_runtime_env): """Test that new tasks don't incorrectly reuse previous environments.""" ray.init() @@ -165,20 +244,32 @@ def g(): return os.environ.get(env_var_name) assert ray.get(f.remote()) is None - assert ray.get( - f.options(runtime_env={ - "env_vars": { + if use_runtime_env: + assert ray.get( + f.options(runtime_env={ + "env_vars": { + env_var_name: val1 + } + }).remote()) == val1 + else: + assert ray.get( + f.options(override_environment_variables={ env_var_name: val1 - } - }).remote()) == val1 + }).remote()) == val1 assert ray.get(f.remote()) is None assert ray.get(g.remote()) is None - assert ray.get( - f.options(runtime_env={ - "env_vars": { + if use_runtime_env: + assert ray.get( + f.options(runtime_env={ + "env_vars": { + env_var_name: val2 + } + }).remote()) == val2 + else: + assert ray.get( + f.options(override_environment_variables={ env_var_name: val2 - } - }).remote()) == val2 + }).remote()) == val2 assert ray.get(g.remote()) is None assert ray.get(f.remote()) is None @@ -187,7 +278,9 @@ def g(): # there aren't enough CPUs (2-4 on Travis CI vs. likely 8 on Buildkite) and # worker processes are being killed to adhere to the soft limit. @pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on Travis CI.") -def test_environment_variables_env_caching(shutdown_only): +@pytest.mark.parametrize("use_runtime_env", [True, False]) +def test_override_environment_variables_env_caching(shutdown_only, + use_runtime_env): """Test that workers with specified envs are cached and reused. When a new task or actor is created with a new runtime env, a @@ -214,7 +307,10 @@ def g(): return task() def get_options(val): - return {"runtime_env": {"env_vars": {env_var_name: val}}} + if use_runtime_env: + return {"override_environment_variables": {env_var_name: val}} + else: + return {"runtime_env": {"env_vars": {env_var_name: val}}} # Empty runtime env does not set our env var. assert ray.get(f.remote())[0] is None diff --git a/python/ray/tests/test_runtime_env_plugin.py b/python/ray/tests/test_runtime_env_plugin.py deleted file mode 100644 index 629cdca4e6d25..0000000000000 --- a/python/ray/tests/test_runtime_env_plugin.py +++ /dev/null @@ -1,75 +0,0 @@ -import os -import tempfile - -import pytest -from ray._private.runtime_env.context import RuntimeEnvContext -from ray._private.runtime_env.plugin import RuntimeEnvPlugin - -import ray - -MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin" - - -class MyPlugin(RuntimeEnvPlugin): - env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY" - - @staticmethod - def validate(runtime_env_dict: dict) -> str: - value = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH] - if value == "fail": - raise ValueError("not allowed") - return value - - @staticmethod - def modify_context(uri: str, runtime_env_dict: dict, - ctx: RuntimeEnvContext) -> None: - plugin_config_dict = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH] - ctx.env_vars[MyPlugin.env_key] = str(plugin_config_dict["env_value"]) - ctx.command_prefix.append( - f"echo {plugin_config_dict['tmp_content']} > " - f"{plugin_config_dict['tmp_file']}") - ctx.py_executable = ( - plugin_config_dict["prefix_command"] + " " + ctx.py_executable) - - -def test_simple_env_modification_plugin(ray_start_regular): - _, tmp_file_path = tempfile.mkstemp() - - @ray.remote - def f(): - import psutil - with open(tmp_file_path, "r") as f: - content = f.read().strip() - return { - "env_value": os.environ[MyPlugin.env_key], - "tmp_content": content, - "nice": psutil.Process().nice(), - } - - with pytest.raises(ValueError, match="not allowed"): - f.options(runtime_env={ - "plugins": { - MY_PLUGIN_CLASS_PATH: "fail" - } - }).remote() - - output = ray.get( - f.options( - runtime_env={ - "plugins": { - MY_PLUGIN_CLASS_PATH: { - "env_value": 42, - "tmp_file": tmp_file_path, - "tmp_content": "hello", - # See https://en.wikipedia.org/wiki/Nice_(Unix) - "prefix_command": "nice -n 19", - } - } - }).remote()) - - assert output == {"env_value": "42", "tmp_content": "hello", "nice": 19} - - -if __name__ == "__main__": - import sys - sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_runtime_env_validation.py b/python/ray/tests/test_runtime_env_validation.py deleted file mode 100644 index 1f3fc254cec29..0000000000000 --- a/python/ray/tests/test_runtime_env_validation.py +++ /dev/null @@ -1,379 +0,0 @@ -import os -import pytest -import sys -import tempfile -from pathlib import Path -import yaml - -from ray._private.runtime_env.validation import ( - parse_and_validate_excludes, parse_and_validate_working_dir, - parse_and_validate_conda, parse_and_validate_pip, - parse_and_validate_env_vars, ParsedRuntimeEnv, - override_task_or_actor_runtime_env) - -CONDA_DICT = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} - -PIP_LIST = ["requests==1.0.0", "pip-install-test"] - - -@pytest.fixture -def test_directory(): - with tempfile.TemporaryDirectory() as tmp_dir: - path = Path(tmp_dir) - subdir = path / "subdir" - subdir.mkdir(parents=True) - requirements_file = subdir / "requirements.txt" - with requirements_file.open(mode="w") as f: - print("\n".join(PIP_LIST), file=f) - - good_conda_file = subdir / "good_conda_env.yaml" - with good_conda_file.open(mode="w") as f: - yaml.dump(CONDA_DICT, f) - - bad_conda_file = subdir / "bad_conda_env.yaml" - with bad_conda_file.open(mode="w") as f: - print("% this is not a YAML file %", file=f) - - old_dir = os.getcwd() - os.chdir(tmp_dir) - yield subdir, requirements_file, good_conda_file, bad_conda_file - os.chdir(old_dir) - - -class TestValidateWorkingDir: - @pytest.mark.parametrize("absolute_path", [True, False]) - def test_validate_working_dir_valid_path(self, test_directory, - absolute_path): - subdir, _, _, _ = test_directory - - rel1 = "." - assert parse_and_validate_working_dir( - rel1, is_task_or_actor=False) == rel1 - - if absolute_path: - subdir = subdir.resolve() - - rel2 = str(subdir) - assert parse_and_validate_working_dir( - rel2, is_task_or_actor=False) == rel2 - - def test_validate_working_dir_absolute_path(self, test_directory): - subdir, _, _, _ = test_directory - - abspath = str(subdir.resolve()) - assert parse_and_validate_working_dir( - abspath, is_task_or_actor=False) == abspath - - def test_validate_working_dir_invalid_path(self): - with pytest.raises(ValueError): - parse_and_validate_working_dir("fake_path", is_task_or_actor=False) - - def test_validate_working_dir_invalid_types(self): - with pytest.raises(TypeError): - parse_and_validate_working_dir( - { - "working_dir": 1 - }, is_task_or_actor=False) - - def test_validate_working_dir_reject_task_or_actor(self): - # Can't pass working_dir for tasks/actors. - with pytest.raises(NotImplementedError): - parse_and_validate_working_dir( - { - "working_dir": "." - }, is_task_or_actor=True) - - -class TestValidateExcludes: - def test_validate_excludes_invalid_types(self): - with pytest.raises(TypeError): - parse_and_validate_excludes(1) - - with pytest.raises(TypeError): - parse_and_validate_excludes(True) - - with pytest.raises(TypeError): - parse_and_validate_excludes("string") - - with pytest.raises(TypeError): - parse_and_validate_excludes(["string", 1]) - - def test_validate_excludes_empty_list(self): - assert ParsedRuntimeEnv({"excludes": []}) == {} - - -@pytest.mark.skipif( - sys.platform == "win32", reason="Conda option not supported on Windows.") -class TestValidateConda: - def test_validate_conda_invalid_types(self): - with pytest.raises(TypeError): - parse_and_validate_conda(1) - - with pytest.raises(TypeError): - parse_and_validate_conda(True) - - def test_validate_conda_str(self, test_directory): - assert parse_and_validate_conda("my_env_name") == "my_env_name" - - def test_validate_conda_invalid_path(self): - with pytest.raises(ValueError): - parse_and_validate_conda("../bad_path.yaml") - - @pytest.mark.parametrize("absolute_path", [True, False]) - def test_validate_conda_valid_file(self, test_directory, absolute_path): - _, _, good_conda_file, _ = test_directory - - if absolute_path: - good_conda_file = good_conda_file.resolve() - - assert parse_and_validate_conda(str(good_conda_file)) == CONDA_DICT - - @pytest.mark.parametrize("absolute_path", [True, False]) - def test_validate_conda_invalid_file(self, test_directory, absolute_path): - _, _, _, bad_conda_file = test_directory - - if absolute_path: - bad_conda_file = bad_conda_file.resolve() - - with pytest.raises(ValueError): - parse_and_validate_conda(str(bad_conda_file)) - - def test_validate_conda_valid_dict(self): - assert parse_and_validate_conda(CONDA_DICT) == CONDA_DICT - - -@pytest.mark.skipif( - sys.platform == "win32", reason="Pip option not supported on Windows.") -class TestValidatePip: - def test_validate_pip_invalid_types(self): - with pytest.raises(TypeError): - parse_and_validate_pip(1) - - with pytest.raises(TypeError): - parse_and_validate_pip(True) - - def test_validate_pip_invalid_path(self): - with pytest.raises(ValueError): - parse_and_validate_pip("../bad_path.txt") - - @pytest.mark.parametrize("absolute_path", [True, False]) - def test_validate_pip_valid_file(self, test_directory, absolute_path): - _, requirements_file, _, _ = test_directory - - if absolute_path: - requirements_file = requirements_file.resolve() - - result = parse_and_validate_pip(str(requirements_file)) - assert result == PIP_LIST - - def test_validate_pip_valid_list(self): - result = parse_and_validate_pip(PIP_LIST) - assert result == PIP_LIST - - -class TestValidateEnvVars: - def test_type_validation(self): - # Only strings allowed. - with pytest.raises(TypeError, match=".*Dict[str, str]*"): - parse_and_validate_env_vars({"INT_ENV": 1}) - - with pytest.raises(TypeError, match=".*Dict[str, str]*"): - parse_and_validate_env_vars({1: "hi"}) - - -class TestParsedRuntimeEnv: - @pytest.mark.parametrize("is_task_or_actor", [True, False]) - def test_empty(self, is_task_or_actor): - assert ParsedRuntimeEnv({}, is_task_or_actor=is_task_or_actor) == {} - - @pytest.mark.skipif( - sys.platform == "win32", reason="Pip option not supported on Windows.") - @pytest.mark.parametrize("is_task_or_actor", [True, False]) - def test_serialization(self, is_task_or_actor): - env1 = ParsedRuntimeEnv( - { - "pip": ["requests"], - "env_vars": { - "hi1": "hi1", - "hi2": "hi2" - } - }, - is_task_or_actor=is_task_or_actor) - - env2 = ParsedRuntimeEnv( - { - "env_vars": { - "hi2": "hi2", - "hi1": "hi1" - }, - "pip": ["requests"] - }, - is_task_or_actor=is_task_or_actor) - - assert env1 == env2 - - serialized_env1 = env1.serialize() - serialized_env2 = env2.serialize() - - # Key ordering shouldn't matter. - assert serialized_env1 == serialized_env2 - - deserialized_env1 = ParsedRuntimeEnv.deserialize(serialized_env1) - deserialized_env2 = ParsedRuntimeEnv.deserialize(serialized_env2) - - assert env1 == deserialized_env1 == env2 == deserialized_env2 - - @pytest.mark.parametrize("is_task_or_actor", [True, False]) - def test_reject_pip_and_conda(self, is_task_or_actor): - with pytest.raises(ValueError): - ParsedRuntimeEnv( - { - "pip": ["requests"], - "conda": "env_name" - }, - is_task_or_actor=is_task_or_actor) - - @pytest.mark.skipif( - sys.platform == "win32", - reason="Conda and pip options not supported on Windows.") - @pytest.mark.parametrize("is_task_or_actor", [True, False]) - def test_ray_commit_injection(self, is_task_or_actor): - # Should not be injected if no pip and conda. - result = ParsedRuntimeEnv( - { - "env_vars": { - "hi": "hi" - } - }, is_task_or_actor=is_task_or_actor) - assert "_ray_commit" not in result - - # Should be injected if pip or conda present. - result = ParsedRuntimeEnv( - { - "pip": ["requests"], - }, is_task_or_actor=is_task_or_actor) - assert "_ray_commit" in result - - result = ParsedRuntimeEnv( - { - "conda": "env_name" - }, is_task_or_actor=is_task_or_actor) - assert "_ray_commit" in result - - # Should not override if passed. - result = ParsedRuntimeEnv( - { - "conda": "env_name", - "_ray_commit": "Blah" - }, - is_task_or_actor=is_task_or_actor) - assert result["_ray_commit"] == "Blah" - - @pytest.mark.parametrize("is_task_or_actor", [True, False]) - def test_inject_current_ray(self, is_task_or_actor): - # Should not be injected if not provided by env var. - result = ParsedRuntimeEnv( - { - "env_vars": { - "hi": "hi" - } - }, is_task_or_actor=is_task_or_actor) - assert "_inject_current_ray" not in result - - os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] = "1" - - # Should be injected if provided by env var. - result = ParsedRuntimeEnv({}, is_task_or_actor=is_task_or_actor) - assert result["_inject_current_ray"] - - # Should be preserved if passed. - result = ParsedRuntimeEnv( - { - "_inject_current_ray": False - }, is_task_or_actor=is_task_or_actor) - assert not result["_inject_current_ray"] - - del os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] - - -class TestOverrideRuntimeEnvs: - def test_override_uris(self): - child = {} - parent = {"uris": ["a", "b"]} - assert override_task_or_actor_runtime_env(child, parent) == parent - - child = {"uris": ["a", "b"]} - parent = {"uris": ["c", "d"]} - assert override_task_or_actor_runtime_env(child, parent) == child - - child = {"uris": ["a", "b"]} - parent = {} - assert override_task_or_actor_runtime_env(child, parent) == child - - def test_override_env_vars(self): - # (child, parent, expected) - TEST_CASES = [ - ({}, {}, {}), - (None, None, None), - ({"a": "b"}, {}, {"a": "b"}), - ({"a": "b"}, None, {"a": "b"}), - ({}, {"a": "b"}, {"a": "b"}), - (None, {"a": "b"}, {"a": "b"}), - ({"a": "b"}, {"a": "d"}, {"a": "b"}), - ({"a": "b"}, {"c": "d"}, {"a": "b", "c": "d"}), - ({"a": "b"}, {"a": "e", "c": "d"}, {"a": "b", "c": "d"}) - ] # yapf: disable - - for idx, (child, parent, expected) in enumerate(TEST_CASES): - child = {"env_vars": child} if child is not None else {} - parent = {"env_vars": parent} if parent is not None else {} - expected = {"env_vars": expected} if expected is not None else {} - assert override_task_or_actor_runtime_env( - child, parent) == expected, f"TEST_INDEX:{idx}" - - def test_uri_inherit(self): - child_env = {} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env == {"uris": ["a", "b"]} - - # The dicts passed in should not be mutated. - assert child_env == {} - assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} - - def test_uri_override(self): - child_env = {"uris": ["c", "d"]} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env["uris"] == ["c", "d"] - assert result_env.get("working_dir") is None - - # The dicts passed in should not be mutated. - assert child_env == {"uris": ["c", "d"]} - assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} - - def test_no_mutate(self): - child_env = {} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env == {"uris": ["a", "b"]} - - # The dicts passed in should not be mutated. - assert child_env == {} - assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} - - def test_inherit_conda(self): - child_env = {"uris": ["a"]} - parent_env = {"conda": "my-env-name", "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env == {"uris": ["a"], "conda": "my-env-name"} - - def test_inherit_pip(self): - child_env = {"uris": ["a"]} - parent_env = {"pip": ["pkg-name"], "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env == {"uris": ["a"], "pip": ["pkg-name"]} - - -if __name__ == "__main__": - sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_scheduling.py b/python/ray/tests/test_scheduling.py index b834d67e0c67c..10a4ab846e844 100644 --- a/python/ray/tests/test_scheduling.py +++ b/python/ray/tests/test_scheduling.py @@ -2,7 +2,6 @@ import collections import logging import platform -import subprocess import sys import time import unittest @@ -550,8 +549,8 @@ def __init__(self): def get_location(self): return ray.worker.global_worker.node.unique_id - @ray.remote(num_cpus=0.5) - def task_cpu(): + @ray.remote + def task_cpu(num_cpus=0.5): time.sleep(10) return ray.worker.global_worker.node.unique_id @@ -579,100 +578,6 @@ def launcher(): cluster.shutdown() -@pytest.mark.parametrize( - "ray_start_cluster", [{ - "num_cpus": 0, - "num_nodes": 1, - }], indirect=True) -def test_head_node_without_cpu(ray_start_cluster): - @ray.remote(num_cpus=1) - def f(): - return 1 - - f.remote() - - check_count = 0 - demand_1cpu = " {'CPU': 1.0}:" - while True: - status = subprocess.check_output(["ray", "status"]).decode() - if demand_1cpu in status: - break - check_count += 1 - assert check_count < 5, f"Incorrect demand. Last status {status}" - time.sleep(1) - - @ray.remote(num_cpus=2) - def g(): - return 2 - - g.remote() - - check_count = 0 - demand_2cpu = " {'CPU': 2.0}:" - while True: - status = subprocess.check_output(["ray", "status"]).decode() - if demand_1cpu in status and demand_2cpu in status: - break - check_count += 1 - assert check_count < 5, f"Incorrect demand. Last status {status}" - time.sleep(1) - - -@pytest.mark.skipif(sys.platform == "win32", reason="Fails on windows") -def test_gpu_scheduling_liveness(ray_start_cluster): - """Check if the GPU scheduling is in progress when - it is used with the placement group - Issue: https://github.com/ray-project/ray/issues/19130 - """ - cluster = ray_start_cluster - # Start a node without a gpu. - cluster.add_node(num_cpus=6) - ray.init(address=cluster.address) - - NUM_CPU_BUNDLES = 10 - - @ray.remote(num_cpus=1) - class Worker(object): - def __init__(self, i): - self.i = i - - def work(self): - time.sleep(0.1) - print("work ", self.i) - - @ray.remote(num_cpus=1, num_gpus=1) - class Trainer(object): - def __init__(self, i): - self.i = i - - def train(self): - time.sleep(0.2) - print("train ", self.i) - - bundles = [{"CPU": 1, "GPU": 1}] - bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)] - - pg = ray.util.placement_group(bundles, strategy="PACK") - o = pg.ready() - # Artificial delay to simulate the real world workload. - time.sleep(3) - print("Scaling up.") - cluster.add_node(num_cpus=6, num_gpus=1) - ray.get(o) - - workers = [ - Worker.options(placement_group=pg).remote(i) - for i in range(NUM_CPU_BUNDLES) - ] - trainer = Trainer.options(placement_group=pg).remote(0) - - # If the gpu scheduling doesn't properly work, the below - # code will hang. - ray.get( - [workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)], timeout=30) - ray.get(trainer.train.remote(), timeout=30) - - if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py index 01b234ceb8315..057c2e0b2ae32 100644 --- a/python/ray/tests/test_tls_auth.py +++ b/python/ray/tests/test_tls_auth.py @@ -1,81 +1,23 @@ # coding: utf-8 -import logging import os import sys import pytest -import ray +import logging logger = logging.getLogger(__name__) -@pytest.mark.parametrize("use_tls", [True], indirect=True) -def test_put_get_with_tls(shutdown_only, use_tls): - ray.init(num_cpus=0) - - for i in range(100): - value_before = i * 10**6 - object_ref = ray.put(value_before) - value_after = ray.get(object_ref) - assert value_before == value_after - - for i in range(100): - value_before = i * 10**6 * 1.0 - object_ref = ray.put(value_before) - value_after = ray.get(object_ref) - assert value_before == value_after - - for i in range(100): - value_before = "h" * i - object_ref = ray.put(value_before) - value_after = ray.get(object_ref) - assert value_before == value_after - - for i in range(100): - value_before = [1] * i - object_ref = ray.put(value_before) - value_after = ray.get(object_ref) - assert value_before == value_after - - -@pytest.mark.parametrize("use_tls", [True], indirect=True) -def test_submit_with_tls(shutdown_only, use_tls): - ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) - - @ray.remote - def f(n): - return list(range(n)) - - id1, id2, id3 = f._remote(args=[3], num_returns=3) - assert ray.get([id1, id2, id3]) == [0, 1, 2] - - @ray.remote - class Actor: - def __init__(self, x, y=0): - self.x = x - self.y = y - - def method(self, a, b=0): - return self.x, self.y, a, b - - a = Actor._remote( - args=[0], kwargs={"y": 1}, num_gpus=1, resources={"Custom": 1}) - - id1, id2, id3, id4 = a.method._remote( - args=["test"], kwargs={"b": 2}, num_returns=4) - assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] - - @pytest.mark.skipif( sys.platform == "darwin", reason=("Cryptography doesn't install in Mac build pipeline")) @pytest.mark.parametrize("use_tls", [True], indirect=True) def test_client_connect_to_tls_server(use_tls, init_and_serve): - from ray.util.client import ray as ray_client + from ray.util.client import ray os.environ["RAY_USE_TLS"] = "0" with pytest.raises(ConnectionError): - ray_client.connect("localhost:50051") + ray.connect("localhost:50051") os.environ["RAY_USE_TLS"] = "1" - ray_client.connect("localhost:50051") + ray.connect("localhost:50051") diff --git a/python/ray/tests/test_traceback.py b/python/ray/tests/test_traceback.py index fa48ec62f09cb..3081bcc6ec3d4 100644 --- a/python/ray/tests/test_traceback.py +++ b/python/ray/tests/test_traceback.py @@ -270,45 +270,6 @@ def __repr__(self): assert label_dict["repr"] == actor_repr -def test_unpickleable_stacktrace(): - expected_output = """System error: Failed to unpickle serialized exception -traceback: Traceback (most recent call last): - File "FILE", line ZZ, in from_bytes - return pickle.loads(ray_exception.serialized_exception) -TypeError: __init__() missing 1 required positional argument: 'arg' - -The above exception was the direct cause of the following exception: - -Traceback (most recent call last): - File "FILE", line ZZ, in deserialize_objects - obj = self._deserialize_object(data, metadata, object_ref) - File "FILE", line ZZ, in _deserialize_object - return RayError.from_bytes(obj) - File "FILE", line ZZ, in from_bytes - raise RuntimeError(msg) from e -RuntimeError: Failed to unpickle serialized exception""" - - class NoPickleError(OSError): - def __init__(self, arg): - pass - - def g(a): - raise NoPickleError("asdf") - - @ray.remote - def f(): - a = 3 - b = 4 - c = a + b - return g(c) - - try: - ray.get(f.remote()) - except Exception as ex: - print(repr(scrub_traceback(str(ex)))) - assert clean_noqa(expected_output) == scrub_traceback(str(ex)) - - if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 47330c64c7ec6..c8bcfe31a8b92 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -160,7 +160,7 @@ py_test( size = "small", srcs = ["tests/test_logger.py"], deps = [":tune_lib"], - tags = ["team:ml"], + tags = ["team:ml", "jenkins_only"], ) py_test( diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index fbaa7207a04a0..e7cfc31810e1d 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -1,11 +1,9 @@ import json import logging import os -import warnings from numbers import Number from typing import Any, Dict, List, Optional, Tuple -from ray.util.debug import log_once from ray.tune.utils import flatten_dict from ray.tune.utils.serialization import TuneFunctionDecoder from ray.tune.utils.util import is_nan_or_inf @@ -558,17 +556,6 @@ def best_result(self) -> Dict: "the metric and mode explicitly and fetch the last result.") return self.best_trial.last_result - def _delimiter(self): - # Deprecate: 1.9 (default should become `/`) - delimiter = os.environ.get("TUNE_RESULT_DELIM", ".") - if delimiter == "." and log_once("delimiter_deprecation"): - warnings.warn( - "Dataframes will use '/' instead of '.' to delimit " - "nested result keys in future versions of Ray. For forward " - "compatibility, set the environment variable " - "TUNE_RESULT_DELIM='/'") - return delimiter - @property def best_result_df(self) -> DataFrame: """Get the best result of the experiment as a pandas dataframe. @@ -582,9 +569,7 @@ def best_result_df(self) -> DataFrame: if not pd: raise ValueError("`best_result_df` requires pandas. Install with " "`pip install pandas`.") - - best_result = flatten_dict( - self.best_result, delimiter=self._delimiter()) + best_result = flatten_dict(self.best_result, delimiter=".") return pd.DataFrame.from_records([best_result], index="trial_id") @property @@ -594,13 +579,12 @@ def results(self) -> Dict[str, Dict]: @property def results_df(self) -> DataFrame: - """Get all the last results as a pandas dataframe.""" if not pd: - raise ValueError("`results_df` requires pandas. Install with " + raise ValueError("`best_result_df` requires pandas. Install with " "`pip install pandas`.") return pd.DataFrame.from_records( [ - flatten_dict(trial.last_result, delimiter=self._delimiter()) + flatten_dict(trial.last_result, delimiter=".") for trial in self.trials ], index="trial_id") diff --git a/python/ray/tune/commands.py b/python/ray/tune/commands.py index 5d47605c63181..7fbbe9776bde2 100644 --- a/python/ray/tune/commands.py +++ b/python/ray/tune/commands.py @@ -116,9 +116,10 @@ def list_trials(experiment_path, _check_tabulate() try: - checkpoints_df = Analysis(experiment_path).dataframe() # last result - except TuneError as e: - raise click.ClickException("No trial data found!") from e + checkpoints_df = Analysis(experiment_path).dataframe( + metric="episode_reward_mean", mode="max") + except TuneError: + raise click.ClickException("No trial data found!") def key_filter(k): return k in DEFAULT_CLI_KEYS or k.startswith(CONFIG_PREFIX) diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py index 77b80e510af2b..db822434f1223 100644 --- a/python/ray/tune/durable_trainable.py +++ b/python/ray/tune/durable_trainable.py @@ -171,16 +171,14 @@ def durable(trainable: Union[str, Type[Trainable], Callable]): A durable trainable class wrapped around your trainable. """ - overwrite_name = None if isinstance(trainable, str): trainable_cls = get_trainable_cls(trainable) - overwrite_name = f"Durable{trainable}" else: trainable_cls = trainable if not inspect.isclass(trainable_cls): # Function API - return wrap_function(trainable_cls, durable=True, name=overwrite_name) + return wrap_function(trainable_cls, durable=True) if not issubclass(trainable_cls, Trainable): raise ValueError( @@ -189,14 +187,8 @@ def durable(trainable: Union[str, Type[Trainable], Callable]): f"it does. Got: {type(trainable_cls)}") # else: Class API - - # Class is already durable - - if issubclass(trainable_cls, DurableTrainable): - return trainable_cls - class _WrappedDurableTrainable(DurableTrainable, trainable_cls): - _name = overwrite_name or (trainable_cls.__name__ if hasattr( - trainable_cls, "__name__") else "durable_trainable") + _name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \ + else "durable_trainable" return _WrappedDurableTrainable diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index e4c2018068d7a..ae4235aa89099 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -10,8 +10,6 @@ from functools import partial from numbers import Number -from typing import Any, Callable, Optional - from six.moves import queue from ray.util.debug import log_once @@ -532,10 +530,7 @@ def _report_thread_runner_error(self, block=False): pass -def wrap_function(train_func: Callable[[Any], Any], - durable: bool = False, - warn: bool = True, - name: Optional[str] = None): +def wrap_function(train_func, durable=False, warn=True): inherit_from = (FunctionRunner, ) if hasattr(train_func, "__mixins__"): @@ -567,8 +562,8 @@ def wrap_function(train_func: Callable[[Any], Any], "arguments to be `func(config, checkpoint_dir=None)`.") class ImplicitFunc(*inherit_from): - _name = name or (train_func.__name__ - if hasattr(train_func, "__name__") else "func") + _name = train_func.__name__ if hasattr(train_func, "__name__") \ + else "func" def _trainable_func(self, config, reporter, checkpoint_dir): if not use_checkpoint and not use_reporter: diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index f555e684b4466..edc8dcb5482d1 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -169,8 +169,8 @@ class TBXLogger(Logger): {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} """ - VALID_HPARAMS = (str, bool, int, float, list, type(None)) - VALID_NP_HPARAMS = (np.bool8, np.float32, np.float64, np.int32, np.int64) + VALID_HPARAMS = (str, bool, np.bool8, int, np.integer, float, list, + type(None)) def _init(self): try: @@ -254,18 +254,10 @@ def _try_log_hparams(self, result): if isinstance(v, self.VALID_HPARAMS) } - np_params = { - k: v.tolist() - for k, v in flat_params.items() - if isinstance(v, self.VALID_NP_HPARAMS) - } - - scrubbed_params.update(np_params) - removed = { k: v for k, v in flat_params.items() - if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS) + if not isinstance(v, self.VALID_HPARAMS) } if removed: logger.info( @@ -593,7 +585,8 @@ class TBXLoggerCallback(LoggerCallback): {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} """ - VALID_HPARAMS = (str, bool, int, float, list, type(None)) + # NoneType is not supported on the last TBX release yet. + VALID_HPARAMS = (str, bool, int, float, list) VALID_NP_HPARAMS = (np.bool8, np.float32, np.float64, np.int32, np.int64) def __init__(self): diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 52f19f8029da2..0b69faa51550d 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -1,6 +1,5 @@ from __future__ import print_function -import datetime from typing import Dict, List, Optional, Union import collections @@ -9,17 +8,15 @@ import numpy as np import time -from ray.util.annotations import PublicAPI, DeveloperAPI -from ray.util.queue import Queue - from ray.tune.callback import Callback from ray.tune.logger import pretty_print, logger -from ray.tune.result import ( - DEFAULT_METRIC, EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS, NODE_IP, - PID, TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL, AUTO_RESULT_KEYS) -from ray.tune.trial import DEBUG_PRINT_INTERVAL, Trial, Location +from ray.tune.result import (DEFAULT_METRIC, EPISODE_REWARD_MEAN, + MEAN_ACCURACY, MEAN_LOSS, TRAINING_ITERATION, + TIME_TOTAL_S, TIMESTEPS_TOTAL, AUTO_RESULT_KEYS) +from ray.tune.trial import DEBUG_PRINT_INTERVAL, Trial from ray.tune.utils import unflattened_lookup from ray.tune.utils.log import Verbosity, has_verbosity +from ray.util.annotations import PublicAPI, DeveloperAPI try: from collections.abc import Mapping, MutableMapping @@ -162,8 +159,6 @@ def __init__( self._max_report_freqency = max_report_frequency self._last_report_time = 0 - self._start_time = time.time() - self._metric = metric self._mode = mode @@ -193,12 +188,6 @@ def set_search_properties(self, metric: Optional[str], def set_total_samples(self, total_samples: int): self._total_samples = total_samples - def set_start_time(self, timestamp: Optional[float] = None): - if timestamp is not None: - self._start_time = time.time() - else: - self._start_time = timestamp - def should_report(self, trials: List[Trial], done: bool = False): if time.time() - self._last_report_time > self._max_report_freqency: self._last_report_time = time.time() @@ -278,11 +267,7 @@ def _progress_str(self, if not self._metrics_override: user_metrics = self._infer_user_metrics(trials, self._infer_limit) self._metric_columns.update(user_metrics) - messages = [ - "== Status ==", - time_passed_str(self._start_time, time.time()), - memory_debug_str(), *sys_info - ] + messages = ["== Status ==", memory_debug_str(), *sys_info] if done: max_progress = None max_error = None @@ -431,32 +416,15 @@ def __init__( "to `tune.run()` instead.") self._overwrite = overwrite - self._output_queue = None - - def set_output_queue(self, queue: Queue): - self._output_queue = queue def report(self, trials: List[Trial], done: bool, *sys_info: Dict): - overwrite = self._overwrite + from IPython.display import clear_output + from IPython.core.display import display, HTML + if self._overwrite: + clear_output(wait=True) progress_str = self._progress_str( trials, done, *sys_info, fmt="html", delim="
") - - def update_output(): - from IPython.display import clear_output - from IPython.core.display import display, HTML - - if overwrite: - clear_output(wait=True) - - display(HTML(progress_str)) - - if self._output_queue is not None: - # If an output queue is set, send callable (e.g. when using - # Ray client) - self._output_queue.put(update_output) - else: - # Else, output directly - update_output() + display(HTML(progress_str)) @PublicAPI @@ -542,33 +510,6 @@ def memory_debug_str(): "to resolve)") -def time_passed_str(start_time: float, current_time: float): - current_time_dt = datetime.datetime.fromtimestamp(current_time) - start_time_dt = datetime.datetime.fromtimestamp(start_time) - delta: datetime.timedelta = current_time_dt - start_time_dt - - rest = delta.total_seconds() - days = rest // (60 * 60 * 24) - - rest -= days * (60 * 60 * 24) - hours = rest // (60 * 60) - - rest -= hours * (60 * 60) - minutes = rest // 60 - - seconds = rest - minutes * 60 - - if days > 0: - running_for_str = f"{days:.0f} days, " - else: - running_for_str = "" - - running_for_str += f"{hours:02.0f}:{minutes:02.0f}:{seconds:05.2f}" - - return (f"Current time: {current_time_dt:%Y-%m-%d %H:%M:%S} " - f"(running for {running_for_str})") - - def _get_trials_by_state(trials: List[Trial]): trials_by_state = collections.defaultdict(list) for t in trials: @@ -833,18 +774,6 @@ def _fair_filter_trials(trials_by_state: Dict[str, List[Trial]], return filtered_trials -def _get_trial_location(trial: Trial, result: dict) -> Location: - # we get the location from the result, as the one in trial will be - # reset when trial terminates - node_ip, pid = result.get(NODE_IP, None), result.get(PID, None) - if node_ip and pid: - location = Location(node_ip, pid) - else: - # fallback to trial location if there hasn't been a report yet - location = trial.location - return location - - def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]): """Returns the following information about a trial: @@ -857,8 +786,7 @@ def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]): """ result = trial.last_result config = trial.config - location = _get_trial_location(trial, result) - trial_info = [str(trial), trial.status, str(location)] + trial_info = [str(trial), trial.status, str(trial.location)] trial_info += [ unflattened_lookup(param, config, default=None) for param in parameters ] diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 959fba6c0dcff..52ec1102a5f78 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -93,18 +93,11 @@ class _TrialCleanup: Args: threshold (int): Number of futures to hold at once. If the threshold is passed, cleanup will kick in and remove futures. - force_cleanup (int): Grace periods for forceful actor termination. - If 0, actors will not be forcefully terminated. """ - def __init__(self, - threshold: int = TRIAL_CLEANUP_THRESHOLD, - force_cleanup: int = 0): + def __init__(self, threshold: int = TRIAL_CLEANUP_THRESHOLD): self.threshold = threshold self._cleanup_map = {} - if force_cleanup < 0: - force_cleanup = 0 - self._force_cleanup = force_cleanup def add(self, trial: Trial, actor: ActorHandle): """Adds a trial actor to be stopped. @@ -130,27 +123,15 @@ def cleanup(self, partial: bool = True): If partial=False, all futures are expected to return. If a future does not return within the timeout period, the cleanup terminates. """ - # At this point, self._cleanup_map holds the last references - # to actors. Removing those references either one-by-one - # (graceful termination case) or all at once, by reinstantiating - # self._cleanup_map (forceful termination case) will cause Ray - # to kill the actors during garbage collection. logger.debug("Cleaning up futures") num_to_keep = int(self.threshold) / 2 if partial else 0 while len(self._cleanup_map) > num_to_keep: dones, _ = ray.wait( - list(self._cleanup_map), - timeout=DEFAULT_GET_TIMEOUT - if not self._force_cleanup else self._force_cleanup) + list(self._cleanup_map), timeout=DEFAULT_GET_TIMEOUT) if not dones: logger.warning( "Skipping cleanup - trainable.stop did not return in " "time. Consider making `stop` a faster operation.") - if not partial and self._force_cleanup: - logger.warning( - "Forcing trainable cleanup by terminating actors.") - self._cleanup_map = {} - return else: done = dones[0] del self._cleanup_map[done] @@ -184,9 +165,7 @@ def __init__(self, # We use self._paused to store paused trials here. self._paused = {} - force_trial_cleanup = int( - os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "0")) - self._trial_cleanup = _TrialCleanup(force_cleanup=force_trial_cleanup) + self._trial_cleanup = _TrialCleanup() self._has_cleaned_up_pgs = False self._reuse_actors = reuse_actors # The maxlen will be updated when `set_max_pending_trials()` is called diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 9f143db42d37d..d83c727179387 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -1,8 +1,5 @@ import logging -import uuid - from types import FunctionType -from typing import Optional import ray import ray.cloudpickle as pickle @@ -117,25 +114,23 @@ def check_serializability(key, value): _global_registry.register(TEST, key, value) -def _make_key(prefix, category, key): +def _make_key(category, key): """Generate a binary key for the given category and key. Args: - prefix (str): Prefix category (str): The category of the item key (str): The unique identifier for the item Returns: The key to use for storing a the value. """ - return (b"TuneRegistry:" + prefix.encode("ascii") + b":" + - category.encode("ascii") + b"/" + key.encode("ascii")) + return (b"TuneRegistry:" + category.encode("ascii") + b"/" + + key.encode("ascii")) class _Registry: - def __init__(self, prefix: Optional[str] = None): + def __init__(self): self._to_flush = {} - self._prefix = prefix or uuid.uuid4().hex[:8] def register(self, category, key, value): """Registers the value with the global registry. @@ -153,14 +148,14 @@ def register(self, category, key, value): def contains(self, category, key): if _internal_kv_initialized(): - value = _internal_kv_get(_make_key(self._prefix, category, key)) + value = _internal_kv_get(_make_key(category, key)) return value is not None else: return (category, key) in self._to_flush def get(self, category, key): if _internal_kv_initialized(): - value = _internal_kv_get(_make_key(self._prefix, category, key)) + value = _internal_kv_get(_make_key(category, key)) if value is None: raise ValueError( "Registry value for {}/{} doesn't exist.".format( @@ -171,12 +166,11 @@ def get(self, category, key): def flush_values(self): for (category, key), value in self._to_flush.items(): - _internal_kv_put( - _make_key(self._prefix, category, key), value, overwrite=True) + _internal_kv_put(_make_key(category, key), value, overwrite=True) self._to_flush.clear() -_global_registry = _Registry(prefix="global") +_global_registry = _Registry() ray.worker._post_init_hooks.append(_global_registry.flush_values) diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index c166da2c0ce8e..e9eb7f40212dc 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -67,10 +67,6 @@ DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL, MEAN_ACCURACY, MEAN_LOSS) -# Metrics that don't require at least one iteration to complete -DEBUG_METRICS = (TRIAL_ID, "experiment_id", "date", "timestamp", PID, HOSTNAME, - NODE_IP, "config") - # Make sure this doesn't regress AUTO_RESULT_KEYS = ( TRAINING_ITERATION, diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 1a671f7b24996..8b7830b79d150 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -80,8 +80,6 @@ class HyperBandScheduler(FIFOScheduler): reaching max_t. Defaults to True. """ - _supports_buffered_results = False - def __init__(self, time_attr: str = "training_iteration", metric: Optional[str] = None, diff --git a/python/ray/tune/schedulers/trial_scheduler.py b/python/ray/tune/schedulers/trial_scheduler.py index 64e9d99613ae8..a0626416afe00 100644 --- a/python/ray/tune/schedulers/trial_scheduler.py +++ b/python/ray/tune/schedulers/trial_scheduler.py @@ -14,16 +14,10 @@ class TrialScheduler: _metric = None - _supports_buffered_results = True - @property def metric(self): return self._metric - @property - def supports_buffered_results(self): - return self._supports_buffered_results - def set_search_properties(self, metric: Optional[str], mode: Optional[str]) -> bool: """Pass search properties to scheduler. diff --git a/python/ray/tune/suggest/bohb.py b/python/ray/tune/suggest/bohb.py index e8f15c5082866..52ebf84e9acc2 100644 --- a/python/ray/tune/suggest/bohb.py +++ b/python/ray/tune/suggest/bohb.py @@ -235,10 +235,10 @@ def to_wrapper(self, trial_id: str, result: Dict) -> _BOHBJobWrapper: def on_pause(self, trial_id: str): self.paused.add(trial_id) - self.running.discard(trial_id) + self.running.remove(trial_id) def on_unpause(self, trial_id: str): - self.paused.discard(trial_id) + self.paused.remove(trial_id) self.running.add(trial_id) @staticmethod diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 598b2a2dccf59..bb49de900fb1d 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -226,14 +226,6 @@ class B(Trainable): self.assertRaises(TypeError, lambda: register_trainable("foo", A)) self.assertRaises(TypeError, lambda: Experiment("foo", A)) - def testRegisterDurableTrainableTwice(self): - def train(config, reporter): - pass - - register_trainable("foo", train) - register_trainable("foo", tune.durable("foo")) - register_trainable("foo", tune.durable("foo")) - def testTrainableCallable(self): def dummy_fn(config, reporter, steps): reporter(timesteps_total=steps, done=True) diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 87a6f42f7af25..98dd4e4b2da58 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -190,7 +190,7 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster): running_trials = _get_running_trials(runner) assert len(running_trials) == 1 assert _check_trial_running(running_trials[0]) - assert not trial.has_reported_at_least_once + assert not trial.last_result assert trial.status == Trial.RUNNING cluster.remove_node(node) cluster.add_node(num_cpus=1) diff --git a/python/ray/tune/tests/test_logger.py b/python/ray/tune/tests/test_logger.py index 84c633a9b0884..ef75bfcfb49d5 100644 --- a/python/ray/tune/tests/test_logger.py +++ b/python/ray/tune/tests/test_logger.py @@ -230,6 +230,16 @@ def testLegacyBadTBX(self): logger.close() assert "INFO" in cm.output[0] + config = {"None": None} + t = Trial( + evaluated_params=config, trial_id="tbx", logdir=self.test_dir) + logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) + logger.on_result(result(0, 4)) + logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1})) + with self.assertLogs("ray.tune.logger", level="INFO") as cm: + logger.close() + assert "INFO" in cm.output[0] + def testBadTBX(self): config = {"b": (1, 2, 3)} t = Trial( @@ -243,6 +253,18 @@ def testBadTBX(self): logger.on_trial_complete(3, [], t) assert "INFO" in cm.output[0] + config = {"None": None} + t = Trial( + evaluated_params=config, trial_id="tbx", logdir=self.test_dir) + logger = TBXLoggerCallback() + logger.on_trial_result(0, [], t, result(0, 4)) + logger.on_trial_result(1, [], t, result(1, 5)) + logger.on_trial_result( + 2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1})) + with self.assertLogs("ray.tune.logger", level="INFO") as cm: + logger.on_trial_complete(3, [], t) + assert "INFO" in cm.output[0] + if __name__ == "__main__": import pytest diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index 6978d2c128c6f..2e85fe0a6b368 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -3,14 +3,13 @@ import os import unittest from unittest.mock import MagicMock, Mock, patch - from ray import tune from ray._private.test_utils import run_string_as_driver from ray.tune.trial import Trial from ray.tune.result import AUTO_RESULT_KEYS -from ray.tune.progress_reporter import ( - CLIReporter, JupyterNotebookReporter, _fair_filter_trials, best_trial_str, - detect_reporter, trial_progress_str, time_passed_str) +from ray.tune.progress_reporter import (CLIReporter, JupyterNotebookReporter, + _fair_filter_trials, best_trial_str, + detect_reporter, trial_progress_str) EXPECTED_RESULT_1 = """Result logdir: /foo Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED) @@ -61,92 +60,76 @@ END_TO_END_COMMAND = """ import ray from ray import tune -from ray.tune.trial import Location -from ray.tune.progress_reporter import _get_trial_location -from unittest.mock import patch - - -def mock_get_trial_location(trial, result): - location = _get_trial_location(trial, result) - if location.pid: - return Location("123.123.123.123", "1") - return location +reporter = tune.progress_reporter.CLIReporter(metric_columns=["done"]) -with patch("ray.tune.progress_reporter._get_trial_location", - mock_get_trial_location): - reporter = tune.progress_reporter.CLIReporter(metric_columns=["done"]) +def f(config): + return {"done": True} - def f(config): - return {"done": True} - - ray.init(num_cpus=1) - tune.run_experiments( - { - "one": { - "run": f, - "config": { - "a": tune.grid_search(list(range(10))), - }, - }, - "two": { - "run": f, - "config": { - "b": tune.grid_search(list(range(10))), - }, - }, - "three": { - "run": f, - "config": { - "c": tune.grid_search(list(range(10))), - }, - }, +ray.init(num_cpus=1) +tune.run_experiments({ + "one": { + "run": f, + "config": { + "a": tune.grid_search(list(range(10))), + }, + }, + "two": { + "run": f, + "config": { + "b": tune.grid_search(list(range(10))), }, - verbose=3, - progress_reporter=reporter)""" + }, + "three": { + "run": f, + "config": { + "c": tune.grid_search(list(range(10))), + }, + }, +}, verbose=3, progress_reporter=reporter)""" EXPECTED_END_TO_END_START = """Number of trials: 30/30 (29 PENDING, 1 RUNNING) -+---------------+----------+-------------------+-----+-----+ -| Trial name | status | loc | a | b | -|---------------+----------+-------------------+-----+-----| -| f_xxxxx_00000 | RUNNING | 123.123.123.123:1 | 0 | | -| f_xxxxx_00001 | PENDING | | 1 | |""" ++---------------+----------+-------+-----+-----+ +| Trial name | status | loc | a | b | +|---------------+----------+-------+-----+-----| +| f_xxxxx_00000 | RUNNING | | 0 | | +| f_xxxxx_00001 | PENDING | | 1 | |""" EXPECTED_END_TO_END_END = """Number of trials: 30/30 (30 TERMINATED) -+---------------+------------+-------------------+-----+-----+-----+--------+ -| Trial name | status | loc | a | b | c | done | -|---------------+------------+-------------------+-----+-----+-----+--------| -| f_xxxxx_00000 | TERMINATED | 123.123.123.123:1 | 0 | | | True | -| f_xxxxx_00001 | TERMINATED | 123.123.123.123:1 | 1 | | | True | -| f_xxxxx_00002 | TERMINATED | 123.123.123.123:1 | 2 | | | True | -| f_xxxxx_00003 | TERMINATED | 123.123.123.123:1 | 3 | | | True | -| f_xxxxx_00004 | TERMINATED | 123.123.123.123:1 | 4 | | | True | -| f_xxxxx_00005 | TERMINATED | 123.123.123.123:1 | 5 | | | True | -| f_xxxxx_00006 | TERMINATED | 123.123.123.123:1 | 6 | | | True | -| f_xxxxx_00007 | TERMINATED | 123.123.123.123:1 | 7 | | | True | -| f_xxxxx_00008 | TERMINATED | 123.123.123.123:1 | 8 | | | True | -| f_xxxxx_00009 | TERMINATED | 123.123.123.123:1 | 9 | | | True | -| f_xxxxx_00010 | TERMINATED | 123.123.123.123:1 | | 0 | | True | -| f_xxxxx_00011 | TERMINATED | 123.123.123.123:1 | | 1 | | True | -| f_xxxxx_00012 | TERMINATED | 123.123.123.123:1 | | 2 | | True | -| f_xxxxx_00013 | TERMINATED | 123.123.123.123:1 | | 3 | | True | -| f_xxxxx_00014 | TERMINATED | 123.123.123.123:1 | | 4 | | True | -| f_xxxxx_00015 | TERMINATED | 123.123.123.123:1 | | 5 | | True | -| f_xxxxx_00016 | TERMINATED | 123.123.123.123:1 | | 6 | | True | -| f_xxxxx_00017 | TERMINATED | 123.123.123.123:1 | | 7 | | True | -| f_xxxxx_00018 | TERMINATED | 123.123.123.123:1 | | 8 | | True | -| f_xxxxx_00019 | TERMINATED | 123.123.123.123:1 | | 9 | | True | -| f_xxxxx_00020 | TERMINATED | 123.123.123.123:1 | | | 0 | True | -| f_xxxxx_00021 | TERMINATED | 123.123.123.123:1 | | | 1 | True | -| f_xxxxx_00022 | TERMINATED | 123.123.123.123:1 | | | 2 | True | -| f_xxxxx_00023 | TERMINATED | 123.123.123.123:1 | | | 3 | True | -| f_xxxxx_00024 | TERMINATED | 123.123.123.123:1 | | | 4 | True | -| f_xxxxx_00025 | TERMINATED | 123.123.123.123:1 | | | 5 | True | -| f_xxxxx_00026 | TERMINATED | 123.123.123.123:1 | | | 6 | True | -| f_xxxxx_00027 | TERMINATED | 123.123.123.123:1 | | | 7 | True | -| f_xxxxx_00028 | TERMINATED | 123.123.123.123:1 | | | 8 | True | -| f_xxxxx_00029 | TERMINATED | 123.123.123.123:1 | | | 9 | True | -+---------------+------------+-------------------+-----+-----+-----+--------+""" # noqa ++---------------+------------+-------+-----+-----+-----+--------+ +| Trial name | status | loc | a | b | c | done | +|---------------+------------+-------+-----+-----+-----+--------| +| f_xxxxx_00000 | TERMINATED | | 0 | | | True | +| f_xxxxx_00001 | TERMINATED | | 1 | | | True | +| f_xxxxx_00002 | TERMINATED | | 2 | | | True | +| f_xxxxx_00003 | TERMINATED | | 3 | | | True | +| f_xxxxx_00004 | TERMINATED | | 4 | | | True | +| f_xxxxx_00005 | TERMINATED | | 5 | | | True | +| f_xxxxx_00006 | TERMINATED | | 6 | | | True | +| f_xxxxx_00007 | TERMINATED | | 7 | | | True | +| f_xxxxx_00008 | TERMINATED | | 8 | | | True | +| f_xxxxx_00009 | TERMINATED | | 9 | | | True | +| f_xxxxx_00010 | TERMINATED | | | 0 | | True | +| f_xxxxx_00011 | TERMINATED | | | 1 | | True | +| f_xxxxx_00012 | TERMINATED | | | 2 | | True | +| f_xxxxx_00013 | TERMINATED | | | 3 | | True | +| f_xxxxx_00014 | TERMINATED | | | 4 | | True | +| f_xxxxx_00015 | TERMINATED | | | 5 | | True | +| f_xxxxx_00016 | TERMINATED | | | 6 | | True | +| f_xxxxx_00017 | TERMINATED | | | 7 | | True | +| f_xxxxx_00018 | TERMINATED | | | 8 | | True | +| f_xxxxx_00019 | TERMINATED | | | 9 | | True | +| f_xxxxx_00020 | TERMINATED | | | | 0 | True | +| f_xxxxx_00021 | TERMINATED | | | | 1 | True | +| f_xxxxx_00022 | TERMINATED | | | | 2 | True | +| f_xxxxx_00023 | TERMINATED | | | | 3 | True | +| f_xxxxx_00024 | TERMINATED | | | | 4 | True | +| f_xxxxx_00025 | TERMINATED | | | | 5 | True | +| f_xxxxx_00026 | TERMINATED | | | | 6 | True | +| f_xxxxx_00027 | TERMINATED | | | | 7 | True | +| f_xxxxx_00028 | TERMINATED | | | | 8 | True | +| f_xxxxx_00029 | TERMINATED | | | | 9 | True | ++---------------+------------+-------+-----+-----+-----+--------+""" EXPECTED_END_TO_END_AC = """Number of trials: 30/30 (30 TERMINATED) +---------------+------------+-------+-----+-----+-----+ @@ -234,26 +217,15 @@ def f(config): Trial train_xxxxx_00002 reported acc=8 with parameters={'do': 'twice'}. """ + \ "This trial completed." -VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------------------+----------+ -| Trial name | status | loc | do | -|-------------------+----------+-------------------+----------| -| train_xxxxx_00000 | RUNNING | 123.123.123.123:1 | complete |""" +VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------+----------+ +| Trial name | status | loc | do | +|-------------------+----------+-------+----------| +| train_xxxxx_00000 | RUNNING | | complete |""" VERBOSE_CMD = """from ray import tune import random import numpy as np import time -from ray.tune.trial import Location -from ray.tune.progress_reporter import _get_trial_location -from unittest.mock import patch - - -def mock_get_trial_location(trial, result): - location = _get_trial_location(trial, result) - if location.pid: - return Location("123.123.123.123", "1") - return location - def train(config): if config["do"] == "complete": @@ -270,14 +242,11 @@ def train(config): random.seed(1234) np.random.seed(1234) - -with patch("ray.tune.progress_reporter._get_trial_location", - mock_get_trial_location): - tune.run( - train, - config={ - "do": tune.grid_search(["complete", "once", "twice"]) - },""" +tune.run( + train, + config={ + "do": tune.grid_search(["complete", "once", "twice"]) + },""" # Add "verbose=3)" etc @@ -455,27 +424,6 @@ def testProgressStr(self): best1 = best_trial_str(trials[1], "metric_1") assert best1 == EXPECTED_BEST_1 - def testTimeElapsed(self): - # Sun Feb 7 14:18:40 2016 -0800 - # (time of the first Ray commit) - time_start = 1454825920 - time_now = ( - time_start + 1 * 60 * 60 # 1 hour - + 31 * 60 # 31 minutes - + 22 # 22 seconds - ) # time to second commit - - # Local timezone output can be tricky, so we don't check the - # day and the hour in this test. - output = time_passed_str(time_start, time_now) - self.assertIn("Current time: 2016-02-", output) - self.assertIn(":50:02 (running for 01:31:22.00)", output) - - time_now += 2 * 60 * 60 * 24 # plus two days - output = time_passed_str(time_start, time_now) - self.assertIn("Current time: 2016-02-", output) - self.assertIn(":50:02 (running for 2 days, 01:31:22.00)", output) - def testCurrentBestTrial(self): trials = [] for i in range(5): diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index f5d87e7dd1926..a21664a2c11ee 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -2,7 +2,6 @@ import os import pytest -import time import unittest import ray @@ -12,7 +11,7 @@ from ray.tune.callback import Callback from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.registry import _global_registry, TRAINABLE_CLASS -from ray.tune.result import PID, TRAINING_ITERATION, TRIAL_ID +from ray.tune.result import TRAINING_ITERATION from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, Checkpoint from ray.tune.resources import Resources @@ -253,68 +252,6 @@ def reset_config(self, config): self.assertEqual(trial.experiment_tag, "modified_mock") self.assertEqual(Trial.RUNNING, trial.status) - def testForceTrialCleanup(self): - class B(Trainable): - def step(self): - print("Step start") - time.sleep(10) - print("Step done") - return dict(my_metric=1, timesteps_this_iter=1, done=True) - - def reset_config(self, config): - self.config = config - return True - - def cleanup(self): - print("Cleanup start") - time.sleep(10) - print("Cleanup done") - - # First check if the trials terminate gracefully by default - trials = self.generate_trials({ - "run": B, - "config": { - "foo": 0 - }, - }, "grid_search") - trial = trials[0] - self.trial_executor.start_trial(trial) - self.assertEqual(Trial.RUNNING, trial.status) - time.sleep(5) - print("Stop trial") - self.trial_executor.stop_trial(trial) - print("Start trial cleanup") - start = time.time() - self.trial_executor.cleanup([trial]) - self.assertGreaterEqual(time.time() - start, 12.0) - - # Check forceful termination. It should run for much less than the - # sleep periods in the Trainable - trials = self.generate_trials({ - "run": B, - "config": { - "foo": 0 - }, - }, "grid_search") - trial = trials[0] - os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1" - self.trial_executor = RayTrialExecutor(queue_trials=False) - os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0" - self.trial_executor.start_trial(trial) - self.assertEqual(Trial.RUNNING, trial.status) - time.sleep(5) - print("Stop trial") - self.trial_executor.stop_trial(trial) - print("Start trial cleanup") - start = time.time() - self.trial_executor.cleanup([trial]) - self.assertLess(time.time() - start, 5.0) - - # also check if auto-filled metrics were returned - self.assertIn(PID, trial.last_result) - self.assertIn(TRIAL_ID, trial.last_result) - self.assertNotIn("my_metric", trial.last_result) - @staticmethod def generate_trials(spec, name): suggester = BasicVariantGenerator() @@ -543,10 +480,6 @@ def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects - def testForceTrialCleanup(self): - self.skipTest("Skipping as force trial cleanup is not applicable" - " for local mode.") - if __name__ == "__main__": import sys diff --git a/python/ray/tune/tests/test_trial_runner_3.py b/python/ray/tune/tests/test_trial_runner_3.py index 44341ebf99cf6..e467eafa5e51e 100644 --- a/python/ray/tune/tests/test_trial_runner_3.py +++ b/python/ray/tune/tests/test_trial_runner_3.py @@ -555,8 +555,7 @@ def testTrialNoSave(self): self.assertTrue( runner2.get_trial("checkpoint").status == Trial.TERMINATED) self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING) - self.assertTrue( - not runner2.get_trial("pending").has_reported_at_least_once) + self.assertTrue(not runner2.get_trial("pending").last_result) runner2.step() def testCheckpointWithFunction(self): diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index f9cf300948ea6..16f40b7602712 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -154,7 +154,7 @@ def testCallbackSteps(self): result = {TRAINING_ITERATION: 1, "metric": 800, "done": False} self.executor.results[trials[1]] = result self.executor.next_trial = trials[1] - self.assertTrue(not trials[1].has_reported_at_least_once) + self.assertEqual(trials[1].last_result, {}) self.trial_runner.step() self.assertEqual(self.callback.state["trial_result"]["iteration"], 3) self.assertEqual(self.callback.state["trial_result"]["trial"].trial_id, diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 798c08192ab21..0e0a2dd65c701 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -847,7 +847,6 @@ def __init__(self, i, config): self.resources = Resources(1, 0) self.custom_trial_name = None self.custom_dirname = None - self._default_result_or_future = None def on_checkpoint(self, checkpoint): self.restored_checkpoint = checkpoint.value diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index 31a8f02132101..81b90dcfeebf2 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -192,13 +192,8 @@ def MockTrainingFuncSync(config, checkpoint_dir=None): "checkpoint") with open(checkpoint_path, "wb") as fp: pickle.dump((a, iter), fp) - # Different sleep times so that asynch test runs do not - # randomly succeed. If well performing trials finish later, - # then bad performing trials will already have continued - # to train, which is exactly what we want to test when - # comparing sync vs. async. - time.sleep(a / 20) # Score gets better every iteration. + time.sleep(1) tune.report(mean_accuracy=iter + a, a=a) self.MockTrainingFuncSync = MockTrainingFuncSync @@ -206,10 +201,7 @@ def MockTrainingFuncSync(config, checkpoint_dir=None): def tearDown(self): ray.shutdown() - def synchSetup(self, synch, param=None): - if param is None: - param = [10, 20, 30] - + def synchSetup(self, synch, param=[10, 20, 30]): scheduler = PopulationBasedTraining( time_attr="training_iteration", metric="mean_accuracy", diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 3299d7aa4e861..7e63147ca4e00 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -8,18 +8,17 @@ import sys import tempfile import time -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Union import uuid import ray import ray.cloudpickle as pickle from ray.tune.resources import Resources from ray.tune.result import ( - DEBUG_METRICS, DEFAULT_RESULTS_DIR, HOSTNAME, NODE_IP, PID, - SHOULD_CHECKPOINT, TIME_THIS_ITER_S, TIME_TOTAL_S, TIMESTEPS_THIS_ITER, - DONE, TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, - TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_ID, TRIAL_INFO, STDOUT_FILE, - STDERR_FILE) + DEFAULT_RESULTS_DIR, SHOULD_CHECKPOINT, TIME_THIS_ITER_S, + TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL, EPISODES_THIS_ITER, + EPISODES_TOTAL, TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_INFO, + STDOUT_FILE, STDERR_FILE) from ray.tune.utils import UtilMonitor from ray.tune.utils.placement_groups import PlacementGroupFactory from ray.tune.utils.trainable import TrainableUtil @@ -155,40 +154,6 @@ def get_current_ip(self): self._local_ip = ray.util.get_node_ip_address() return self._local_ip - def get_auto_filled_metrics(self, - now: Optional[datetime] = None, - time_this_iter: Optional[float] = None, - debug_metrics_only: bool = False) -> dict: - """Return a dict with metrics auto-filled by the trainable. - - If ``debug_metrics_only`` is True, only metrics that don't - require at least one iteration will be returned - (``ray.tune.result.DEBUG_METRICS``). - """ - if now is None: - now = datetime.today() - autofilled = { - TRIAL_ID: self.trial_id, - "experiment_id": self._experiment_id, - "date": now.strftime("%Y-%m-%d_%H-%M-%S"), - "timestamp": int(time.mktime(now.timetuple())), - TIME_THIS_ITER_S: time_this_iter, - TIME_TOTAL_S: self._time_total, - PID: os.getpid(), - HOSTNAME: platform.node(), - NODE_IP: self._local_ip, - "config": self.config, - "time_since_restore": self._time_since_restore, - "timesteps_since_restore": self._timesteps_since_restore, - "iterations_since_restore": self._iterations_since_restore - } - if debug_metrics_only: - autofilled = { - k: v - for k, v in autofilled.items() if k in DEBUG_METRICS - } - return autofilled - def is_actor(self): try: actor_id = ray.worker.global_worker.actor_id @@ -324,7 +289,19 @@ def train(self): result.setdefault("neg_mean_loss", -result["mean_loss"]) now = datetime.today() - result.update(self.get_auto_filled_metrics(now, time_this_iter)) + result.update( + experiment_id=self._experiment_id, + date=now.strftime("%Y-%m-%d_%H-%M-%S"), + timestamp=int(time.mktime(now.timetuple())), + time_this_iter_s=time_this_iter, + time_total_s=self._time_total, + pid=os.getpid(), + hostname=platform.node(), + node_ip=self._local_ip, + config=self.config, + time_since_restore=self._time_since_restore, + timesteps_since_restore=self._timesteps_since_restore, + iterations_since_restore=self._iterations_since_restore) monitor_data = self._monitor.get_data() if monitor_data: diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index ede51f26ba5b1..6398b53f2292f 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -8,20 +8,19 @@ import re import shutil import time -from typing import Callable, Dict, Optional, Sequence, Union +from typing import Callable, Dict, Sequence, Union import uuid import ray import ray.cloudpickle as cloudpickle -from ray.exceptions import GetTimeoutError, RayActorError +from ray.exceptions import GetTimeoutError from ray.tune import TuneError from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not # have been defined yet. See https://github.com/ray-project/ray/issues/1716. from ray.tune.registry import get_trainable_cls, validate_trainable -from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, NODE_IP, PID, - TRAINING_ITERATION, TRIAL_ID) +from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION from ray.tune.resources import Resources, \ json_to_resources, resources_to_json from ray.tune.utils.placement_groups import PlacementGroupFactory, \ @@ -300,9 +299,7 @@ def __init__(self, self.max_failures = max_failures # Local trial state that is updated during the run - self._last_result = {} - self._default_result_or_future: Union[ray.ObjectRef, dict, None] = ( - None) + self.last_result = {} self.last_update_time = -float("inf") # stores in memory max/min/avg/last-n-avg/last result for each @@ -397,52 +394,6 @@ def _setup_resources(self, log_always: bool = False): resource_kwargs["has_placement_group"] = True self.resources = Resources(**resource_kwargs) - def _get_default_result_or_future(self) -> Optional[dict]: - """Calls ray.get on self._default_result_or_future and assigns back. - - Returns None in case of exceptions. - Will also set the trial location if runner is set. - """ - if self._default_result_or_future and isinstance( - self._default_result_or_future, ray.ObjectRef): - try: - self._default_result_or_future = ray.get( - self._default_result_or_future) - except RayActorError: # error during initialization - self._default_result_or_future = None - if self._default_result_or_future and self.runner: - self.set_location( - Location( - self._default_result_or_future.get(NODE_IP), - self._default_result_or_future.get(PID))) - return self._default_result_or_future - - @property - def last_result(self) -> dict: - # The logic in here is as follows: - # 1. If the trial has reported at least once, last_result would have - # been set and therefore would not be empty. We can just return it. - # 2. If the trial has not reported at least once but we have the - # future for the default results dict, (obtained through - # Trainable.get_auto_filled_metrics), we get that future - # and return it. - # 3. In the worst case where we have nothing, we just set the - # trial_id and return that. - result = self._last_result - if not {k for k in result if k != TRIAL_ID}: - self._get_default_result_or_future() - result = self._default_result_or_future or result - result.setdefault(TRIAL_ID, self.trial_id) - return result - - @last_result.setter - def last_result(self, val: dict): - self._last_result = val - - @property - def has_reported_at_least_once(self) -> bool: - return bool(self._last_result) - @property def node_ip(self): return self.location.hostname @@ -548,11 +499,6 @@ def update_resources( def set_runner(self, runner): self.runner = runner - if runner: - # Do not block here, the result will be gotten when last_result - # property is accessed - self._default_result_or_future = ( - runner.get_auto_filled_metrics.remote(debug_metrics_only=True)) self.checkpoint_manager.delete = CheckpointDeleter( self._trainable_name(), runner, self.node_ip) # No need to invalidate state cache: runner is not stored in json @@ -657,7 +603,7 @@ def update_last_result(self, result, terminate=False): if self.experiment_tag: result.update(experiment_tag=self.experiment_tag) - self.set_location(Location(result.get(NODE_IP), result.get(PID))) + self.set_location(Location(result.get("node_ip"), result.get("pid"))) self.last_result = result self.last_update_time = time.time() @@ -783,7 +729,6 @@ def __getstate__(self): state["_state_json"] = None state["_state_valid"] = False - state["_default_result_or_future"] = None return copy.deepcopy(state) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index f4ce4ea70d001..0d91ee3b8bc65 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -15,9 +15,8 @@ from ray.tune.callback import CallbackList from ray.tune.stopper import NoopStopper from ray.tune.ray_trial_executor import RayTrialExecutor -from ray.tune.result import (DEBUG_METRICS, DEFAULT_METRIC, DONE, - TIME_THIS_ITER_S, RESULT_DUPLICATE, - SHOULD_CHECKPOINT) +from ray.tune.result import (DEFAULT_METRIC, TIME_THIS_ITER_S, + RESULT_DUPLICATE, SHOULD_CHECKPOINT) from ray.tune.syncer import CloudSyncer, get_cloud_syncer from ray.tune.trial import Checkpoint, Trial from ray.tune.schedulers import FIFOScheduler, TrialScheduler @@ -196,9 +195,7 @@ class TrialRunner: """ CKPT_FILE_TMPL = "experiment_state-{}.json" - VALID_RESUME_TYPES = [ - True, "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY", "AUTO" - ] + VALID_RESUME_TYPES = [True, "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY"] RAISE = "RAISE" def __init__(self, @@ -418,7 +415,7 @@ def _validate_resume(self, resume_type): Args: resume_type: One of True, "REMOTE", "LOCAL", - "PROMPT", "ERRORED_ONLY", "AUTO". + "PROMPT", "ERRORED_ONLY". """ # TODO: Consider supporting ERRORED_ONLY+REMOTE? if not resume_type: @@ -429,54 +426,11 @@ def _validate_resume(self, resume_type): # Not clear if we need this assertion, since we should always have a # local checkpoint dir. assert self._local_checkpoint_dir or self._remote_checkpoint_dir - - if resume_type == "AUTO": - if self._remote_checkpoint_dir: - logger.info( - f"Trying to find and download experiment checkpoint at " - f"{self._remote_checkpoint_dir}") - # Todo: This syncs the entire experiment including trial - # checkpoints. We should exclude these in the future. - try: - self._syncer.sync_down_if_needed() - self._syncer.wait() - except TuneError as e: - logger.warning( - f"Got error when trying to sync down: {e} " - f"\nPlease check this error message for potential " - f"access problems - if a directory was not found, " - f"that is expected at this stage when you're starting " - f"a new experiment.") - logger.info( - "No remote checkpoint was found or an error occurred " - "when trying to download the experiment checkpoint. " - "Please check the previous warning message for more " - "details. " - "Ray Tune will now start a new experiment.") - return False - logger.info( - "A remote experiment checkpoint was found and will be " - "used to restore the previous experiment state.") - return True - elif not self.checkpoint_exists(self._local_checkpoint_dir): - logger.info("No local checkpoint was found. " - "Ray Tune will now start a new experiment.") - return False - logger.info( - "A local experiment checkpoint was found and will be used " - "to restore the previous experiment state.") - return True - if resume_type in [True, "LOCAL", "PROMPT", "ERRORED_ONLY"]: if not self.checkpoint_exists(self._local_checkpoint_dir): raise ValueError( - f"You called resume ({resume_type}) when no checkpoint " - f"exists in local directory " - f"({self._local_checkpoint_dir}). If you want to start " - f"a new experiment, use `resume=\"AUTO\"` or " - f"`resume=None`. If you expected an experiment to " - f"already exist, check if you supplied the correct " - f"`local_dir` to `tune.run()`.") + f"Called resume ({resume_type}) when no checkpoint exists " + f"in local directory ({self._local_checkpoint_dir}).") elif resume_type == "PROMPT": if click.confirm(f"Resume from local directory? " f"({self._local_checkpoint_dir})"): @@ -494,22 +448,12 @@ def _validate_resume(self, resume_type): "`upload_dir` set to `tune.run(sync_config=...)`.") # Try syncing down the upload directory. - logger.info(f"Downloading experiment checkpoint from " - f"{self._remote_checkpoint_dir}") - # Todo: This syncs the entire experiment including trial - # checkpoints. We should exclude these in the future. - try: - self._syncer.sync_down_if_needed() - self._syncer.wait() - except TuneError as e: - raise RuntimeError( - "Syncing the remote experiment checkpoint to the driver " - "failed. Please check the error message. If you want to " - "start a new experiment, use `resume=\"AUTO\"` or " - "`resume=None`. If you expected an experiment to " - "already exist, check if you supplied the correct " - "`upload_dir` to the `tune.SyncConfig` passed to " - "`tune.run()`.") from e + logger.info("Downloading from %s", self._remote_checkpoint_dir) + # TODO(ujvl): Note that this syncs down the entire directory, + # which may also contain trial checkpoints. We should selectively + # sync the necessary files instead. + self._syncer.sync_down_if_needed() + self._syncer.wait() if not self.checkpoint_exists(self._local_checkpoint_dir): raise ValueError("Called resume when no checkpoint exists " @@ -919,8 +863,6 @@ def _process_trial_result(self, trial, result): flat_result = flatten_dict(result) self._validate_result_metrics(flat_result) - _trigger_callback_complete = False - if self._stopper(trial.trial_id, result) or trial.should_stop(flat_result): result.update(done=True) @@ -940,7 +882,8 @@ def _process_trial_result(self, trial, result): trial=trial, result=result.copy()) - _trigger_callback_complete = True + self._callbacks.on_trial_complete( + iteration=self._iteration, trials=self._trials, trial=trial) decision = TrialScheduler.STOP else: with warn_if_slow("scheduler.on_trial_result"): @@ -976,10 +919,6 @@ def _process_trial_result(self, trial, result): # the global checkpoint state. self._checkpoint_trial_if_needed(trial, force=force_checkpoint) - if _trigger_callback_complete: - self._callbacks.on_trial_complete( - iteration=self._iteration, trials=self._trials, trial=trial) - if trial.is_saving: # Cache decision to execute on after the save is processed. # This prevents changing the trial's state or kicking off @@ -993,18 +932,15 @@ def _process_trial_result(self, trial, result): def _validate_result_metrics(self, result): """ Check if any of the required metrics was not reported - in the last result. If the only items are ``done`` or any of - DEBUG_METRICS, this means that no result was ever received and - the trial just returned. This is also okay and will not raise - an error. + in the last result. If the only item is `done=True`, this + means that no result was ever received and the trial just + returned. This is also okay and will not raise an error. This will ignore checking for the DEFAULT_METRIC. """ - if int(os.environ.get( - "TUNE_DISABLE_STRICT_METRIC_CHECKING", 0)) != 1 and (len({ - k - for k in result if k not in list(DEBUG_METRICS) + [DONE] - }) > 1): + if int(os.environ.get("TUNE_DISABLE_STRICT_METRIC_CHECKING", + 0)) != 1 and (len(result) > 1 + or "done" not in result): base_metric = self._metric \ if self._metric != DEFAULT_METRIC else None scheduler_metric = self._scheduler_alg.metric \ diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index a3553879633a9..8077f7c6e6cd2 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -11,15 +11,13 @@ import ray from ray.util.annotations import PublicAPI -from ray.util.queue import Queue, Empty from ray.tune.analysis import ExperimentAnalysis from ray.tune.callback import Callback from ray.tune.error import TuneError from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.logger import Logger -from ray.tune.progress_reporter import (detect_reporter, ProgressReporter, - JupyterNotebookReporter) +from ray.tune.progress_reporter import detect_reporter, ProgressReporter from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.registry import get_trainable_cls from ray.tune.stopper import Stopper @@ -316,48 +314,7 @@ def run( # Make sure tune.run is called on the sever node. remote_run = force_on_current_node(remote_run) - # JupyterNotebooks don't work with remote tune runs out of the box - # (e.g. via Ray client) as they don't have access to the main - # process stdout. So we introduce a queue here that accepts - # callables, which will then be executed on the driver side. - if isinstance(progress_reporter, JupyterNotebookReporter): - execute_queue = Queue(actor_options={ - "num_cpus": 0, - **force_on_current_node(None) - }) - progress_reporter.set_output_queue(execute_queue) - - def get_next_queue_item(): - try: - return execute_queue.get(block=False) - except Empty: - return None - - else: - # If we don't need a queue, use this dummy get fn instead of - # scheduling an unneeded actor - def get_next_queue_item(): - return None - - def _handle_execute_queue(): - execute_item = get_next_queue_item() - while execute_item: - if isinstance(execute_item, Callable): - execute_item() - - execute_item = get_next_queue_item() - - remote_future = remote_run.remote(_remote=False, **remote_run_kwargs) - - # ray.wait(...)[1] returns futures that are not ready, yet - while ray.wait([remote_future], timeout=0.2)[1]: - # Check if we have items to execute - _handle_execute_queue() - - # Handle queue one last time - _handle_execute_queue() - - return ray.get(remote_future) + return ray.get(remote_run.remote(_remote=False, **remote_run_kwargs)) del remote_run_kwargs @@ -384,34 +341,8 @@ def _handle_execute_queue(): if num_samples == -1: num_samples = sys.maxsize - result_buffer_length = None - - # Create scheduler here as we need access to some of its properties - if isinstance(scheduler, str): - # importing at top level causes a recursive dependency - from ray.tune.schedulers import create_scheduler - scheduler = create_scheduler(scheduler) - scheduler = scheduler or FIFOScheduler() - - if scheduler.supports_buffered_results: - # Result buffering with a Hyperband scheduler is a bad idea, as - # hyperband tries to stop trials when processing brackets. With result - # buffering, we might trigger this multiple times when evaluating - # a single trial, which leads to unexpected behavior. - env_result_buffer_length = os.getenv("TUNE_RESULT_BUFFER_LENGTH", "") - if env_result_buffer_length: - warnings.warn( - f"You are using a {type(scheduler)} scheduler, but " - f"TUNE_RESULT_BUFFER_LENGTH is set " - f"({env_result_buffer_length}). This can lead to undesired " - f"and faulty behavior, so the buffer length was forcibly set " - f"to 1 instead.") - result_buffer_length = 1 - trial_executor = trial_executor or RayTrialExecutor( - reuse_actors=reuse_actors, - queue_trials=queue_trials, - result_buffer_length=result_buffer_length) + reuse_actors=reuse_actors, queue_trials=queue_trials) if isinstance(run_or_experiment, list): experiments = run_or_experiment else: @@ -464,6 +395,11 @@ def _handle_execute_queue(): if is_local_mode: max_concurrent_trials = 1 + if isinstance(scheduler, str): + # importing at top level causes a recursive dependency + from ray.tune.schedulers import create_scheduler + scheduler = create_scheduler(scheduler) + if not search_alg: search_alg = BasicVariantGenerator( max_concurrent=max_concurrent_trials or 0) @@ -511,6 +447,7 @@ def _handle_execute_queue(): "does not contain any more parameter definitions - include " "them in the search algorithm's search space if necessary.") + scheduler = scheduler or FIFOScheduler() if not scheduler.set_search_properties(metric, mode): raise ValueError( "You passed a `metric` or `mode` argument to `tune.run()`, but " @@ -594,7 +531,6 @@ def sigint_handler(sig, frame): signal.signal(signal.SIGINT, sigint_handler) tune_start = time.time() - progress_reporter.set_start_time(tune_start) while not runner.is_finished() and not state[signal.SIGINT]: runner.step() if has_verbosity(Verbosity.V1_EXPERIMENT): diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index c41179c43b845..0f4612c66c047 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -645,15 +645,14 @@ def get_current_node_resource_key() -> str: raise ValueError("Cannot found the node dictionary for current node.") -def force_on_current_node(task_or_actor=None): +def force_on_current_node(task_or_actor): """Given a task or actor, place it on the current node. If using Ray Client, the current node is the client server node. Args: task_or_actor: A Ray remote function or class to place on the - current node. If None, returns the options dict to pass to - another actor. + current node. Returns: The provided task or actor, but with options modified to force @@ -661,10 +660,6 @@ def force_on_current_node(task_or_actor=None): """ node_resource_key = get_current_node_resource_key() options = {"resources": {node_resource_key: 0.01}} - - if task_or_actor is None: - return options - return task_or_actor.options(**options) diff --git a/python/ray/util/__init__.py b/python/ray/util/__init__.py index 7a326154bf1e3..3177925e68fc0 100644 --- a/python/ray/util/__init__.py +++ b/python/ray/util/__init__.py @@ -19,7 +19,7 @@ @PublicAPI(stability="beta") -@client_mode_hook(auto_init=True) +@client_mode_hook def list_named_actors(all_namespaces: bool = False) -> List[str]: """List all named actors in the system. diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index 29b95c850c2a4..f11b692d56f42 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -5,6 +5,7 @@ import os import sys import logging +import json import threading import grpc @@ -65,7 +66,7 @@ def connect(self, job_config = job_config or JobConfig() job_config.set_ray_namespace(namespace) if job_config is not None: - runtime_env = job_config.runtime_env + runtime_env = json.loads(job_config.get_serialized_runtime_env()) if runtime_env.get("pip") or runtime_env.get("conda"): logger.warning("The 'pip' or 'conda' field was specified in " "the runtime env, so it may take some time to " diff --git a/python/ray/util/client/client_pickler.py b/python/ray/util/client/client_pickler.py index 0faf3c99c68cd..9c1ebef68d565 100644 --- a/python/ray/util/client/client_pickler.py +++ b/python/ray/util/client/client_pickler.py @@ -49,17 +49,12 @@ else: import pickle # noqa: F401 - # NOTE(barakmich): These PickleStubs are really close to -# the data for an execution, with no arguments. Combine the two? -class PickleStub( - NamedTuple("PickleStub", [("type", str), ("client_id", str), - ("ref_id", bytes), ("name", Optional[str]), - ("baseline_options", Optional[Dict])])): - def __reduce__(self): - # PySpark's namedtuple monkey patch breaks compatibility with - # cloudpickle. Thus we revert this patch here if it exists. - return object.__reduce__(self) +# the data for an exectuion, with no arguments. Combine the two? +PickleStub = NamedTuple("PickleStub", + [("type", str), ("client_id", str), ("ref_id", bytes), + ("name", Optional[str]), + ("baseline_options", Optional[Dict])]) class ClientPickler(cloudpickle.CloudPickler): diff --git a/python/ray/util/client/options.py b/python/ray/util/client/options.py index ec6c568d5b347..9c9df946d0cf5 100644 --- a/python/ray/util/client/options.py +++ b/python/ray/util/client/options.py @@ -36,6 +36,7 @@ "placement_group_bundle_index": (), "placement_group_capture_child_tasks": (), "runtime_env": (), + "override_environment_variables": (), } diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 98ad26c93d8b4..0fb2f07429b1d 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -27,10 +27,10 @@ from ray.util.client.server.dataservicer import _get_reconnecting_from_context from ray._private.client_mode_hook import disable_client_hook from ray._private.parameter import RayParams -from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env import RuntimeEnvContext from ray._private.services import ProcessInfo, start_ray_client_server -from ray._private.utils import detect_fate_sharing_support -from ray._private.tls_utils import add_port_to_grpc_server +from ray._private.utils import (detect_fate_sharing_support, + add_port_to_grpc_server) # Import psutil after ray so the packaged version is used. import psutil @@ -264,9 +264,7 @@ def start_specific_server(self, client_id: str, f"ray_client_server_{specific_server.port}", unique=True) serialized_runtime_env = job_config.get_serialized_runtime_env() - if not serialized_runtime_env or serialized_runtime_env == "{}": - # TODO(edoakes): can we just remove this case and always send it - # to the agent? + if serialized_runtime_env == "{}": serialized_runtime_env_context = RuntimeEnvContext().serialize() else: serialized_runtime_env_context = self._create_runtime_env( diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 27a10d18e3b11..351b981d0a17c 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -35,7 +35,7 @@ from ray.ray_constants import env_integer from ray.util.placement_group import PlacementGroup from ray._private.client_mode_hook import disable_client_hook -from ray._private.tls_utils import add_port_to_grpc_server +from ray._private.utils import add_port_to_grpc_server logger = logging.getLogger(__name__) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 4b45ac0c761ee..b5c50215c5488 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -128,9 +128,6 @@ def __init__( self._connect_channel() self._has_connected = True - # Has Ray been initialized on the server? - self._serverside_ray_initialized = False - # Initialize the streams to finish protocol negotiation. self.data_client = DataClient(self, self._client_id, self.metadata) self.reference_count: Dict[bytes, int] = defaultdict(int) @@ -362,8 +359,8 @@ def get(self, vals, *, timeout: Optional[float] = None) -> Any: logger.debug("Internal retry for get {}".format(to_get)) if len(to_get) != len(res): raise Exception( - "Mismatched number of items in request ({}) and response ({})". - format(len(to_get), len(res))) + "Mismatched number of items in request ({}) and response ({})" + .format(len(to_get), len(res))) if isinstance(vals, ClientObjectRef): res = res[0] return res @@ -650,17 +647,10 @@ def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]: return json.loads(self.data_client.ListNamedActors(req).actors_json) def is_initialized(self) -> bool: - if not self.is_connected() or self.server is None: - return False - if not self._serverside_ray_initialized: - # We only check that Ray is initialized on the server once to - # avoid making an RPC every time this function is called. This is - # safe to do because Ray only 'un-initializes' on the server when - # the Client connection is torn down. - self._serverside_ray_initialized = self.get_cluster_info( + if self.server is not None: + return self.get_cluster_info( ray_client_pb2.ClusterInfoType.IS_INITIALIZED) - - return self._serverside_ray_initialized + return False def ping_server(self, timeout=None) -> bool: """Simple health check. diff --git a/python/ray/util/dask/scheduler_utils.py b/python/ray/util/dask/scheduler_utils.py index dba0c660b0c5b..a1805048c989b 100644 --- a/python/ray/util/dask/scheduler_utils.py +++ b/python/ray/util/dask/scheduler_utils.py @@ -371,11 +371,8 @@ def fire_task(): return nested_get(result, state["cache"]) -def apply_sync(func, args=(), kwds=None, callback=None): +def apply_sync(func, args=(), kwds={}, callback=None): """ A naive synchronous version of apply_async """ - if kwds is None: - kwds = {} - res = func(*args, **kwds) if callback is not None: callback(res) diff --git a/python/ray/util/placement_group.py b/python/ray/util/placement_group.py index 933695ea0fbe1..43741556f54e1 100644 --- a/python/ray/util/placement_group.py +++ b/python/ray/util/placement_group.py @@ -25,7 +25,7 @@ def _export_bundle_reservation_check_method_if_needed(): if bundle_reservation_check: return - @ray.remote(num_cpus=0) + @ray.remote(num_cpus=0, max_calls=0) def bundle_reservation_check_func(placement_group): return placement_group @@ -307,7 +307,7 @@ def get_current_placement_group() -> Optional[PlacementGroup]: None if the current task or actor wasn't created with any placement group. """ - if client_mode_should_convert(auto_init=True): + if client_mode_should_convert(): # Client mode is only a driver. return None worker = ray.worker.global_worker diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 90b8c0adb44cc..77bf9e1454ea9 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -5,6 +5,7 @@ import ray import torch +from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS from ray.util.sgd import utils from ray.util.sgd.torch.utils import choose_amp_backend @@ -62,7 +63,6 @@ def setup_operator(self): world_rank=0, local_rank=0, is_distributed=False, - device=None, use_gpu=self.use_gpu, use_fp16=self.use_fp16, use_tqdm=self.use_tqdm, @@ -121,6 +121,11 @@ def train_epoch(self, info = info or {} self._toggle_profiling(profile=profile) + info.update({ + NUM_STEPS: num_steps, + USE_FP16: self.use_fp16, + "epoch_idx": self.epochs, + }) with self.timers.record("train_epoch"): if iterator is not None: # Dataset will provide us with a list of tuples but we @@ -136,11 +141,7 @@ def format_batch(batch): else: iterator = self.make_iterator( training=True, num_steps=num_steps) - train_stats = self.training_operator.train_epoch( - iterator, - info=info, - num_steps=num_steps, - epoch_idx=self.epochs) + train_stats = self.training_operator.train_epoch(iterator, info) # This is so that `epochs` is first in ordering. stats = dict(epoch=self.epochs, **train_stats) @@ -150,6 +151,7 @@ def format_batch(batch): def validate(self, num_steps=None, profile=False, info=None): """Evaluates the model on the validation data set.""" + info = info or {} self._toggle_profiling(profile=profile) with self.timers.record("validation"): diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index 3a37436c43e92..7143d5c558fd0 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -11,8 +11,11 @@ from ray.util.annotations import PublicAPI from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection, NUM_SAMPLES) -from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS, - SCHEDULER_STEP_BATCH, USE_FP16) +from ray.util.sgd.torch.constants import ( + SCHEDULER_STEP_EPOCH, + NUM_STEPS, + SCHEDULER_STEP_BATCH, +) from ray.util.sgd.torch.utils import choose_amp_backend from torch.nn.parallel import DistributedDataParallel @@ -128,15 +131,14 @@ def __init__(self, config, world_rank, local_rank, - is_distributed, - use_gpu, - device, + is_distributed=False, + device=None, + use_gpu=False, use_fp16=False, use_tqdm=False, wrap_ddp=False, add_dist_sampler=False, scheduler_step_freq=None): - # You are not expected to override this method. self._world_rank = world_rank self._local_rank = local_rank @@ -454,7 +456,7 @@ def should_wrap_dataloader(loader): self._validation_loader = with_sampler( self._validation_loader) - def train_epoch(self, iterator, info=None, num_steps=None, epoch_idx=0): + def train_epoch(self, iterator, info): """Runs one standard training pass over the training dataloader. By default, this method will iterate over the given iterator and @@ -487,10 +489,8 @@ def train_epoch(self, ...): Args: iterator (iter): Iterator over the training data for the entire epoch. This iterator is expected to be entirely consumed. - info (Optional[dict]): Dictionary for information to be used for - custom training operations. - num_steps (Optional[int]): Number of steps in the iterator. - epoch_idx (int): Index of current epoch. + info (dict): Dictionary for information to be used for custom + training operations. Returns: A dict of metrics from training. @@ -499,14 +499,6 @@ def train_epoch(self, ...): raise RuntimeError("Either set self.model in setup function or " "override this method to implement a custom " "training loop.") - - info = info or {} - - info.update({ - NUM_STEPS: num_steps, - USE_FP16: self.use_fp16, - "epoch_idx": epoch_idx - }) model = self.model scheduler = None if hasattr(self, "scheduler"): @@ -644,7 +636,7 @@ def train_batch(self, batch, batch_info): return {"train_loss": loss.item(), NUM_SAMPLES: target.size(0)} - def validate(self, val_iterator, info=None): + def validate(self, val_iterator, info): """Runs one standard validation pass over the val_iterator. This will call ``model.eval()`` and ``torch.no_grad`` when iterating @@ -656,8 +648,8 @@ def validate(self, val_iterator, info=None): Args: val_iterator (iter): Iterable constructed from the validation dataloader. - info: (Optional[dict]): Dictionary for information to be used for - custom validation operations. + info: (dict): Dictionary for information to be used for custom + validation operations. Returns: A dict of metrics from the evaluation. @@ -670,8 +662,6 @@ def validate(self, val_iterator, info=None): raise RuntimeError("Either set self.model in setup function or " "override this method to implement a custom " "validation loop.") - - info = info or {} model = self.model metric_meters = AverageMeterCollection() @@ -1161,13 +1151,13 @@ def schedulers(self): def get_test_operator(operator_cls): class _TestingOperator(operator_cls): - def train_epoch(self, iterator, info, **kwargs): + def train_epoch(self, iterator, info): func = self.config.get("custom_func") if callable(func): return func(self, iterator, info) return {"done": 1} - def validate(self, iterator, info, **kwargs): + def validate(self, iterator, info): return self.train_epoch(iterator, info) return _TestingOperator diff --git a/python/ray/util/sgd/v2/BUILD b/python/ray/util/sgd/v2/BUILD index 7081a53b75591..1f3bb55976689 100644 --- a/python/ray/util/sgd/v2/BUILD +++ b/python/ray/util/sgd/v2/BUILD @@ -24,16 +24,6 @@ py_test( "--max_train_steps=2", "--start_local", "--num_workers=2"] ) -py_test( - name = "tune_cifar_pytorch_pbt_example", - size = "medium", - main = "examples/tune_cifar_pytorch_pbt_example.py", - srcs = ["examples/tune_cifar_pytorch_pbt_example.py"], - tags = ["team:ml", "exclusive", "pytorch"], - deps = [":sgd_v2_lib"], - args = ["--smoke-test"] -) - py_test( name = "tune_linear_example", size = "medium", @@ -57,14 +47,6 @@ py_test( deps = [":sgd_v2_lib"] ) -py_test( - name = "test_gpu", - size = "medium", - srcs = ["tests/test_gpu.py"], - tags = ["team:ml", "exclusive", "gpu_only"], - deps = [":sgd_v2_lib"] -) - py_test( name = "test_session", size = "small", @@ -89,15 +71,6 @@ py_test( deps = [":sgd_v2_lib"] ) -py_test( - name = "test_utils", - size = "small", - srcs = ["tests/test_utils.py"], - tags = ["team:ml", "exclusive"], - deps = [":sgd_v2_lib"] -) - - py_test( name = "test_worker_group", size = "medium", diff --git a/python/ray/util/sgd/v2/__init__.py b/python/ray/util/sgd/v2/__init__.py index 8fb122c160345..49d68ce97309d 100644 --- a/python/ray/util/sgd/v2/__init__.py +++ b/python/ray/util/sgd/v2/__init__.py @@ -8,6 +8,6 @@ __all__ = [ "BackendConfig", "CheckpointStrategy", "HorovodConfig", "load_checkpoint", - "local_rank", "report", "save_checkpoint", "SGDIterator", - "TensorflowConfig", "SGDCallback", "TorchConfig", "Trainer", "world_rank" + "local_rank", "report", "save_checkpoint", "SGDCallback", "SGDIterator", + "TensorflowConfig", "TorchConfig", "Trainer", "world_rank" ] diff --git a/python/ray/util/sgd/v2/backends/backend.py b/python/ray/util/sgd/v2/backends/backend.py index 4feec51a5eb08..24d8b59f1e413 100644 --- a/python/ray/util/sgd/v2/backends/backend.py +++ b/python/ray/util/sgd/v2/backends/backend.py @@ -12,7 +12,7 @@ from ray.util.sgd.v2.checkpoint import CheckpointStrategy from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \ TUNE_INSTALLED, TUNE_CHECKPOINT_FILE_NAME, \ - TUNE_CHECKPOINT_ID, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV + TUNE_CHECKPOINT_ID from ray.util.sgd.v2.session import TrainingResultType, TrainingResult from ray.util.sgd.v2.session import init_session, get_session, shutdown_session from ray.util.sgd.v2.utils import construct_path, check_for_failure @@ -275,21 +275,15 @@ def start(self, if initialization_hook: self._initialization_hook = initialization_hook self.worker_group.execute(initialization_hook) - - share_cuda_visible_devices_enabled = bool( - env_integer(ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, - self._backend.share_cuda_visible_devices)) - - if (self._num_gpus_per_worker > 0 - and share_cuda_visible_devices_enabled): - self._share_cuda_visible_devices() + if self._num_gpus_per_worker > 0: + self._setup_gpus() self._backend.on_start(self.worker_group, self._backend_config) except RayActorError as exc: logger.exception(str(exc)) self._increment_failures() self._restart() - def _share_cuda_visible_devices(self): + def _setup_gpus(self): """Sets CUDA_VISIBLE_DEVICES on all workers. For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs @@ -691,18 +685,6 @@ def _increment_failures(self): class Backend(metaclass=abc.ABCMeta): - """Metaclass for distributed communication backend. - - Attributes: - share_cuda_visible_devices (bool): If True, each worker - process will have CUDA_VISIBLE_DEVICES set as the visible device - IDs of all workers on the same node for this training instance. - If False, each worker will have CUDA_VISIBLE_DEVICES set to the - device IDs allocated by Ray for that worker. - """ - - share_cuda_visible_devices: bool = False - def on_start(self, worker_group: WorkerGroup, backend_config: BackendConfig): """Logic for starting this backend.""" diff --git a/python/ray/util/sgd/v2/backends/horovod.py b/python/ray/util/sgd/v2/backends/horovod.py index 4382130ae5749..4f424d5212dec 100644 --- a/python/ray/util/sgd/v2/backends/horovod.py +++ b/python/ray/util/sgd/v2/backends/horovod.py @@ -52,8 +52,6 @@ def init_env_vars(world_rank: int, world_size: int, node_id: str): class HorovodBackend(Backend): - share_cuda_visible_devices: bool = True - def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig): diff --git a/python/ray/util/sgd/v2/backends/torch.py b/python/ray/util/sgd/v2/backends/torch.py index 1d1f0d39f366f..7d76b179c8d2d 100644 --- a/python/ray/util/sgd/v2/backends/torch.py +++ b/python/ray/util/sgd/v2/backends/torch.py @@ -92,8 +92,6 @@ def shutdown_torch(destroy_process_group=False): class TorchBackend(Backend): - share_cuda_visible_devices: bool = True - def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig): if len(worker_group) > 1 and dist.is_available(): # Set the appropriate training backend. diff --git a/python/ray/util/sgd/v2/constants.py b/python/ray/util/sgd/v2/constants.py index b0dc39e9cbfbc..6ebd428f7b1cb 100644 --- a/python/ray/util/sgd/v2/constants.py +++ b/python/ray/util/sgd/v2/constants.py @@ -44,7 +44,3 @@ # This needs to be added to the checkpoint dictionary so if the Tune trial # is restarted, the checkpoint_id can continue to increment. TUNE_CHECKPOINT_ID = "_current_checkpoint_id" - -# Integer value which if set will override the value of -# Backend.share_cuda_visible_devices. 1 for True, 0 for False. -ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV = "SGD_ENABLE_SHARE_CUDA_VISIBLE_DEVICES" diff --git a/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py b/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py index c299808c916aa..f87380cf9ce16 100644 --- a/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py +++ b/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py @@ -72,7 +72,7 @@ def train_func(config): return results -def train_tensorflow_mnist(num_workers=2, use_gpu=False): +def train_tensorflow_mnist(num_workers=1, use_gpu=False): trainer = Trainer( backend="tensorflow", num_workers=num_workers, use_gpu=use_gpu) trainer.start() @@ -98,7 +98,7 @@ def train_tensorflow_mnist(num_workers=2, use_gpu=False): "--num-workers", "-n", type=int, - default=2, + default=1, help="Sets number of workers for training.") parser.add_argument( "--use-gpu", diff --git a/python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py b/python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py deleted file mode 100644 index 1ff8054be367c..0000000000000 --- a/python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py +++ /dev/null @@ -1,200 +0,0 @@ -import numpy as np -import argparse -from filelock import FileLock - -import ray -from ray import tune -from ray.tune import CLIReporter -from ray.tune.schedulers import PopulationBasedTraining - -import torch -import torch.nn as nn -from torch.utils.data import DataLoader, DistributedSampler, Subset -from torchvision.datasets import CIFAR10 -import torchvision.transforms as transforms -from torch.nn.parallel import DistributedDataParallel - -from ray.util.sgd.torch.resnet import ResNet18 - -import ray.util.sgd.v2 as sgd -from ray.util.sgd.v2 import Trainer - - -def train(dataloader, model, loss_fn, optimizer, device): - size = len(dataloader.dataset) - for batch, (X, y) in enumerate(dataloader): - X, y = X.to(device), y.to(device) - - # Compute prediction error - pred = model(X) - loss = loss_fn(pred, y) - - # Backpropagation - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if batch % 100 == 0: - loss, current = loss.item(), batch * len(X) - print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") - - -def validate(dataloader, model, loss_fn, device): - size = len(dataloader.dataset) - num_batches = len(dataloader) - model.eval() - test_loss, correct = 0, 0 - with torch.no_grad(): - for X, y in dataloader: - X, y = X.to(device), y.to(device) - pred = model(X) - test_loss += loss_fn(pred, y).item() - correct += (pred.argmax(1) == y).type(torch.float).sum().item() - test_loss /= num_batches - correct /= size - print(f"Test Error: \n " - f"Accuracy: {(100 * correct):>0.1f}%, " - f"Avg loss: {test_loss:>8f} \n") - return {"loss": test_loss} - - -def train_func(config): - device = torch.device(f"cuda:{sgd.local_rank()}" - if torch.cuda.is_available() else "cpu") - - epochs = config.pop("epochs", 3) - model = ResNet18(config) - model = model.to(device) - model = DistributedDataParallel( - model, - device_ids=[device.index] if torch.cuda.is_available() else None) - - # Create optimizer. - optimizer = torch.optim.SGD( - model.parameters(), - lr=config.get("lr", 0.1), - momentum=config.get("momentum", 0.9)) - - # Load in training and validation data. - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)), - ]) # meanstd transformation - - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)), - ]) - - with FileLock(".ray.lock"): - train_dataset = CIFAR10( - root="~/data", - train=True, - download=True, - transform=transform_train) - validation_dataset = CIFAR10( - root="~/data", - train=False, - download=False, - transform=transform_test) - - if config.get("test_mode"): - train_dataset = Subset(train_dataset, list(range(64))) - validation_dataset = Subset(validation_dataset, list(range(64))) - - train_loader = DataLoader( - train_dataset, - batch_size=config["batch_size"], - sampler=DistributedSampler(train_dataset)) - validation_loader = DataLoader( - validation_dataset, - batch_size=config["batch_size"], - sampler=DistributedSampler(validation_dataset)) - - # Create loss. - criterion = nn.CrossEntropyLoss() - - results = [] - - for _ in range(epochs): - train(train_loader, model, criterion, optimizer, device) - result = validate(validation_loader, model, criterion, device) - sgd.report(**result) - results.append(result) - - return results - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--address", - required=False, - type=str, - help="the address to use for Redis") - parser.add_argument( - "--num-workers", - "-n", - type=int, - default=2, - help="Sets number of workers for training.") - parser.add_argument( - "--num-epochs", type=int, default=5, help="Number of epochs to train.") - parser.add_argument( - "--smoke-test", - action="store_true", - default=False, - help="Finish quickly for testing.") - parser.add_argument( - "--use-gpu", - action="store_true", - default=False, - help="Enables GPU training") - - args, _ = parser.parse_known_args() - if args.smoke_test: - ray.init(num_cpus=4) - else: - ray.init(address=args.address) - - trainer = Trainer( - "torch", num_workers=args.num_workers, use_gpu=args.use_gpu) - Trainable = trainer.to_tune_trainable(train_func) - pbt_scheduler = PopulationBasedTraining( - time_attr="training_iteration", - metric="loss", - mode="min", - perturbation_interval=1, - hyperparam_mutations={ - # distribution for resampling - "lr": lambda: np.random.uniform(0.001, 1), - # allow perturbations within this set of categorical values - "momentum": [0.8, 0.9, 0.99], - }) - - reporter = CLIReporter() - reporter.add_metric_column("loss", "loss") - - analysis = tune.run( - Trainable, - num_samples=4, - config={ - "lr": tune.choice([0.001, 0.01, 0.1]), - "momentum": 0.8, - "batch_size": 128 * args.num_workers, - "epochs": args.num_epochs, - "test_mode": args.smoke_test # whether to to subset the data - }, - stop={"training_iteration": 2 if args.smoke_test else 100}, - max_failures=3, # used for fault tolerance - checkpoint_freq=3, # used for fault tolerance - keep_checkpoints_num=1, # used for fault tolerance - verbose=2, - progress_reporter=reporter, - scheduler=pbt_scheduler) - - print(analysis.get_best_config(metric="loss", mode="min")) diff --git a/python/ray/util/sgd/v2/tests/test_backend.py b/python/ray/util/sgd/v2/tests/test_backend.py index 65ac486dd9df1..985029b808118 100644 --- a/python/ray/util/sgd/v2/tests/test_backend.py +++ b/python/ray/util/sgd/v2/tests/test_backend.py @@ -8,7 +8,6 @@ from ray.util.sgd import v2 as sgd from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor from ray.util.sgd.v2.backends.tensorflow import TensorflowConfig -from ray.util.sgd.v2.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV from ray.util.sgd.v2.worker_group import WorkerGroup from ray.util.sgd.v2.backends.torch import TorchConfig @@ -322,7 +321,6 @@ def get_resources(): num_workers, expected_results = worker_results - os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" e = BackendExecutor( config, num_workers=num_workers, @@ -351,7 +349,6 @@ def get_resources(): num_workers, expected_results = worker_results - os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" e = BackendExecutor( config, num_workers=num_workers, @@ -377,7 +374,6 @@ def get_resources(): num_workers, expected_results = worker_results - os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" e = BackendExecutor( config, num_workers=num_workers, diff --git a/python/ray/util/sgd/v2/tests/test_gpu.py b/python/ray/util/sgd/v2/tests/test_gpu.py deleted file mode 100644 index 845e768cd6d47..0000000000000 --- a/python/ray/util/sgd/v2/tests/test_gpu.py +++ /dev/null @@ -1,92 +0,0 @@ -import pytest - -import ray -from ray.util.sgd.v2 import Trainer -from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \ - horovod_torch_train_func -from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ - tensorflow_mnist_train_func -from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \ - as fashion_mnist_train_func -from test_tune import torch_fashion_mnist, tune_tensorflow_mnist - - -@pytest.fixture -def ray_start_4_cpus_2_gpus(): - address_info = ray.init(num_cpus=4, num_gpus=2) - yield address_info - # The code after the yield will run as teardown code. - ray.shutdown() - - -def test_tensorflow_mnist_gpu(ray_start_4_cpus_2_gpus): - num_workers = 2 - epochs = 3 - - trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=True) - config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} - trainer.start() - results = trainer.run(tensorflow_mnist_train_func, config) - trainer.shutdown() - - assert len(results) == num_workers - result = results[0] - - loss = result["loss"] - assert len(loss) == epochs - assert loss[-1] < loss[0] - - accuracy = result["accuracy"] - assert len(accuracy) == epochs - assert accuracy[-1] > accuracy[0] - - -def test_torch_fashion_mnist_gpu(ray_start_4_cpus_2_gpus): - num_workers = 2 - epochs = 3 - - trainer = Trainer("torch", num_workers=num_workers, use_gpu=True) - config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} - trainer.start() - results = trainer.run(fashion_mnist_train_func, config) - trainer.shutdown() - - assert len(results) == num_workers - - for result in results: - assert len(result) == epochs - assert result[-1] < result[0] - - -def test_horovod_torch_mnist_gpu(ray_start_4_cpus_2_gpus): - num_workers = 2 - num_epochs = 2 - trainer = Trainer("horovod", num_workers, use_gpu=True) - trainer.start() - results = trainer.run( - horovod_torch_train_func, - config={ - "num_epochs": num_epochs, - "lr": 1e-3 - }) - trainer.shutdown() - - assert len(results) == num_workers - for worker_result in results: - assert len(worker_result) == num_epochs - assert worker_result[num_epochs - 1] < worker_result[0] - - -def test_tune_fashion_mnist_gpu(ray_start_4_cpus_2_gpus): - torch_fashion_mnist(num_workers=2, use_gpu=True, num_samples=1) - - -def test_tune_tensorflow_mnist_gpu(ray_start_4_cpus_2_gpus): - tune_tensorflow_mnist(num_workers=2, use_gpu=True, num_samples=1) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", "-x", "-s", __file__])) diff --git a/python/ray/util/sgd/v2/tests/test_trainer.py b/python/ray/util/sgd/v2/tests/test_trainer.py index 9795017283a0b..f7da6310a2a96 100644 --- a/python/ray/util/sgd/v2/tests/test_trainer.py +++ b/python/ray/util/sgd/v2/tests/test_trainer.py @@ -5,24 +5,26 @@ import horovod.torch as hvd_torch import pytest - import ray import ray.util.sgd.v2 as sgd +import tensorflow as tf +import torch from ray._private.test_utils import wait_for_condition from ray.util.sgd.v2 import Trainer, TorchConfig, TensorflowConfig, \ HorovodConfig from ray.util.sgd.v2.backends.backend import BackendConfig, Backend, \ BackendExecutor from ray.util.sgd.v2.callbacks.callback import SGDCallback -from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \ - horovod_torch_train_func, HorovodTrainClass -from ray.util.sgd.v2.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ tensorflow_mnist_train_func from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \ - as fashion_mnist_train_func + as \ + fashion_mnist_train_func from ray.util.sgd.v2.examples.train_linear_example import train_func as \ linear_train_func + +from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \ + horovod_torch_train_func, HorovodTrainClass from ray.util.sgd.v2.worker_group import WorkerGroup @@ -496,6 +498,31 @@ def test_tensorflow_mnist(ray_start_2_cpus): assert accuracy[-1] > accuracy[0] +@pytest.mark.skipif( + len(tf.config.list_physical_devices("GPU")) < 2, + reason="Only run if multiple GPUs are available.") +def test_tensorflow_mnist_gpu(ray_start_2_cpus_2_gpus): + num_workers = 2 + epochs = 3 + + trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=True) + config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} + trainer.start() + results = trainer.run(tensorflow_mnist_train_func, config) + trainer.shutdown() + + assert len(results) == num_workers + result = results[0] + + loss = result["loss"] + assert len(loss) == epochs + assert loss[-1] < loss[0] + + accuracy = result["accuracy"] + assert len(accuracy) == epochs + assert accuracy[-1] > accuracy[0] + + def test_torch_linear(ray_start_2_cpus): num_workers = 2 epochs = 3 @@ -530,6 +557,26 @@ def test_torch_fashion_mnist(ray_start_2_cpus): assert result[-1] < result[0] +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Only run if multiple GPUs are available.") +def test_torch_fashion_mnist_gpu(ray_start_2_cpus_2_gpus): + num_workers = 2 + epochs = 3 + + trainer = Trainer("torch", num_workers=num_workers, use_gpu=True) + config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} + trainer.start() + results = trainer.run(fashion_mnist_train_func, config) + trainer.shutdown() + + assert len(results) == num_workers + + for result in results: + assert len(result) == epochs + assert result[-1] < result[0] + + def test_horovod_simple(ray_start_2_cpus): def simple_fn(): hvd_torch.init() @@ -563,6 +610,28 @@ def test_horovod_torch_mnist(ray_start_2_cpus): assert worker_result[num_epochs - 1] < worker_result[0] +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Only run if multiple GPUs are available.") +def test_horovod_torch_mnist_gpu(ray_start_2_cpus_2_gpus): + num_workers = 2 + num_epochs = 2 + trainer = Trainer("horovod", num_workers, use_gpu=True) + trainer.start() + results = trainer.run( + horovod_torch_train_func, + config={ + "num_epochs": num_epochs, + "lr": 1e-3 + }) + trainer.shutdown() + + assert len(results) == num_workers + for worker_result in results: + assert len(worker_result) == num_epochs + assert worker_result[num_epochs - 1] < worker_result[0] + + def test_horovod_torch_mnist_stateful(ray_start_2_cpus): num_workers = 2 num_epochs = 2 @@ -917,6 +986,7 @@ def test_resources(ray_start_4_cpus_4_gpus_4_extra, resource, num_requested): def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra): + # GPUs should not be requested if `use_gpu` is False. with pytest.raises(ValueError): Trainer( @@ -936,8 +1006,6 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra): def get_resources(): return os.environ["CUDA_VISIBLE_DEVICES"] - os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" - # 0 GPUs will be requested and should not raise an error. trainer = Trainer(TestConfig(), num_workers=2, use_gpu=False) trainer.start() diff --git a/python/ray/util/sgd/v2/tests/test_tune.py b/python/ray/util/sgd/v2/tests/test_tune.py index 0ec1db59542f8..fb9d39b6df8b0 100644 --- a/python/ray/util/sgd/v2/tests/test_tune.py +++ b/python/ray/util/sgd/v2/tests/test_tune.py @@ -1,13 +1,18 @@ import os import pytest + +import torch +import tensorflow as tf + import ray -import ray.util.sgd.v2 as sgd from ray import tune, cloudpickle from ray.tune import TuneError + +import ray.util.sgd.v2 as sgd from ray.util.sgd.v2 import Trainer -from ray.util.sgd.v2.backends.backend import Backend, BackendConfig from ray.util.sgd.v2.constants import TUNE_CHECKPOINT_FILE_NAME +from ray.util.sgd.v2.backends.backend import Backend, BackendConfig from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ tensorflow_mnist_train_func from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \ @@ -23,6 +28,14 @@ def ray_start_2_cpus(): ray.shutdown() +@pytest.fixture +def ray_start_4_cpus_4_gpus(): + address_info = ray.init(num_cpus=2, num_gpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + @pytest.fixture def ray_start_8_cpus(): address_info = ray.init(num_cpus=8) @@ -70,6 +83,13 @@ def test_tune_torch_fashion_mnist(ray_start_8_cpus): torch_fashion_mnist(num_workers=2, use_gpu=False, num_samples=2) +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Only run if multiple GPUs are available.") +def test_tune_fashion_mnist_gpu(ray_start_4_cpus_4_gpus): + torch_fashion_mnist(num_workers=2, use_gpu=True, num_samples=1) + + def tune_tensorflow_mnist(num_workers, use_gpu, num_samples): epochs = 2 trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=use_gpu) @@ -93,6 +113,13 @@ def test_tune_tensorflow_mnist(ray_start_8_cpus): tune_tensorflow_mnist(num_workers=2, use_gpu=False, num_samples=2) +@pytest.mark.skipif( + len(tf.config.list_physical_devices("GPU")) < 2, + reason="Only run if multiple GPUs are available.") +def test_tune_tensorflow_mnist_gpu(ray_start_4_cpus_4_gpus): + tune_tensorflow_mnist(num_workers=2, use_gpu=True, num_samples=1) + + def test_tune_error(ray_start_2_cpus): def train_func(config): raise RuntimeError("Error in training function!") diff --git a/python/ray/util/tracing/tracing_helper.py b/python/ray/util/tracing/tracing_helper.py index 68696fe29c46d..73fb61c00767c 100644 --- a/python/ray/util/tracing/tracing_helper.py +++ b/python/ray/util/tracing/tracing_helper.py @@ -290,8 +290,6 @@ def _invocation_remote_span( # If tracing feature flag is not on, perform a no-op. # Tracing doesn't work for cross lang yet. if not is_tracing_enabled() or self._is_cross_language: - if kwargs is not None: - assert "_ray_trace_ctx" not in kwargs return method(self, args, kwargs, *_args, **_kwargs) assert "_ray_trace_ctx" not in kwargs @@ -367,7 +365,8 @@ def _invocation_actor_class_remote_span( # If tracing feature flag is not on, perform a no-op if not is_tracing_enabled(): - assert "_ray_trace_ctx" not in kwargs + if not self.__ray_metadata__.is_cross_language: + kwargs["_ray_trace_ctx"] = None return method(self, args, kwargs, *_args, **_kwargs) class_name = self.__ray_metadata__.class_name @@ -405,8 +404,6 @@ def _start_span( # If tracing feature flag is not on, perform a no-op if (not is_tracing_enabled() or self._actor_ref()._ray_is_cross_language): - if kwargs is not None: - assert "_ray_trace_ctx" not in kwargs return method(self, args, kwargs, *_args, **_kwargs) class_name = (self._actor_ref() diff --git a/python/ray/worker.py b/python/ray/worker.py index 97849d8d2750f..9f5dd31ca6da3 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -191,8 +191,8 @@ def current_session_and_job(self): @property def runtime_env(self): """Get the runtime env in json format""" - return json.loads(self.core_worker.get_job_config() - .runtime_env.serialized_runtime_env) + return json.loads( + self.core_worker.get_job_config().runtime_env.raw_json) def get_serialization_context(self, job_id=None): """Get the SerializationContext of the job that this worker is processing. @@ -223,6 +223,9 @@ def check_connected(self): Exception: An exception is raised if the worker is not connected. """ if not self.connected: + if os.environ.get("RAY_ENABLE_AUTO_CONNECT", "") != "0": + ray.client().connect() + return raise RaySystemError("Ray has not been started yet. You can " "start Ray with 'ray.init()'.") @@ -476,7 +479,7 @@ def print_logs(self): @PublicAPI -@client_mode_hook(auto_init=True) +@client_mode_hook def get_gpu_ids(): """Get the IDs of the GPUs that are available to the worker. @@ -573,7 +576,7 @@ def get_dashboard_url(): @PublicAPI -@client_mode_hook(auto_init=False) +@client_mode_hook def init( address: Optional[str] = None, *, @@ -602,6 +605,7 @@ def init( _memory: Optional[int] = None, _redis_password: str = ray_constants.REDIS_DEFAULT_PASSWORD, _temp_dir: Optional[str] = None, + _lru_evict: bool = False, _metrics_export_port: Optional[int] = None, _system_config: Optional[Dict[str, str]] = None, _tracing_startup_hook: Optional[Callable] = None, @@ -878,6 +882,7 @@ def init( start_initial_python_workers_for_first_job=( job_config is None or job_config.runtime_env is None), _system_config=_system_config, + lru_evict=_lru_evict, enable_object_reconstruction=_enable_object_reconstruction, metrics_export_port=_metrics_export_port, tracing_startup_hook=_tracing_startup_hook) @@ -919,6 +924,7 @@ def init( object_ref_seed=None, temp_dir=_temp_dir, _system_config=_system_config, + lru_evict=_lru_evict, enable_object_reconstruction=_enable_object_reconstruction, metrics_export_port=_metrics_export_port) _global_node = ray.node.Node( @@ -968,7 +974,7 @@ def init( @PublicAPI -@client_mode_hook(auto_init=False) +@client_mode_hook def shutdown(_exiting_interpreter: bool = False): """Disconnect the worker, and terminate processes started by ray.init(). @@ -1234,7 +1240,7 @@ def listen_error_messages_raylet(worker, threads_stopped): @PublicAPI -@client_mode_hook(auto_init=False) +@client_mode_hook def is_initialized() -> bool: """Check if ray.init has been called yet. @@ -1553,7 +1559,7 @@ def show_in_dashboard(message: str, key: str = "", dtype: str = "text"): @PublicAPI -@client_mode_hook(auto_init=True) +@client_mode_hook def get(object_refs: Union[ray.ObjectRef, List[ray.ObjectRef]], *, timeout: Optional[float] = None) -> Union[Any, List[Any]]: @@ -1642,7 +1648,7 @@ def get(object_refs: Union[ray.ObjectRef, List[ray.ObjectRef]], @PublicAPI -@client_mode_hook(auto_init=True) +@client_mode_hook def put(value: Any, *, _owner: Optional["ray.actor.ActorHandle"] = None) -> ray.ObjectRef: """Store an object in the object store. @@ -1696,7 +1702,7 @@ def put(value: Any, *, @PublicAPI -@client_mode_hook(auto_init=True) +@client_mode_hook def wait(object_refs: List[ray.ObjectRef], *, num_returns: int = 1, @@ -1803,7 +1809,7 @@ def wait(object_refs: List[ray.ObjectRef], @PublicAPI -@client_mode_hook(auto_init=True) +@client_mode_hook def get_actor(name: str, namespace: Optional[str] = None) -> "ray.actor.ActorHandle": """Get a handle to a named actor. @@ -1835,7 +1841,7 @@ def get_actor(name: str, @PublicAPI -@client_mode_hook(auto_init=True) +@client_mode_hook def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True): """Kill an actor forcefully. @@ -1864,7 +1870,7 @@ def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True): @PublicAPI -@client_mode_hook(auto_init=True) +@client_mode_hook def cancel(object_ref: ray.ObjectRef, *, force: bool = False, @@ -1926,7 +1932,6 @@ def make_decorator(num_returns=None, max_restarts=None, max_task_retries=None, runtime_env=None, - placement_group="default", worker=None, retry_exceptions=None): def decorator(function_or_class): @@ -1958,7 +1963,7 @@ def decorator(function_or_class): Language.PYTHON, function_or_class, None, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type, num_returns, max_calls, max_retries, retry_exceptions, - runtime_env, placement_group) + runtime_env) if inspect.isclass(function_or_class): if num_returns is not None: @@ -2096,6 +2101,15 @@ def method(self): retry_exceptions (bool): Only for *remote functions*. This specifies whether application-level errors should be retried up to max_retries times. + override_environment_variables (Dict[str, str]): (Deprecated in Ray + 1.4.0, will be removed in Ray 1.6--please use the ``env_vars`` + field of :ref:`runtime-environments` instead.) This specifies + environment variables to override for the actor or task. The + overrides are propagated to all child actors and tasks. This + is a dictionary mapping variable names to their values. Existing + variables can be overridden, new ones can be created, and an + existing variable can be unset by setting it to an empty string. + Note: can only be set via `.options()`. """ worker = global_worker @@ -2107,8 +2121,7 @@ def method(self): valid_kwargs = [ "num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory", "resources", "accelerator_type", "max_calls", "max_restarts", - "max_task_retries", "max_retries", "runtime_env", "retry_exceptions", - "placement_group" + "max_task_retries", "max_retries", "runtime_env", "retry_exceptions" ] error_string = ("The @ray.remote decorator must be applied either " "with no arguments and no parentheses, for example " @@ -2141,7 +2154,6 @@ def method(self): object_store_memory = kwargs.get("object_store_memory") max_retries = kwargs.get("max_retries") runtime_env = kwargs.get("runtime_env") - placement_group = kwargs.get("placement_group", "default") retry_exceptions = kwargs.get("retry_exceptions") return make_decorator( @@ -2157,6 +2169,5 @@ def method(self): max_task_retries=max_task_retries, max_retries=max_retries, runtime_env=runtime_env, - placement_group=placement_group, worker=worker, retry_exceptions=retry_exceptions) diff --git a/python/ray/workers/setup_worker.py b/python/ray/workers/setup_worker.py index b40737c1a8ad0..23fbc6e8e150d 100644 --- a/python/ray/workers/setup_worker.py +++ b/python/ray/workers/setup_worker.py @@ -3,8 +3,7 @@ import logging import os -from ray._private.runtime_env.context import RuntimeEnvContext -from ray.core.generated.common_pb2 import Language +from ray._private.runtime_env import RuntimeEnvContext logger = logging.getLogger(__name__) @@ -27,9 +26,6 @@ type=str, help="the worker allocated resource") -parser.add_argument( - "--language", type=str, help="the language type of the worker") - def get_tmp_dir(remaining_args): for arg in remaining_args: @@ -121,5 +117,5 @@ def start_worker_in_container(container_option, args, remaining_args): # probably not even go through this codepath. runtime_env_context = RuntimeEnvContext.deserialize( args.serialized_runtime_env_context or "{}") - runtime_env_context.exec_worker(remaining_args, - Language.Value(args.language)) + + runtime_env_context.exec_worker(remaining_args) diff --git a/python/ray/workflow/common.py b/python/ray/workflow/common.py index 169b318ed5d76..e883cdfacd0b4 100644 --- a/python/ray/workflow/common.py +++ b/python/ray/workflow/common.py @@ -32,8 +32,7 @@ def get_qualname(f): def ensure_ray_initialized(): - if not ray.is_initialized(): - ray.init() + ray.worker.global_worker.check_connected() @dataclass diff --git a/python/ray/workflow/execution.py b/python/ray/workflow/execution.py index b660c65fa8a9e..6de22cef05943 100644 --- a/python/ray/workflow/execution.py +++ b/python/ray/workflow/execution.py @@ -32,9 +32,8 @@ def run(entry_workflow: Workflow, # Workflow ID format: {Entry workflow UUID}.{Unix time to nanoseconds} workflow_id = f"{str(uuid.uuid4())}.{time.time():.9f}" - logger.info( - f"Workflow job created. [id=\"{workflow_id}\", storage_url=" - f"\"{store.storage_url}\"]. Type: {entry_workflow.data.step_type} ") + logger.info(f"Workflow job created. [id=\"{workflow_id}\", storage_url=" + f"\"{store.storage_url}\"].") with workflow_context.workflow_step_context(workflow_id, store.storage_url): @@ -52,7 +51,7 @@ def run(entry_workflow: Workflow, # - it's a new workflow # TODO (yic): follow up with force rerun if entry_workflow.data.step_type != StepType.FUNCTION or not wf_exists: - commit_step(ws, "", entry_workflow, exception=None) + commit_step(ws, "", entry_workflow, None) workflow_manager = get_or_create_management_actor() ignore_existing = (entry_workflow.data.step_type != StepType.FUNCTION) # NOTE: It is important to 'ray.get' the returned output. This diff --git a/python/ray/workflow/recovery.py b/python/ray/workflow/recovery.py index 8c64c2cba4100..58902b4419681 100644 --- a/python/ray/workflow/recovery.py +++ b/python/ray/workflow/recovery.py @@ -51,8 +51,8 @@ def _recover_workflow_step(args: List[Any], kwargs: Dict[str, Any], def _construct_resume_workflow_from_step( - reader: workflow_storage.WorkflowStorage, step_id: StepID, - input_map: Dict[StepID, Any]) -> Union[Workflow, StepID]: + reader: workflow_storage.WorkflowStorage, + step_id: StepID) -> Union[Workflow, StepID]: """Try to construct a workflow (step) that recovers the workflow step. If the workflow step already has an output checkpointing file, we return the workflow step id instead. @@ -60,8 +60,6 @@ def _construct_resume_workflow_from_step( Args: reader: The storage reader for inspecting the step. step_id: The ID of the step we want to recover. - input_map: This is a context storing the input which has been loaded. - This context is important for dedupe Returns: A workflow that recovers the step, or a ID of a step @@ -72,8 +70,8 @@ def _construct_resume_workflow_from_step( # we already have the output return step_id if isinstance(result.output_step_id, str): - return _construct_resume_workflow_from_step( - reader, result.output_step_id, input_map) + return _construct_resume_workflow_from_step(reader, + result.output_step_id) # output does not exists or not valid. try to reconstruct it. if not result.is_recoverable(): raise WorkflowStepNotRecoverableError(step_id) @@ -81,14 +79,7 @@ def _construct_resume_workflow_from_step( with serialization.objectref_cache(): input_workflows = [] for i, _step_id in enumerate(result.workflows): - # Check whether the step has been loaded or not to avoid - # duplication - if _step_id in input_map: - r = input_map[_step_id] - else: - r = _construct_resume_workflow_from_step( - reader, _step_id, input_map) - input_map[_step_id] = r + r = _construct_resume_workflow_from_step(reader, _step_id) if isinstance(r, Workflow): input_workflows.append(r) else: @@ -128,15 +119,15 @@ def _resume_workflow_step_executor(workflow_id: str, step_id: "StepID", try: store = storage.create_storage(store_url) wf_store = workflow_storage.WorkflowStorage(workflow_id, store) - r = _construct_resume_workflow_from_step(wf_store, step_id, {}) + r = _construct_resume_workflow_from_step(wf_store, step_id) except Exception as e: raise WorkflowNotResumableError(workflow_id) from e if isinstance(r, Workflow): - with workflow_context.workflow_step_context( - workflow_id, store.storage_url, last_step_of_workflow=True): - from ray.workflow.step_executor import execute_workflow - result = execute_workflow(r) + with workflow_context.workflow_step_context(workflow_id, + store.storage_url): + from ray.workflow.step_executor import (execute_workflow) + result = execute_workflow(r, last_step_of_workflow=True) return result.persisted_output, result.volatile_output assert isinstance(r, StepID) return wf_store.load_step_output(r), None diff --git a/python/ray/workflow/step_executor.py b/python/ray/workflow/step_executor.py index b5416b5a40218..878c7b40bf451 100644 --- a/python/ray/workflow/step_executor.py +++ b/python/ray/workflow/step_executor.py @@ -134,9 +134,33 @@ def _resolve_step_inputs( return signature.recover_args(flattened_args) -def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult": +def execute_workflow( + workflow: "Workflow", + outer_most_step_id: Optional[str] = None, + last_step_of_workflow: bool = False) -> "WorkflowExecutionResult": """Execute workflow. + To fully explain what we are doing, we need to introduce some syntax first. + The syntax for dependencies between workflow steps + "B.step(A.step())" is "A - B"; the syntax for nested workflow steps + "def A(): return B.step()" is "A / B". + + In a chain/DAG of step dependencies, the "output step" is the step of last + (topological) order. For example, in "A - B - C", C is the output step. + + In a chain of nested workflow steps, the initial "output step" is + called the "outer most step" for other "output steps". For example, in + "A / B / C / D", "A" is the outer most step for "B", "C", "D"; + in the hybrid workflow "((A - B) / C / D) - (E / (F - G) / H)", + "B" is the outer most step for "C", "D"; "E" is the outer most step + for "G", "H". + + Args: + workflow: The workflow to be executed. + outer_most_step_id: The ID of the outer most workflow. None if it + does not exists. + last_step_of_workflow: The step that generates the output of the + workflow (including nested steps). Returns: An object ref that represent the result. """ @@ -149,8 +173,8 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult": **workflow_data.ray_options).remote( workflow_data.step_type, workflow_data.func_body, workflow_context.get_workflow_step_context(), workflow.step_id, - baked_inputs, workflow_data.catch_exceptions, - workflow_data.max_retries) + baked_inputs, outer_most_step_id, workflow_data.catch_exceptions, + workflow_data.max_retries, last_step_of_workflow) if not isinstance(persisted_output, WorkflowOutputType): raise TypeError("Unexpected return type of the workflow.") @@ -173,6 +197,7 @@ async def _write_step_inputs(wf_storage: workflow_storage.WorkflowStorage, # TODO(suquark): in the future we should write to storage directly # with plasma store object in memory. args_obj = ray.get(inputs.inputs.args) + workflow_id = wf_storage._workflow_id storage = wf_storage._storage save_tasks = [ @@ -188,13 +213,19 @@ async def _write_step_inputs(wf_storage: workflow_storage.WorkflowStorage, await asyncio.gather(*save_tasks) -def commit_step(store: workflow_storage.WorkflowStorage, step_id: "StepID", - ret: Union["Workflow", Any], exception: Optional[Exception]): +def commit_step(store: workflow_storage.WorkflowStorage, + step_id: "StepID", + ret: Union["Workflow", Any], + exception: Optional[Exception], + outer_most_step_id: Optional[str] = None): """Checkpoint the step output. Args: store: The storage the current workflow is using. step_id: The ID of the step. ret: The returned object of the workflow step. + outer_most_step_id: The ID of the outer most workflow. None if it + does not exists. See "step_executor.execute_workflow" for detailed + explanation. """ from ray.workflow.common import Workflow if isinstance(ret, Workflow): @@ -205,12 +236,7 @@ def commit_step(store: workflow_storage.WorkflowStorage, step_id: "StepID", ] asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) - context = workflow_context.get_workflow_step_context() - store.save_step_output( - step_id, - ret, - exception=exception, - outer_most_step_id=context.outer_most_step_id) + store.save_step_output(step_id, ret, exception, outer_most_step_id) def _wrap_run(func: Callable, step_type: StepType, step_id: "StepID", @@ -302,11 +328,12 @@ def _wrap_run(func: Callable, step_type: StepType, step_id: "StepID", @ray.remote(num_returns=2) -def _workflow_step_executor(step_type: StepType, func: Callable, - context: workflow_context.WorkflowStepContext, - step_id: "StepID", - baked_inputs: "_BakedWorkflowInputs", - catch_exceptions: bool, max_retries: int) -> Any: +def _workflow_step_executor( + step_type: StepType, func: Callable, + context: workflow_context.WorkflowStepContext, step_id: "StepID", + baked_inputs: "_BakedWorkflowInputs", outer_most_step_id: "StepID", + catch_exceptions: bool, max_retries: int, + last_step_of_workflow: bool) -> Any: """Executor function for workflow step. Args: @@ -315,9 +342,13 @@ def _workflow_step_executor(step_type: StepType, func: Callable, context: Workflow step context. Used to access correct storage etc. step_id: The ID of the step. baked_inputs: The processed inputs for the step. + outer_most_step_id: See "step_executor.execute_workflow" for + explanation. catch_exceptions: If set to be true, return (Optional[Result], Optional[Error]) instead of Result. max_retries: Max number of retries encounter of a failure. + last_step_of_workflow: The step that generates the output of the + workflow (including nested steps). Returns: Workflow step output. @@ -330,7 +361,7 @@ def _workflow_step_executor(step_type: StepType, func: Callable, func, step_type, step_id, catch_exceptions, max_retries, *args, **kwargs) except Exception as e: - commit_step(store, step_id, None, e) + commit_step(store, step_id, None, e, outer_most_step_id) raise e if step_type == StepType.READONLY_ACTOR_METHOD: if isinstance(volatile_output, Workflow): @@ -340,28 +371,26 @@ def _workflow_step_executor(step_type: StepType, func: Callable, assert not isinstance(persisted_output, Workflow) else: store = workflow_storage.get_workflow_storage() - commit_step(store, step_id, persisted_output, None) - outer_most_step_id = context.outer_most_step_id + commit_step(store, step_id, persisted_output, None, outer_most_step_id) if isinstance(persisted_output, Workflow): if step_type == StepType.FUNCTION: # Passing down outer most step so inner nested steps would # access the same outer most step. - if not context.outer_most_step_id: + if not outer_most_step_id: # The current workflow step returns a nested workflow, and # there is no outer step for the current step. So the # current step is the outer most step for the inner nested # workflow steps. outer_most_step_id = workflow_context.get_current_step_id() assert volatile_output is None - # Execute sub-workflow. Pass down "outer_most_step_id". - with workflow_context.fork_workflow_step_context( - outer_most_step_id=outer_most_step_id): - result = execute_workflow(persisted_output) + # execute sub-workflow + result = execute_workflow(persisted_output, outer_most_step_id, + last_step_of_workflow) # When virtual actor returns a workflow in the method, # the volatile_output and persisted_output will be put together persisted_output = result.persisted_output volatile_output = result.volatile_output - elif context.last_step_of_workflow: + elif last_step_of_workflow: # advance the progress of the workflow store.advance_progress(step_id) _record_step_status(step_id, WorkflowStatus.SUCCESSFUL) @@ -386,11 +415,9 @@ class _BakedWorkflowInputs: @classmethod def from_workflow_inputs(cls, inputs: "WorkflowInputs"): - with workflow_context.fork_workflow_step_context( - outer_most_step_id=None, last_step_of_workflow=False): - workflow_outputs = [ - execute_workflow(w).persisted_output for w in inputs.workflows - ] + workflow_outputs = [ + execute_workflow(w).persisted_output for w in inputs.workflows + ] return cls(inputs.args, workflow_outputs, inputs.workflow_refs) def __reduce__(self): @@ -400,10 +427,7 @@ def __reduce__(self): def _record_step_status(step_id: "StepID", status: "WorkflowStatus", - outputs: Optional[List["ObjectRef"]] = None) -> None: - if outputs is None: - outputs = [] - + outputs: List["ObjectRef"] = []) -> None: workflow_id = workflow_context.get_current_workflow_id() workflow_manager = get_management_actor() ray.get( diff --git a/python/ray/workflow/tests/test_basic_workflows_2.py b/python/ray/workflow/tests/test_basic_workflows_2.py index acecfb14dc014..dad390635cab7 100644 --- a/python/ray/workflow/tests/test_basic_workflows_2.py +++ b/python/ray/workflow/tests/test_basic_workflows_2.py @@ -1,13 +1,10 @@ -import os import pytest import ray import re from filelock import FileLock -from pathlib import Path from ray._private.test_utils import run_string_as_driver, SignalActor from ray import workflow from ray.tests.conftest import * # noqa -from unittest.mock import patch def test_init_twice(call_ray_start, reset_workflow, tmp_path): @@ -25,11 +22,9 @@ def test_init_twice(call_ray_start, reset_workflow, tmp_path): def test_init_twice_2(call_ray_start, reset_workflow, tmp_path): - with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}): - run_string_as_driver(driver_script) - with pytest.raises( - RuntimeError, match=".*different from the workflow manager.*"): - workflow.init(str(tmp_path)) + run_string_as_driver(driver_script) + with pytest.raises(RuntimeError): + workflow.init(str(tmp_path)) @pytest.mark.parametrize( @@ -290,38 +285,6 @@ def f2(*w): f.run() -def test_dedupe_indirect(workflow_start_regular, tmp_path): - counter = Path(tmp_path) / "counter.txt" - lock = Path(tmp_path) / "lock.txt" - counter.write_text("0") - - @workflow.step - def incr(): - with FileLock(str(lock)): - c = int(counter.read_text()) - c += 1 - counter.write_text(f"{c}") - - @workflow.step - def identity(a): - return a - - @workflow.step - def join(*a): - return counter.read_text() - - # Here a is passed to two steps and we need to ensure - # it's only executed once - a = incr.step() - i1 = identity.step(a) - i2 = identity.step(a) - assert "1" == join.step(i1, i2).run() - assert "2" == join.step(i1, i2).run() - # pass a multiple times - assert "3" == join.step(a, a, a, a).run() - assert "4" == join.step(a, a, a, a).run() - - if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/workflow/tests/test_lifetime.py b/python/ray/workflow/tests/test_lifetime.py index 64a519fa19a48..8d12399369ac9 100644 --- a/python/ray/workflow/tests/test_lifetime.py +++ b/python/ray/workflow/tests/test_lifetime.py @@ -1,4 +1,3 @@ -import os import ray import time import pytest @@ -6,7 +5,6 @@ run_string_as_driver) from ray.tests.conftest import * # noqa from ray import workflow -from unittest.mock import patch driver_script = """ import time @@ -31,23 +29,21 @@ def foo(x): def test_workflow_lifetime_1(call_ray_start, reset_workflow): # Case 1: driver exits normally - with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}): - run_string_as_driver(driver_script.format(5)) - workflow.init() - output = workflow.get_output("driver_terminated") - assert ray.get(output) == 20 + run_string_as_driver(driver_script.format(5)) + workflow.init() + output = workflow.get_output("driver_terminated") + assert ray.get(output) == 20 def test_workflow_lifetime_2(call_ray_start, reset_workflow): # Case 2: driver terminated - with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}): - proc = run_string_as_driver_nonblocking(driver_script.format(100)) - time.sleep(10) - proc.kill() - time.sleep(1) - workflow.init() - output = workflow.get_output("driver_terminated") - assert ray.get(output) == 20 + proc = run_string_as_driver_nonblocking(driver_script.format(100)) + time.sleep(10) + proc.kill() + time.sleep(1) + workflow.init() + output = workflow.get_output("driver_terminated") + assert ray.get(output) == 20 if __name__ == "__main__": diff --git a/python/ray/workflow/workflow_access.py b/python/ray/workflow/workflow_access.py index c1b5d78d253a0..0524637cf08da 100644 --- a/python/ray/workflow/workflow_access.py +++ b/python/ray/workflow/workflow_access.py @@ -327,8 +327,8 @@ def load(wf_store, workflow_id, step_id): actor = get_management_actor() return actor.get_output.remote(workflow_id, result.output_step_id) - raise ValueError(f"Cannot load output from step id {step_id} " - f"in workflow {workflow_id}") + raise ValueError( + f"No such step id {step_id} in workflow {workflow_id}") return ray.put( _SelfDereferenceObject(None, diff --git a/python/ray/workflow/workflow_context.py b/python/ray/workflow/workflow_context.py index 7dec0937695f5..ffbeaafb6ce7f 100644 --- a/python/ray/workflow/workflow_context.py +++ b/python/ray/workflow/workflow_context.py @@ -1,58 +1,40 @@ -from dataclasses import dataclass, field import logging -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List from contextlib import contextmanager from ray.workflow.common import WorkflowStatus logger = logging.getLogger(__name__) -if TYPE_CHECKING: - from python.ray.workflow.common import StepID - -@dataclass class WorkflowStepContext: - """ - The structure for saving workflow step context. The context provides - critical info (e.g. where to checkpoint, which is its parent step) - for the step to execute correctly. - - To fully explain what we are doing, we need to introduce some syntax - first. The syntax for dependencies between workflow steps - "B.step(A.step())" is "A - B"; the syntax for nested workflow steps - "def A(): return B.step()" is "A / B". - - In a chain/DAG of step dependencies, the "output step" is the step of - last (topological) order. For example, in "A - B - C", C is the - output step. - - In a chain of nested workflow steps, the initial "output step" is - called the "outer most step" for other "output steps". For example, in - "A / B / C / D", "A" is the outer most step for "B", "C", "D"; - in the hybrid workflow "((A - B) / C / D) - (E / (F - G) / H)", - "B" is the outer most step for "C", "D"; "E" is the outer most step - for "G", "H". - """ - # ID of the workflow. - workflow_id: Optional[str] = None - # The storage of the workflow, used for checkpointing. - storage_url: Optional[str] = None - # The "calling stack" of the current workflow step. It describe - # the parent workflow steps. - workflow_scope: List[str] = field(default_factory=list) - # The ID of the outer most workflow. "None" if it does not exists. - outer_most_step_id: "Optional[StepID]" = None - # The step that generates the output of the workflow (including all - # nested steps). - last_step_of_workflow: bool = False + def __init__(self, + workflow_id: str = None, + storage_url: str = None, + workflow_scope: List[str] = None): + """ + The structure for saving workflow step context. The context provides + critical info (e.g. where to checkpoint, which is its parent step) + for the step to execute correctly. + + Args: + workflow_id: The workflow job ID. + storage_url: The storage of the workflow, used for checkpointing. + workflow_scope: The "calling stack" of the current workflow step. + It describe the parent workflow steps. + """ + self.workflow_id = workflow_id + self.storage_url = storage_url + self.workflow_scope = workflow_scope or [] + + def __reduce__(self): + return WorkflowStepContext, (self.workflow_id, self.storage_url, + self.workflow_scope) _context: Optional[WorkflowStepContext] = None @contextmanager -def workflow_step_context(workflow_id, - storage_url, - last_step_of_workflow=False) -> None: +def workflow_step_context(workflow_id, storage_url) -> None: """Initialize the workflow step context. Args: @@ -63,48 +45,7 @@ def workflow_step_context(workflow_id, original_context = _context assert workflow_id is not None try: - _context = WorkflowStepContext( - workflow_id, - storage_url, - last_step_of_workflow=last_step_of_workflow) - yield - finally: - _context = original_context - - -_sentinel = object() - - -@contextmanager -def fork_workflow_step_context( - workflow_id: Optional[str] = _sentinel, - storage_url: Optional[str] = _sentinel, - workflow_scope: Optional[List[str]] = _sentinel, - outer_most_step_id: Optional[str] = _sentinel, - last_step_of_workflow: Optional[bool] = _sentinel): - """Fork the workflow step context. - Inherits the original value if no value is provided. - - Args: - workflow_id: The ID of the workflow. - storage_url: The storage the workflow is using. - """ - global _context - original_context = _context - assert workflow_id is not None - try: - _context = WorkflowStepContext( - workflow_id=original_context.workflow_id - if workflow_id is _sentinel else workflow_id, - storage_url=original_context.storage_url - if storage_url is _sentinel else storage_url, - workflow_scope=original_context.workflow_scope - if workflow_scope is _sentinel else workflow_scope, - outer_most_step_id=original_context.outer_most_step_id - if outer_most_step_id is _sentinel else outer_most_step_id, - last_step_of_workflow=original_context.last_step_of_workflow - if last_step_of_workflow is _sentinel else last_step_of_workflow, - ) + _context = WorkflowStepContext(workflow_id, storage_url) yield finally: _context = original_context diff --git a/python/ray/workflow/workflow_storage.py b/python/ray/workflow/workflow_storage.py index 5a188cca1a3f2..bf18f471483de 100644 --- a/python/ray/workflow/workflow_storage.py +++ b/python/ray/workflow/workflow_storage.py @@ -117,9 +117,9 @@ def load_step_output(self, step_id: StepID) -> Any: # In this case, there is no such step raise output_err - def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], *, + def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], exception: Optional[Exception], - outer_most_step_id: StepID) -> None: + outer_most_step_id: Optional[StepID]) -> None: """When a workflow step returns, 1. If the returned object is a workflow, this means we are a nested workflow. We save the output metadata that points to the workflow. @@ -130,7 +130,8 @@ def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], *, it means we are in the workflow job driver process. ret: The returned object from a workflow step. exception: This step should throw exception. - outer_most_step_id: See WorkflowStepContext. + outer_most_step_id: See + "step_executor.execute_workflow" for explanation. """ tasks = [] if isinstance(ret, Workflow): @@ -153,9 +154,14 @@ def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], *, # tasks.append(self._put(self._key_step_output(step_id), ret)) dynamic_output_id = step_id # TODO (yic): Delete exception file - tasks.append( - self._update_dynamic_output(outer_most_step_id, - dynamic_output_id)) + + # outer_most_step_id == "" indicates the root step of a + # workflow. This would directly update "outputs.json" in + # the workflow dir, and we want to avoid it. + if outer_most_step_id is not None and outer_most_step_id != "": + tasks.append( + self._update_dynamic_output(outer_most_step_id, + dynamic_output_id)) else: assert ret is None promise = serialization.dump_to_storage( @@ -265,15 +271,10 @@ async def _update_dynamic_output(self, outer_most_step_id: StepID, critical for scalability of virtual actors. Args: - outer_most_step_id: See WorkflowStepContext for explanation. + outer_most_step_id: ID of outer_most_step. See + "step_executor.execute_workflow" for explanation. dynamic_output_step_id: ID of dynamic_step. """ - # outer_most_step_id == "" indicates the root step of a - # workflow. This would directly update "outputs.json" in - # the workflow dir, and we want to avoid it. - if outer_most_step_id is None or outer_most_step_id == "": - return - metadata = await self._get( self._key_step_output_metadata(outer_most_step_id), True) if (dynamic_output_step_id != metadata["output_step_id"] diff --git a/python/requirements.txt b/python/requirements.txt index 2f683373fbc05..4d0baeaf9ef80 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -5,7 +5,7 @@ # In short, if you change it here, PLEASE also change it in setup.py. # # setup.py install_requires -aiohttp>=3.7 +aiohttp==3.7 aioredis < 2 click >= 7.0 cloudpickle @@ -27,7 +27,7 @@ requests ## setup.py extras dm_tree flask -gym==0.19 +gym lz4 scikit-image opencv-python-headless==4.3.0.36 @@ -68,7 +68,6 @@ opentelemetry-exporter-otlp==1.1.0 pexpect Pillow; platform_system != "Windows" pygments -pyspark pytest==5.4.3 pytest-asyncio pytest-rerunfailures diff --git a/python/requirements/ml/requirements_rllib.txt b/python/requirements/ml/requirements_rllib.txt index 6bba94e49fc99..a81e52c9c1f08 100644 --- a/python/requirements/ml/requirements_rllib.txt +++ b/python/requirements/ml/requirements_rllib.txt @@ -10,9 +10,9 @@ kaggle_environments==1.7.11 # Unity3D testing mlagents_envs==0.27.0 # For tests on PettingZoo's multi-agent envs. -pettingzoo==1.11.1 +pettingzoo==1.11.0 pymunk==6.0.0 -supersuit==2.6.6 +supersuit # For testing in MuJoCo-like envs (in PyBullet). pybullet==3.1.7 # For tests on RecSim and Kaggle envs. diff --git a/python/requirements/requirements_default.txt b/python/requirements/requirements_default.txt index 2df14c6e7588d..4537b9f9ea2a6 100644 --- a/python/requirements/requirements_default.txt +++ b/python/requirements/requirements_default.txt @@ -1,4 +1,4 @@ -aiohttp>=3.7 +aiohttp aiohttp_cors aioredis<2 colorful diff --git a/python/requirements_linters.txt b/python/requirements_linters.txt index 69f457fea1688..6f5661b1f2b2f 100644 --- a/python/requirements_linters.txt +++ b/python/requirements_linters.txt @@ -1,6 +1,5 @@ flake8==3.9.1 flake8-comprehensions flake8-quotes==2.0.0 -flake8-bugbear==21.9.2 mypy==0.782 yapf==0.23.0 diff --git a/python/setup.py b/python/setup.py index 3fb8ff43ab262..62d1e4e36fa46 100644 --- a/python/setup.py +++ b/python/setup.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) SUPPORTED_PYTHONS = [(3, 6), (3, 7), (3, 8), (3, 9)] -SUPPORTED_BAZEL = (4, 2, 1) +SUPPORTED_BAZEL = (3, 4, 1) ROOT_DIR = os.path.dirname(__file__) BUILD_JAVA = os.getenv("RAY_INSTALL_JAVA") == "1" @@ -184,13 +184,8 @@ def get_packages(self): # in this directory if setup_spec.type == SetupType.RAY: setup_spec.extras = { - "data": [ - "pandas", - "pyarrow>=4.0.1", - "fsspec", - ], "default": [ - "aiohttp >= 3.7", + "aiohttp", "aiohttp_cors", "aioredis < 2", "colorful", @@ -539,19 +534,6 @@ def copy_file(target_dir, filename, rootdir): return 0 -def add_system_dlls(dlls, target_dir): - """ - Copy any required dlls required by the c-extension module and not already - provided by python. They will end up in the wheel next to the c-extension - module which will guarentee they are available at runtime. - """ - for dll in dlls: - # Installing Visual Studio will copy the runtime dlls to system32 - src = os.path.join(r"c:\Windows\system32", dll) - assert os.path.exists(src) - shutil.copy(src, target_dir) - - def pip_run(build_ext): build(True, BUILD_JAVA, True) @@ -576,13 +558,6 @@ def pip_run(build_ext): copied_files = 0 for filename in setup_spec.files_to_include: copied_files += copy_file(build_ext.build_lib, filename, ROOT_DIR) - if sys.platform == "win32": - # _raylet.pyd links to some MSVC runtime DLLS, this one may not be - # present on a user's machine. While vcruntime140.dll and - # vcruntime140_1.dll are also required, they are provided by CPython. - runtime_dlls = ["msvcp140.dll"] - add_system_dlls(runtime_dlls, os.path.join(build_ext.build_lib, "ray")) - copied_files += len(runtime_dlls) print("# of files copied to {}: {}".format(build_ext.build_lib, copied_files)) diff --git a/release/.buildkite/build_pipeline.py b/release/.buildkite/build_pipeline.py index 9cbfdbdc08a0c..96d58a2b54f2a 100644 --- a/release/.buildkite/build_pipeline.py +++ b/release/.buildkite/build_pipeline.py @@ -99,7 +99,6 @@ def __init__(self, name: str, retry: int = 0): "~/ray/release/nightly_tests/nightly_tests.yaml": [ "dask_on_ray_large_scale_test_no_spilling", "dask_on_ray_large_scale_test_spilling", - "pg_autoscaling_regression_test", ], "~/ray/release/long_running_tests/long_running_tests.yaml": [ SmokeTest("actor_deaths"), diff --git a/release/RELEASE_CHECKLIST.md b/release/RELEASE_CHECKLIST.md index e4770a1cb6fdd..f8a55bfff9aec 100644 --- a/release/RELEASE_CHECKLIST.md +++ b/release/RELEASE_CHECKLIST.md @@ -31,7 +31,6 @@ This checklist is meant to be used in conjunction with the RELEASE_PROCESS.rst d - [ ] Test passing - [ ] Results added to `release/release_logs` - [ ] microbenchmark -- [ ] `kubernetes` manual release tests pass - [ ] ``weekly`` release test suite - [ ] Test passing diff --git a/release/RELEASE_PROCESS.rst b/release/RELEASE_PROCESS.rst index b2f7d05db5492..59da95846cf3c 100644 --- a/release/RELEASE_PROCESS.rst +++ b/release/RELEASE_PROCESS.rst @@ -172,9 +172,6 @@ Release tests are added and maintained by the respective teams. As another example, if you just want to kick off all nightly RLLib tests, select the respective test suite and specify ``rllib`` in the test file filter. -6. **Kubernetes tests must be run manually.** Refer to ``kubernetes_manual_tests/README.md``. - Feel free to ping code owner(s) of OSS Kubernetes support to run these. - Identify and Resolve Release Blockers ------------------------------------- If a release blocking issue arises in the course of testing, you should diff --git a/release/alerts/xgboost_tests.py b/release/alerts/xgboost_tests.py index 8b77cc17f49c7..59ab2880adf76 100644 --- a/release/alerts/xgboost_tests.py +++ b/release/alerts/xgboost_tests.py @@ -43,9 +43,7 @@ def handle_result(created_on: datetime.datetime, category: str, else: # train scripts if test_name == "train_small": - # Leave a couple of seconds for ray connect setup - # (without connect it should finish in < 30) - target_time = 45 + target_time = 30 elif test_name == "train_moderate": target_time = 60 elif test_name == "train_gpu": diff --git a/release/e2e.py b/release/e2e.py index f47a0bbeecf08..1b5fe71d15923 100644 --- a/release/e2e.py +++ b/release/e2e.py @@ -264,30 +264,11 @@ def getenv_default(key: str, default: Optional[str] = None): } REPORT_S = 30 -RETRY_MULTIPLIER = 2 - - -def exponential_backoff_retry(f, retry_exceptions, initial_retry_delay_s, - max_retries): - retry_cnt = 0 - retry_delay_s = initial_retry_delay_s - while True: - try: - return f() - except retry_exceptions as e: - retry_cnt += 1 - if retry_cnt > max_retries: - raise - logger.info(f"Retry function call failed due to {e} " - f"in {retry_delay_s} seconds...") - time.sleep(retry_delay_s) - retry_delay_s *= RETRY_MULTIPLIER def maybe_fetch_api_token(): if GLOBAL_CONFIG["ANYSCALE_CLI_TOKEN"] is None: - logger.info( - "Missing ANYSCALE_CLI_TOKEN, retrieving from AWS secrets store") + print("Missing ANYSCALE_CLI_TOKEN, retrieving from AWS secrets store") # NOTE(simon) This should automatically retrieve # release-automation@anyscale.com's anyscale token GLOBAL_CONFIG["ANYSCALE_CLI_TOKEN"] = boto3.client( @@ -424,8 +405,7 @@ def populate_wheels_sanity_check(commit: Optional[str] = None): raise RuntimeError(f"Could not populate wheels sanity check command: " f"Commit hash missing. Got: {commit}") - cmd = (f"python -c 'import ray; " - f"assert ray.__commit__ == \"{commit}\", ray.__commit__'") + cmd = f"python -c 'import ray; assert ray.__commit__ == \"{commit}\"'" os.environ["RAY_WHEELS_SANITY_CHECK"] = cmd @@ -483,7 +463,7 @@ def has_errored(result: Dict[Any, Any]) -> bool: return result.get("status", "invalid") != "finished" -def report_result(test_suite: str, test_name: str, status: str, last_logs: str, +def report_result(test_suite: str, test_name: str, status: str, logs: str, results: Dict[Any, Any], artifacts: Dict[Any, Any], category: str): now = datetime.datetime.utcnow() @@ -497,66 +477,67 @@ def report_result(test_suite: str, test_name: str, status: str, last_logs: str, f"results, artifacts, category) " f"VALUES (:created_on, :test_suite, :test_name, :status, :last_logs, " f":results, :artifacts, :category)") - parameters = [{ - "name": "created_on", - "typeHint": "TIMESTAMP", - "value": { - "stringValue": now.strftime("%Y-%m-%d %H:%M:%S") - }, - }, { - "name": "test_suite", - "value": { - "stringValue": test_suite - } - }, { - "name": "test_name", - "value": { - "stringValue": test_name - } - }, { - "name": "status", - "value": { - "stringValue": status - } - }, { - "name": "last_logs", - "value": { - "stringValue": last_logs - } - }, { - "name": "results", - "typeHint": "JSON", - "value": { - "stringValue": json.dumps(results) - }, - }, { - "name": "artifacts", - "typeHint": "JSON", - "value": { - "stringValue": json.dumps(artifacts) - }, - }, { - "name": "category", - "value": { - "stringValue": category - } - }] - - # Default boto3 call timeout is 45 seconds. - retry_delay_s = 64 - MAX_RDS_RETRY = 3 - exponential_backoff_retry( - lambda: rds_data_client.execute_statement( - database=GLOBAL_CONFIG["RELEASE_AWS_DB_NAME"], - parameters=parameters, - secretArn=GLOBAL_CONFIG["RELEASE_AWS_DB_SECRET_ARN"], - resourceArn=GLOBAL_CONFIG["RELEASE_AWS_DB_RESOURCE_ARN"], - schema=schema, - sql=sql), - retry_exceptions=rds_data_client.exceptions.StatementTimeoutException, - initial_retry_delay_s=retry_delay_s, - max_retries=MAX_RDS_RETRY) - logger.info("Result has been persisted to the databse") + + rds_data_client.execute_statement( + database=GLOBAL_CONFIG["RELEASE_AWS_DB_NAME"], + parameters=[ + { + "name": "created_on", + "typeHint": "TIMESTAMP", + "value": { + "stringValue": now.strftime("%Y-%m-%d %H:%M:%S") + }, + }, + { + "name": "test_suite", + "value": { + "stringValue": test_suite + } + }, + { + "name": "test_name", + "value": { + "stringValue": test_name + } + }, + { + "name": "status", + "value": { + "stringValue": status + } + }, + { + "name": "last_logs", + "value": { + "stringValue": logs + } + }, + { + "name": "results", + "typeHint": "JSON", + "value": { + "stringValue": json.dumps(results) + }, + }, + { + "name": "artifacts", + "typeHint": "JSON", + "value": { + "stringValue": json.dumps(artifacts) + }, + }, + { + "name": "category", + "value": { + "stringValue": category + } + }, + ], + secretArn=GLOBAL_CONFIG["RELEASE_AWS_DB_SECRET_ARN"], + resourceArn=GLOBAL_CONFIG["RELEASE_AWS_DB_RESOURCE_ARN"], + schema=schema, + sql=sql, + ) def log_results_and_artifacts(result: Dict): @@ -922,11 +903,7 @@ def wait_for_session_command_to_complete(create_session_command_result, # Sleep 1 sec before next check. time.sleep(1) - result = exponential_backoff_retry( - lambda: sdk.get_session_command(session_command_id=scd_id), - retry_exceptions=Exception, - initial_retry_delay_s=10, - max_retries=3) + result = sdk.get_session_command(session_command_id=scd_id) completed = result.result.finished_at if state_str == "CMD_RUN": @@ -957,14 +934,10 @@ def wait_for_session_command_to_complete(create_session_command_result, def get_command_logs(session_controller: SessionController, scd_id: str, lines: int = 50): - result = exponential_backoff_retry( - lambda: session_controller.api_client.get_execution_logs_api_v2_session_commands_session_command_id_execution_logs_get( # noqa: E501 - session_command_id=scd_id, - start_line=-1 * lines, - end_line=0), - retry_exceptions=Exception, - initial_retry_delay_s=10, - max_retries=3) + result = session_controller.api_client.get_execution_logs_api_v2_session_commands_session_command_id_execution_logs_get( # noqa: E501 + session_command_id=scd_id, + start_line=-1 * lines, + end_line=0) return result.result.lines @@ -1804,7 +1777,7 @@ def run_test(test_config_file: str, report: bool = True, keep_results_dir: bool = False, session_name: Optional[str] = None, - app_config_id_override=None) -> Dict[str, Any]: + app_config_id_override=None): with open(test_config_file, "rt") as f: test_configs = yaml.load(f, Loader=yaml.FullLoader) @@ -1863,18 +1836,18 @@ def run_test(test_config_file: str, logger.info("Kicked off test. It's now up to the `--check` " "part of the script to track its process.") - return {} + return else: # `--check` or no kick off only if status == "nosession": logger.info(f"No running session found for test {test_name}, so " f"assuming everything is fine.") - return {} + return if status == "kickoff": logger.info(f"Test {test_name} is still running.") - return {} + return last_logs = result.get("last_logs", "No logs.") @@ -1884,7 +1857,7 @@ def run_test(test_config_file: str, test_suite=test_suite, test_name=test_name, status=status, - last_logs=last_logs, + logs=last_logs, results=result.get("results", {}), artifacts=result.get("artifacts", {}), category=category, @@ -1899,7 +1872,7 @@ def run_test(test_config_file: str, if has_errored(result): raise RuntimeError(last_logs) - return report_kwargs + return if __name__ == "__main__": @@ -1962,6 +1935,7 @@ def run_test(test_config_file: str, "You have to set the ANYSCALE_PROJECT environment variable!") maybe_fetch_api_token() + if args.ray_wheels: os.environ["RAY_WHEELS"] = str(args.ray_wheels) url = str(args.ray_wheels) @@ -1981,7 +1955,7 @@ def run_test(test_config_file: str, test_config_file = os.path.abspath(os.path.expanduser(args.test_config)) - result_dict = run_test( + run_test( test_config_file=test_config_file, test_name=args.test_name, project_id=GLOBAL_CONFIG["ANYSCALE_PROJECT"], @@ -1996,30 +1970,3 @@ def run_test(test_config_file: str, keep_results_dir=args.keep_results_dir, app_config_id_override=args.app_config_id_override, ) - - if result_dict: - # If we get a result dict, check if any alerts should be raised - from alert import SUITE_TO_FN, default_handle_result - - logger.info("Checking if results are valid...") - - handle_result_kwargs = result_dict.copy() - handle_result_kwargs["created_on"] = None - - test_suite = handle_result_kwargs.get("test_suite", None) - test_name = handle_result_kwargs.get("test_name", None) - category = handle_result_kwargs.get("category", None) - - handle_fn = SUITE_TO_FN.get(test_suite, None) - if not handle_fn: - logger.warning(f"No handle for suite {test_suite}") - alert = default_handle_result(**handle_result_kwargs) - else: - alert = handle_fn(**handle_result_kwargs) - - if alert: - # If we get an alert, the test failed. - raise RuntimeError(alert) - else: - logger.info(f"No alert raised for test {test_suite}/{test_name} " - f"({category}) - the test successfully passed!") diff --git a/release/golden_notebook_tests/dask_xgboost_app_config.yaml b/release/golden_notebook_tests/dask_xgboost_app_config.yaml index a05da857edef8..072b183099476 100755 --- a/release/golden_notebook_tests/dask_xgboost_app_config.yaml +++ b/release/golden_notebook_tests/dask_xgboost_app_config.yaml @@ -5,8 +5,9 @@ debian_packages: python: pip_packages: + - pytest - pandas>=1.3.0 # otherwise, a version mismatch between local and remote will cause an exception - - git+https://github.com/ray-project/xgboost_ray.git#egg=xgboost_ray[default] + - xgboost_ray[default] - dask - fastapi - uvicorn @@ -15,5 +16,5 @@ python: post_build_cmds: - pip uninstall -y ray || true - - pip install -U {{ env["RAY_WHEELS"] | default("ray") }} + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/golden_notebook_tests/golden_notebook_tests.yaml b/release/golden_notebook_tests/golden_notebook_tests.yaml index e6d5838d10333..1fae1e1d65824 100644 --- a/release/golden_notebook_tests/golden_notebook_tests.yaml +++ b/release/golden_notebook_tests/golden_notebook_tests.yaml @@ -1,7 +1,4 @@ - name: dask_xgboost_test - owner: - mail: "antoni@anyscale.com" - slack: "@team_ml" cluster: app_config: dask_xgboost_app_config.yaml compute_template: compute_tpl.yaml @@ -11,18 +8,8 @@ autosuspend_mins: 10 timeout: 1200 script: python workloads/dask_xgboost_test.py - args: - [ - "--num-actors 4", - "--cpus-per-actor 4", - "--num-actors-inference 16", - "--cpus-per-actor-inference 1", - ] - name: modin_xgboost_test - owner: - mail: "antoni@anyscale.com" - slack: "@team_ml" cluster: app_config: modin_xgboost_app_config.yaml compute_template: compute_tpl.yaml @@ -32,13 +19,6 @@ autosuspend_mins: 10 timeout: 1200 script: python workloads/modin_xgboost_test.py - args: - [ - "--num-actors 4", - "--cpus-per-actor 4", - "--num-actors-inference 16", - "--cpus-per-actor-inference 1", - ] - name: torch_tune_serve_test owner: @@ -54,3 +34,4 @@ autosuspend_mins: 10 timeout: 1200 script: python workloads/torch_tune_serve_test.py + diff --git a/release/golden_notebook_tests/modin_xgboost_app_config.yaml b/release/golden_notebook_tests/modin_xgboost_app_config.yaml index 5fb35e7b03fdd..c17fa85ca0144 100755 --- a/release/golden_notebook_tests/modin_xgboost_app_config.yaml +++ b/release/golden_notebook_tests/modin_xgboost_app_config.yaml @@ -5,8 +5,7 @@ debian_packages: python: pip_packages: - - pandas>=1.3.0 # otherwise, a version mismatch between local and remote will cause an exception - - git+https://github.com/ray-project/xgboost_ray.git#egg=xgboost_ray[default] + - pytest - modin - s3fs - fastapi @@ -17,4 +16,4 @@ python: post_build_cmds: - pip uninstall -y ray || true - pip install -U {{ env["RAY_WHEELS"] | default("ray") }} - - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} + - pip install git+https://github.com/ray-project/xgboost_ray.git#egg=xgboost_ray[default] diff --git a/release/golden_notebook_tests/workloads/dask_xgboost_test.py b/release/golden_notebook_tests/workloads/dask_xgboost_test.py index c10bf91d96754..99755eb4399bb 100644 --- a/release/golden_notebook_tests/workloads/dask_xgboost_test.py +++ b/release/golden_notebook_tests/workloads/dask_xgboost_test.py @@ -1,28 +1,135 @@ -import ray +import argparse +import json import os import time -import json -from util import import_and_execute_test_script, wait_for_cluster_client -NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO = ( - "doc/examples/dask_xgboost/dask_xgboost.py") +import dask +import dask.dataframe as dd +import ray +from ray import tune + +from ray.util.dask import ray_dask_get + +from xgboost_ray import RayDMatrix, RayParams, train, predict + +from utils.utils import is_anyscale_connect + +FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/" \ + "simpleHIGGS.csv" + + +def train_xgboost(config, train_df, test_df, target_column, ray_params): + # distributed loading of a parquet dataset + train_set = RayDMatrix(train_df, target_column) + test_set = RayDMatrix(test_df, target_column) + + evals_result = {} + + start_time = time.time() + # Train the classifier + bst = train( + params=config, + dtrain=train_set, + evals=[(test_set, "eval")], + evals_result=evals_result, + verbose_eval=False, + num_boost_round=100, + ray_params=ray_params) + print(f"Total time taken: {time.time()-start_time}") + + model_path = "model.xgb" + bst.save_model(model_path) + print("Final validation error: {:.4f}".format( + evals_result["eval"]["error"][-1])) + + return bst + + +def tune_xgboost(train_df, test_df, target_column): + # Set XGBoost config. + config = { + "tree_method": "approx", + "objective": "binary:logistic", + "eval_metric": ["logloss", "error"], + "eta": tune.loguniform(1e-4, 1e-1), + "subsample": tune.uniform(0.5, 1.0), + "max_depth": tune.randint(1, 9) + } + + ray_params = RayParams( + max_actor_restarts=1, gpus_per_actor=0, cpus_per_actor=4, num_actors=4) + + analysis = tune.run( + tune.with_parameters( + train_xgboost, + train_df=train_df, + test_df=test_df, + target_column=target_column, + ray_params=ray_params), + # Use the `get_tune_resources` helper function to set the resources. + resources_per_trial=ray_params.get_tune_resources(), + config=config, + num_samples=1, + metric="eval-error", + mode="min", + verbose=1) + + accuracy = 1. - analysis.best_result["eval-error"] + print(f"Best model parameters: {analysis.best_config}") + print(f"Best model total accuracy: {accuracy:.4f}") + + return analysis.best_config def main(): - import_and_execute_test_script(NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO) + print("Loading HIGGS data.") + + dask.config.set(scheduler=ray_dask_get) + colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)] + data = dd.read_csv(FILE_URL, names=colnames) + + print("Loaded HIGGS data.") + + # partition on a column + df_train = data[(data["feature-01"] < 0.4)] + df_validation = data[(data["feature-01"] >= 0.4) + & (data["feature-01"] < 0.8)] + + config = { + "tree_method": "approx", + "objective": "binary:logistic", + "eval_metric": ["logloss", "error"], + } + + bst = train_xgboost( + config, df_train, df_validation, "label", + RayParams(max_actor_restarts=1, cpus_per_actor=4, num_actors=4)) + tune_xgboost(df_train, df_validation, "label") + inference_df = RayDMatrix( + df_train[sorted(df_train.columns)], ignore=["label", "partition"]) + predict( + bst, + inference_df, + ray_params=RayParams(cpus_per_actor=2, num_actors=16)) if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", + action="store_true", + help="Finish quickly for testing.") + args = parser.parse_args() + start = time.time() addr = os.environ.get("RAY_ADDRESS") job_name = os.environ.get("RAY_JOB_NAME", "dask_xgboost_test") - if addr is not None and addr.startswith("anyscale://"): + if is_anyscale_connect(addr): ray.init(address=addr, job_name=job_name) else: ray.init(address="auto") - wait_for_cluster_client(4, 600) main() taken = time.time() - start diff --git a/release/golden_notebook_tests/workloads/modin_xgboost_test.py b/release/golden_notebook_tests/workloads/modin_xgboost_test.py index d5fb36f07b23e..4180351e7cb40 100644 --- a/release/golden_notebook_tests/workloads/modin_xgboost_test.py +++ b/release/golden_notebook_tests/workloads/modin_xgboost_test.py @@ -1,28 +1,131 @@ -import ray +import argparse +import json import os import time -import json -from util import import_and_execute_test_script, wait_for_cluster_client -NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO = ( - "doc/examples/modin_xgboost/modin_xgboost.py") +import modin.pandas as pd +import ray +from ray import tune +from xgboost_ray import RayDMatrix, RayParams, train, predict + +from utils.utils import is_anyscale_connect + +FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/" \ + "simpleHIGGS.csv" + + +def train_xgboost(config, train_df, test_df, target_column, ray_params): + # distributed loading of a parquet dataset + train_set = RayDMatrix(train_df, target_column) + test_set = RayDMatrix(test_df, target_column) + + evals_result = {} + + start_time = time.time() + # Train the classifier + bst = train( + params=config, + dtrain=train_set, + evals=[(test_set, "eval")], + evals_result=evals_result, + verbose_eval=False, + num_boost_round=100, + ray_params=ray_params) + print(f"Total time taken: {time.time()-start_time}") + + model_path = "model.xgb" + bst.save_model(model_path) + print("Final validation error: {:.4f}".format( + evals_result["eval"]["error"][-1])) + + return bst + + +def tune_xgboost(train_df, test_df, target_column): + # Set XGBoost config. + config = { + "tree_method": "approx", + "objective": "binary:logistic", + "eval_metric": ["logloss", "error"], + "eta": tune.loguniform(1e-4, 1e-1), + "subsample": tune.uniform(0.5, 1.0), + "max_depth": tune.randint(1, 9) + } + + ray_params = RayParams( + max_actor_restarts=1, gpus_per_actor=0, cpus_per_actor=1, num_actors=2) + + analysis = tune.run( + tune.with_parameters( + train_xgboost, + train_df=train_df, + test_df=test_df, + target_column=target_column, + ray_params=ray_params), + # Use the `get_tune_resources` helper function to set the resources. + resources_per_trial=ray_params.get_tune_resources(), + config=config, + num_samples=1, + metric="eval-error", + mode="min", + verbose=1) + + accuracy = 1. - analysis.best_result["eval-error"] + print(f"Best model parameters: {analysis.best_config}") + print(f"Best model total accuracy: {accuracy:.4f}") + + return analysis.best_config def main(): - import_and_execute_test_script(NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO) + print("Loading HIGGS data.") + + colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)] + + data = pd.read_csv(FILE_URL, names=colnames) + + print("Loaded HIGGS data.") + + # partition on a column + df_train = data[(data["feature-01"] < 0.4)] + df_validation = data[(data["feature-01"] >= 0.4) + & (data["feature-01"] < 0.8)] + + config = { + "tree_method": "approx", + "objective": "binary:logistic", + "eval_metric": ["logloss", "error"], + } + + bst = train_xgboost( + config, df_train, df_validation, "label", + RayParams(max_actor_restarts=1, cpus_per_actor=4, num_actors=4)) + # tune_xgboost(df_train, df_validation, "label") # broken atm + inference_df = RayDMatrix( + df_train[sorted(df_train.columns)], ignore=["label", "partition"]) + predict( + bst, + inference_df, + ray_params=RayParams(cpus_per_actor=1, num_actors=16)) if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", + action="store_true", + help="Finish quickly for testing.") + args = parser.parse_args() + start = time.time() addr = os.environ.get("RAY_ADDRESS") job_name = os.environ.get("RAY_JOB_NAME", "modin_xgboost_test") - if addr is not None and addr.startswith("anyscale://"): + if is_anyscale_connect(addr): ray.init(address=addr, job_name=job_name) else: ray.init(address="auto") - wait_for_cluster_client(4, 600) main() taken = time.time() - start diff --git a/release/golden_notebook_tests/workloads/torch_tune_serve_test.py b/release/golden_notebook_tests/workloads/torch_tune_serve_test.py index 9b511d5765ae6..15bd43a575a7a 100644 --- a/release/golden_notebook_tests/workloads/torch_tune_serve_test.py +++ b/release/golden_notebook_tests/workloads/torch_tune_serve_test.py @@ -17,6 +17,8 @@ from torch.utils.data import DataLoader, Subset from torchvision.datasets import MNIST +from utils.utils import is_anyscale_connect + def load_mnist_data(train: bool, download: bool): transform = transforms.Compose( @@ -198,7 +200,7 @@ def test_predictions(test_mode=False): addr = os.environ.get("RAY_ADDRESS") job_name = os.environ.get("RAY_JOB_NAME", "torch_tune_serve_test") - if addr is not None and addr.startswith("anyscale://"): + if is_anyscale_connect(addr): client = ray.init(address=addr, job_name=job_name) else: client = ray.init(address="auto") diff --git a/release/golden_notebook_tests/workloads/util.py b/release/golden_notebook_tests/workloads/util.py deleted file mode 100644 index a0efc28b0e73a..0000000000000 --- a/release/golden_notebook_tests/workloads/util.py +++ /dev/null @@ -1,49 +0,0 @@ -from pathlib import Path -import importlib.util -import ray -import time - - -def import_and_execute_test_script(relative_path_to_test_script: str): - """Imports and executes a module from a path relative to Ray repo root.""" - # get the ray folder - ray_path = next( - x for x in Path(__file__).resolve().parents if str(x).endswith("/ray")) - notebook_path = ray_path.joinpath(relative_path_to_test_script) - assert notebook_path.exists() - - spec = importlib.util.spec_from_file_location("notebook_test", - notebook_path) - notebook_test_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(notebook_test_module) - - -def wait_for_cluster_client(num_nodes: int, - max_time_s: int, - feedback_interval_s: int = 10): - assert ray.is_initialized() - curr_nodes = 0 - start = time.time() - next_feedback = start - max_time = start + max_time_s - while not curr_nodes >= num_nodes: - now = time.time() - - if now >= max_time: - raise RuntimeError( - f"Maximum wait time reached, but only " - f"{curr_nodes}/{num_nodes} nodes came up. Aborting.") - - if now >= next_feedback: - passed = now - start - print(f"Waiting for more nodes to come up: " - f"{curr_nodes}/{num_nodes} " - f"({passed:.0f} seconds passed)") - next_feedback = now + feedback_interval_s - - time.sleep(5) - curr_nodes = len(ray.nodes()) - - passed = time.time() - start - print(f"Cluster is up: {curr_nodes}/{num_nodes} nodes online after " - f"{passed:.0f} seconds") diff --git a/release/golden_notebook_tests/workloads/utils/utils.py b/release/golden_notebook_tests/workloads/utils/utils.py new file mode 100644 index 0000000000000..071f076c72aee --- /dev/null +++ b/release/golden_notebook_tests/workloads/utils/utils.py @@ -0,0 +1,5 @@ +def is_anyscale_connect(address: str) -> bool: + """Returns whether or not the Ray Address points to an Anyscale cluster.""" + is_anyscale_connect = address is not None and address.startswith( + "anyscale://") + return is_anyscale_connect diff --git a/release/kubernetes_manual_tests/README.md b/release/kubernetes_manual_tests/README.md deleted file mode 100644 index 12b61f272b079..0000000000000 --- a/release/kubernetes_manual_tests/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# ray-k8s-tests - -These tests are not automated and thus **must be run manually** for each release. -If you have issues running them, bug the code owner(s) for OSS Kubernetes support. - -## How to run -1. Configure kubectl and Helm 3 to access a K8s cluster. -2. `git checkout releases/` -3. You might have to locally pip install the Ray wheel for the relevant commit (or pip install -e) in a conda env, see Ray client note below. -4. cd to this directory -5. `IMAGE=rayproject/ray: bash k8s_release_tests.sh` -6. Test outcomes will be reported at the end of the output. - -This runs three tests and does the necessary resource creation/teardown. The tests typically take about 15 minutes to finish. - -## Notes -0. Anyscale employees: You should have access to create a K8s cluster using either GKE or EKS, ask OSS Kubernetes code owner if in doubt. -1. Your Ray cluster should be able to accomodate 30 1-CPU pods to run all of the tests. -2. These tests use basic Ray client functionality -- your locally installed Ray version may need to be updated to match the one in the release image. -3. The tests do a poor job of Ray client port-forwarding process clean-up -- if a test fails, it's possible there might be a port-forwarding process stuck running in the background. To identify the rogue process run `ps aux | grep "port-forward"`. Then `kill` it. -4. There are some errors that will appear on the screen during the run -- that's normal, error recovery is being tested. - -## Running individual tests -To run any of the three individual tests, substitute in step 5 of **How to Run** `k8s-test.sh` or `helm-test.sh` or `k8s-test-scale.sh`. -It's the last of these that needs 30 1-cpu pods. 10 is enough for either of the other two. The scale test is currently somewhat flaky. Rerun it if it fails. diff --git a/release/kubernetes_manual_tests/helm-test.sh b/release/kubernetes_manual_tests/helm-test.sh deleted file mode 100755 index 273ddb5c1cc11..0000000000000 --- a/release/kubernetes_manual_tests/helm-test.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -set -x -kubectl create namespace helm-test -kubectl create namespace helm-test2 -KUBERNETES_OPERATOR_TEST_NAMESPACE=helm-test KUBERNETES_OPERATOR_TEST_IMAGE="$IMAGE" python ../../python/ray/tests/kubernetes_e2e/test_helm.py -kubectl delete namespace helm-test -kubectl delete namespace helm-test2 -kubectl delete -f ../../deploy/charts/ray/crds/cluster_crd.yaml diff --git a/release/kubernetes_manual_tests/k8s-test-scale.sh b/release/kubernetes_manual_tests/k8s-test-scale.sh deleted file mode 100755 index 59ea06c80f5f1..0000000000000 --- a/release/kubernetes_manual_tests/k8s-test-scale.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -set -x -kubectl create namespace scale-test -kubectl create namespace scale-test2 -KUBERNETES_OPERATOR_TEST_NAMESPACE=scale-test KUBERNETES_OPERATOR_TEST_IMAGE="$IMAGE" python ../../python/ray/tests/kubernetes_e2e/test_k8s_operator_scaling.py -kubectl -n scale-test delete --all rayclusters -kubectl -n scale-test2 delete --all rayclusters -kubectl delete -f ../../deploy/components/operator_cluster_scoped.yaml -kubectl delete namespace scale-test -kubectl delete namespace scale-test2 -kubectl delete -f ../../deploy/charts/ray/crds/cluster_crd.yaml diff --git a/release/kubernetes_manual_tests/k8s-test.sh b/release/kubernetes_manual_tests/k8s-test.sh deleted file mode 100755 index aa0ec6325d880..0000000000000 --- a/release/kubernetes_manual_tests/k8s-test.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x -kubectl create namespace basic-test -kubectl apply -f ../../deploy/charts/ray/crds/cluster_crd.yaml -KUBERNETES_OPERATOR_TEST_NAMESPACE=basic-test KUBERNETES_OPERATOR_TEST_IMAGE="$IMAGE" python ../../python/ray/tests/kubernetes_e2e/test_k8s_operator_basic.py -kubectl -n basic-test delete --all rayclusters -kubectl -n basic-test delete deployment ray-operator -kubectl delete namespace basic-test -kubectl delete -f ../../deploy/charts/ray/crds/cluster_crd.yaml diff --git a/release/kubernetes_manual_tests/k8s_release_tests.sh b/release/kubernetes_manual_tests/k8s_release_tests.sh deleted file mode 100644 index 6576dcdabfa39..0000000000000 --- a/release/kubernetes_manual_tests/k8s_release_tests.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -set -x -IMAGE="$IMAGE" bash k8s-test.sh -BASIC_SUCCEEDED=$? -IMAGE="$IMAGE" bash helm-test.sh -HELM_SUCCEEDED=$? -IMAGE="$IMAGE" bash k8s-test-scale.sh -SCALE_SUCCEEDED=$? - -if (( BASIC_SUCCEEDED == 0 )) -then - echo "k8s-test.sh succeeded" -else - echo "k8s-test.sh test failed" -fi - -if (( HELM_SUCCEEDED == 0 )) -then - echo "helm-test.sh test succeeded"; -else - echo "helm-test.sh test failed" -fi - -if (( SCALE_SUCCEEDED == 0)) -then - echo "k8s-test-scale.sh test succeeded"; -else - echo "k8s-test-scale.sh failed. Try re-running just the k8s-test-scale.sh. It's expected to be flaky." -fi - diff --git a/release/long_running_tests/tpl_cpu_1.yaml b/release/long_running_tests/tpl_cpu_1.yaml index a22bc5dfc95a7..1045aa8948456 100644 --- a/release/long_running_tests/tpl_cpu_1.yaml +++ b/release/long_running_tests/tpl_cpu_1.yaml @@ -22,8 +22,3 @@ aws: Value: '{{env["ANYSCALE_USER"]}}' - Key: anyscale-expiration Value: '{{env["EXPIRATION_2D"]}}' - - BlockDeviceMappings: - - DeviceName: /dev/sda1 - Ebs: - VolumeSize: 202 \ No newline at end of file diff --git a/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml b/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml index 1aa0b86782476..9b7a0a9a11d3f 100644 --- a/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml +++ b/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml @@ -1,4 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" +env_vars: {} debian_packages: [] python: diff --git a/release/nightly_tests/dataset/app_config.yaml b/release/nightly_tests/dataset/app_config.yaml index 5f311fbabfe87..c0cc753990de9 100644 --- a/release/nightly_tests/dataset/app_config.yaml +++ b/release/nightly_tests/dataset/app_config.yaml @@ -1,4 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: {} python: pip_packages: [] diff --git a/release/nightly_tests/dataset/dataset_shuffle_data_loader.py b/release/nightly_tests/dataset/dataset_shuffle_data_loader.py index e917624a4712b..da3a7d74649f0 100644 --- a/release/nightly_tests/dataset/dataset_shuffle_data_loader.py +++ b/release/nightly_tests/dataset/dataset_shuffle_data_loader.py @@ -85,7 +85,7 @@ def create_torch_iterator(split, batch_size, rank=None): def create_dataset(filenames, repeat_times): pipeline = ray.data.read_parquet(list(filenames))\ - .repeat(times=repeat_times).random_shuffle_each_window() + .repeat(times=repeat_times).random_shuffle() return pipeline diff --git a/release/nightly_tests/dataset/pipelined_ingestion_app.yaml b/release/nightly_tests/dataset/pipelined_ingestion_app.yaml index 23ee18a1008b7..2fbda804b9b50 100644 --- a/release/nightly_tests/dataset/pipelined_ingestion_app.yaml +++ b/release/nightly_tests/dataset/pipelined_ingestion_app.yaml @@ -1,4 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: {} python: pip_packages: [] diff --git a/release/nightly_tests/dataset/pipelined_training.py b/release/nightly_tests/dataset/pipelined_training.py index c8c7486724755..d9a4b9245bee1 100644 --- a/release/nightly_tests/dataset/pipelined_training.py +++ b/release/nightly_tests/dataset/pipelined_training.py @@ -244,12 +244,12 @@ def __next__(self): i * num_rows // num_windows // num_workers for i in range(1, num_workers) ] - pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:") + pipe = pipe.random_shuffle(_spread_resource_prefix="node:") pipe_shards = pipe.split_at_indices(split_indices) else: ds = ray.data.read_parquet(files, _spread_resource_prefix="node:") pipe = ds.repeat(epochs) - pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:") + pipe = pipe.random_shuffle(_spread_resource_prefix="node:") pipe_shards = pipe.split(num_workers, equal=True) return pipe_shards diff --git a/release/nightly_tests/dataset/pipelined_training_app.yaml b/release/nightly_tests/dataset/pipelined_training_app.yaml index 23ee18a1008b7..2fbda804b9b50 100644 --- a/release/nightly_tests/dataset/pipelined_training_app.yaml +++ b/release/nightly_tests/dataset/pipelined_training_app.yaml @@ -1,4 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: {} python: pip_packages: [] diff --git a/release/nightly_tests/dataset/shuffle_app_config.yaml b/release/nightly_tests/dataset/shuffle_app_config.yaml index d89acec77a973..ac02d79b90415 100644 --- a/release/nightly_tests/dataset/shuffle_app_config.yaml +++ b/release/nightly_tests/dataset/shuffle_app_config.yaml @@ -1,4 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: {} python: pip_packages: ["boto3", "numpy", "torch", "tqdm", "pyarrow"] diff --git a/release/nightly_tests/decision_tree/decision_tree_app_config.yaml b/release/nightly_tests/decision_tree/decision_tree_app_config.yaml index 70ae8eb896d16..92f5d3707fe1c 100644 --- a/release/nightly_tests/decision_tree/decision_tree_app_config.yaml +++ b/release/nightly_tests/decision_tree/decision_tree_app_config.yaml @@ -1,4 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" +env_vars: {} debian_packages: [] python: diff --git a/release/nightly_tests/many_nodes_tests/app_config.yaml b/release/nightly_tests/many_nodes_tests/app_config.yaml index 9586d050b0418..67eb10caac1e7 100644 --- a/release/nightly_tests/many_nodes_tests/app_config.yaml +++ b/release/nightly_tests/many_nodes_tests/app_config.yaml @@ -1,5 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {"RAY_gcs_server_rpc_server_thread_num": "8", "RAY_GCS_ACTOR_SCHEDULING_ENABLED": "true"} +env_vars: {} debian_packages: [] python: diff --git a/release/nightly_tests/nightly_tests.yaml b/release/nightly_tests/nightly_tests.yaml index 9482eade1e713..d932924ffa6a3 100644 --- a/release/nightly_tests/nightly_tests.yaml +++ b/release/nightly_tests/nightly_tests.yaml @@ -317,24 +317,13 @@ prepare: python wait_cluster.py 32 1000 script: python dask_on_ray/dask_on_ray_sort.py --nbytes 1_000_000_000_000 --npartitions 1000 --num-nodes 31 --ray --data-dir /tmp/ray --s3-bucket core-nightly-test -# TODO (yic): Add this back when we make it stable -# - name: many_nodes_actor_test -# cluster: -# app_config: many_nodes_tests/app_config.yaml -# compute_template: many_nodes_tests/compute_config.yaml - -# run: -# timeout: 7200 -# prepare: python wait_cluster.py 500 5400 -# script: python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 -# # TODO: enable failure test later -# #&& python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --fail --no-report && python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --no-report - -- name: pg_autoscaling_regression_test +- name: many_nodes_actor_test cluster: - app_config: placement_group_tests/app_config.yaml - compute_template: placement_group_tests/compute.yaml + app_config: many_nodes_tests/app_config.yaml + compute_template: many_nodes_tests/compute_config.yaml run: - timeout: 1200 - script: python placement_group_tests/pg_run.py + timeout: 7200 + prepare: python wait_cluster.py 500 5400 + script: python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 + # TODO(yic): Add extra test for python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --fail --no-report && python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --no-report diff --git a/release/nightly_tests/placement_group_tests/app_config.yaml b/release/nightly_tests/placement_group_tests/app_config.yaml deleted file mode 100644 index d30247838e1e9..0000000000000 --- a/release/nightly_tests/placement_group_tests/app_config.yaml +++ /dev/null @@ -1,12 +0,0 @@ -base_image: "anyscale/ray-ml:pinned-nightly-py37" -debian_packages: [] - -python: - pip_packages: [] - conda_packages: [] - -post_build_cmds: - - pip uninstall -y ray - - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} - - pip3 install -U ray[default] - - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/nightly_tests/placement_group_tests/cluster.py b/release/nightly_tests/placement_group_tests/cluster.py deleted file mode 100644 index a12ed798a4e99..0000000000000 --- a/release/nightly_tests/placement_group_tests/cluster.py +++ /dev/null @@ -1,13 +0,0 @@ -import time -from ray.cluster_utils import Cluster - -cluster = Cluster() - -cluster.add_node(num_cpus=16) - -time.sleep(20) -print("Scaling up.") -cluster.add_node(num_cpus=16, num_gpus=1) - -print("Scaled up. Waiting for 1000 seconds until done.") -time.sleep(1000) diff --git a/release/nightly_tests/placement_group_tests/compute.yaml b/release/nightly_tests/placement_group_tests/compute.yaml deleted file mode 100644 index 5b619db7651a4..0000000000000 --- a/release/nightly_tests/placement_group_tests/compute.yaml +++ /dev/null @@ -1,27 +0,0 @@ -cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} -region: us-west-2 - -aws: - BlockDeviceMappings: - - DeviceName: /dev/sda1 - Ebs: - VolumeSize: 500 - -head_node_type: - name: head_node - instance_type: m5.4xlarge - -worker_node_types: - - name: cpu_node - instance_type: m5.4xlarge - min_workers: 0 - max_workers: 2 - use_spot: false - - name: fake_gpu_node - instance_type: m5.4xlarge - min_workers: 0 - max_workers: 2 - use_spot: false - resources: - cpu: 16 - gpu: 1 diff --git a/release/nightly_tests/placement_group_tests/pg_run.py b/release/nightly_tests/placement_group_tests/pg_run.py deleted file mode 100644 index 7bb616c2dcaa3..0000000000000 --- a/release/nightly_tests/placement_group_tests/pg_run.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import time -import json - -import ray -from ray.util.placement_group import placement_group - -# Tests are supposed to run for 10 minutes. -RUNTIME = 600 -NUM_CPU_BUNDLES = 30 - - -@ray.remote(num_cpus=1) -class Worker(object): - def __init__(self, i): - self.i = i - - def work(self): - time.sleep(0.1) - print("work ", self.i) - - -@ray.remote(num_cpus=1, num_gpus=1) -class Trainer(object): - def __init__(self, i): - self.i = i - - def train(self): - time.sleep(0.2) - print("train ", self.i) - - -def main(): - ray.init(address="auto") - - bundles = [{"CPU": 1, "GPU": 1}] - bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)] - - pg = placement_group(bundles, strategy="PACK") - - ray.get(pg.ready()) - - workers = [ - Worker.options(placement_group=pg).remote(i) - for i in range(NUM_CPU_BUNDLES) - ] - - trainer = Trainer.options(placement_group=pg).remote(0) - - start = time.time() - while True: - ray.get([workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)]) - ray.get(trainer.train.remote()) - end = time.time() - if end - start > RUNTIME: - break - - if "TEST_OUTPUT_JSON" in os.environ: - out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") - results = {} - json.dump(results, out_file) - - -if __name__ == "__main__": - main() diff --git a/release/nightly_tests/shuffle/shuffle_app_config.yaml b/release/nightly_tests/shuffle/shuffle_app_config.yaml index d30247838e1e9..67eb10caac1e7 100644 --- a/release/nightly_tests/shuffle/shuffle_app_config.yaml +++ b/release/nightly_tests/shuffle/shuffle_app_config.yaml @@ -1,4 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" +env_vars: {} debian_packages: [] python: @@ -9,4 +10,5 @@ post_build_cmds: - pip uninstall -y ray - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} - pip3 install -U ray[default] + - echo {{env["DATESTAMP"]}} - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml b/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml index 536c7b6da27f4..2fea571c90f77 100644 --- a/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml +++ b/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml @@ -1,4 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" +env_vars: {} debian_packages: [] python: diff --git a/release/nightly_tests/stress_tests/stress_tests_app_config.yaml b/release/nightly_tests/stress_tests/stress_tests_app_config.yaml index 66c99bb3bfe5a..1f264f9fa1e44 100644 --- a/release/nightly_tests/stress_tests/stress_tests_app_config.yaml +++ b/release/nightly_tests/stress_tests/stress_tests_app_config.yaml @@ -1,4 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" +env_vars: {} debian_packages: [] python: diff --git a/release/release_logs/1.7.0/benchmarks/many_actors.txt b/release/release_logs/1.7.0/benchmarks/many_actors.txt deleted file mode 100644 index 2995df9b7f18d..0000000000000 --- a/release/release_logs/1.7.0/benchmarks/many_actors.txt +++ /dev/null @@ -1,10 +0,0 @@ -{ - "actors_per_second": 333.2797984180003, - "num_actors": 10000, - "time": 30.0048189163208, - "success": "1", - "_runtime": 43.551865577697754, - "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_han7mApDaGYvrbvhuLKBSGBz", - "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", - "_stable": true -} diff --git a/release/release_logs/1.7.0/benchmarks/many_nodes.txt b/release/release_logs/1.7.0/benchmarks/many_nodes.txt deleted file mode 100644 index d6d5a3c0b6631..0000000000000 --- a/release/release_logs/1.7.0/benchmarks/many_nodes.txt +++ /dev/null @@ -1,10 +0,0 @@ -{ - "tasks_per_second": 3.224712885579051, - "num_tasks": 1000, - "time": 610.1051273345947, - "success": "1", - "_runtime": 620.4832813739777, - "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_6f82dxdGaxTV4uZNSamTYGLY", - "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", - "_stable": true -} diff --git a/release/release_logs/1.7.0/benchmarks/many_pgs.txt b/release/release_logs/1.7.0/benchmarks/many_pgs.txt deleted file mode 100644 index 560c050dcecb4..0000000000000 --- a/release/release_logs/1.7.0/benchmarks/many_pgs.txt +++ /dev/null @@ -1,10 +0,0 @@ -{ - "pgs_per_second": 17.06879130613137, - "num_pgs": 1000, - "time": 58.586456537246704, - "success": "1", - "_runtime": 69.5553240776062, - "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_gr3X2VEThCAQrtiHrJRd8yxW", - "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", - "_stable": true -} diff --git a/release/release_logs/1.7.0/benchmarks/many_tasks.txt b/release/release_logs/1.7.0/benchmarks/many_tasks.txt deleted file mode 100644 index fa9c7d8d41db2..0000000000000 --- a/release/release_logs/1.7.0/benchmarks/many_tasks.txt +++ /dev/null @@ -1,10 +0,0 @@ -{ - "tasks_per_second": 27.508657888123608, - "num_tasks": 10000, - "time": 663.5219151973724, - "success": "1", - "_runtime": 674.2678966522217, - "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_XCJkRqS4HkuHLXehx7i6Fwvc", - "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", - "_stable": true -} diff --git a/release/release_logs/1.7.0/microbenchmark.txt b/release/release_logs/1.7.0/microbenchmark.txt deleted file mode 100644 index b5fa29117583d..0000000000000 --- a/release/release_logs/1.7.0/microbenchmark.txt +++ /dev/null @@ -1,134 +0,0 @@ -{ - "single_client_get_calls": [ - 34647.91400708946, - 311.7390971967917 - ], - "single_client_put_calls": [ - 58969.83872190603, - 869.618205663433 - ], - "multi_client_put_calls": [ - 199832.5755298421, - 2482.9205035774476 - ], - "single_client_get_calls_Plasma_Store": [ - 7082.757370159696, - 146.62873820799672 - ], - "single_client_put_calls_Plasma_Store": [ - 6321.65654587901, - 11.077913617295936 - ], - "multi_client_put_calls_Plasma_Store": [ - 9186.218655830648, - 112.23231532820908 - ], - "single_client_put_gigabytes": [ - 20.299125005168346, - 5.063681202623047 - ], - "single_client_tasks_and_get_batch": [ - 13.14018865978927, - 0.3152301478634011 - ], - "multi_client_put_gigabytes": [ - 36.56441662881655, - 1.843382220404724 - ], - "single_client_get_object_containing_10k_refs": [ - 10.351906653488715, - 0.23442465466734483 - ], - "single_client_tasks_sync": [ - 1257.4155346823063, - 16.879731074181798 - ], - "single_client_tasks_async": [ - 13436.707639489237, - 467.0229967004351 - ], - "multi_client_tasks_async": [ - 37893.82918345513, - 2501.210898297811 - ], - "1_1_actor_calls_sync": [ - 2018.517206134362, - 4.133444448098185 - ], - "1_1_actor_calls_async": [ - 5107.498479502846, - 155.05763494606228 - ], - "1_1_actor_calls_concurrent": [ - 4974.868578485068, - 46.89895438701842 - ], - "1_n_actor_calls_async": [ - 13035.656413458306, - 263.67959962428176 - ], - "n_n_actor_calls_async": [ - 42424.91241384691, - 909.2063842725172 - ], - "n_n_actor_calls_with_arg_async": [ - 2910.8727809194884, - 142.55651461439174 - ], - "1_1_async_actor_calls_sync": [ - 1434.0111494545497, - 15.145616176257736 - ], - "1_1_async_actor_calls_async": [ - 3227.631490168903, - 74.52309737428871 - ], - "1_1_async_actor_calls_with_args_async": [ - 2417.18007329992, - 42.010241468147406 - ], - "1_n_async_actor_calls_async": [ - 13212.476889889944, - 280.91562344862103 - ], - "n_n_async_actor_calls_async": [ - 32212.030653578477, - 4172.2556150359205 - ], - "client__get_calls": [ - 1518.5267029642152, - 18.33838666361156 - ], - "client__put_calls": [ - 869.7170835067376, - 8.603084105450836 - ], - "client__put_gigabytes": [ - 0.11768745420143228, - 0.002542373184018965 - ], - "client__tasks_and_put_batch": [ - 58861.12144186892, - 546.7701167395176 - ], - "client__1_1_actor_calls_sync": [ - 472.8343418119895, - 6.16968890867776 - ], - "client__1_1_actor_calls_async": [ - 742.6478263697102, - 2.886810073788351 - ], - "client__1_1_actor_calls_concurrent": [ - 729.3572241473628, - 19.903703549912592 - ], - "client__tasks_and_get_batch": [ - 0.6990944804839968, - 0.00738047968242822 - ], - "_runtime": 558.9188287258148, - "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_AHVUzrAzUMiLZ4p9EEAbL68s", - "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", - "_stable": true -} diff --git a/release/release_logs/1.7.0/scalability/object_store.txt b/release/release_logs/1.7.0/scalability/object_store.txt deleted file mode 100644 index 6917229b88dc5..0000000000000 --- a/release/release_logs/1.7.0/scalability/object_store.txt +++ /dev/null @@ -1,10 +0,0 @@ -{ - "broadcast_time": 611.015479593, - "object_size": 1073741824, - "num_nodes": 50, - "success": "1", - "_runtime": 620.4363269805908, - "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_Chj4PHZqrEjbzc8Ni4RY1Fev", - "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", - "_stable": true -} diff --git a/release/release_logs/1.7.0/scalability/single_node.txt b/release/release_logs/1.7.0/scalability/single_node.txt deleted file mode 100644 index c868fa3c8eb4e..0000000000000 --- a/release/release_logs/1.7.0/scalability/single_node.txt +++ /dev/null @@ -1,16 +0,0 @@ -{ - "args_time": 17.256289814000013, - "num_args": 10000, - "returns_time": 5.854934190999984, - "num_returns": 3000, - "get_time": 25.88724605799996, - "queued_time": 140.99555420300004, - "num_queued": 1000000, - "large_object_time": 294.249499343, - "large_object_size": 107374182400, - "success": "1", - "_runtime": 528.4356288909912, - "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_ELgpggWSHiqhksawLcz4urEP", - "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", - "_stable": true -} diff --git a/release/release_logs/1.7.0/stress_tests/dead_actors.txt b/release/release_logs/1.7.0/stress_tests/dead_actors.txt deleted file mode 100644 index ab763e4173b75..0000000000000 --- a/release/release_logs/1.7.0/stress_tests/dead_actors.txt +++ /dev/null @@ -1,11 +0,0 @@ -{ - "success": 1, - "total_time": 130.34314274787903, - "avg_iteration_time": 1.303428828716278, - "max_iteration_time": 3.651247501373291, - "min_iteration_time": 0.09438443183898926, - "_runtime": 902.0143933296204, - "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_pxDnaxYFzDNsyifjJNV1qhqs", - "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", - "_stable": true -} diff --git a/release/release_logs/1.7.0/stress_tests/many_tasks.txt b/release/release_logs/1.7.0/stress_tests/many_tasks.txt deleted file mode 100644 index a0244c5b28489..0000000000000 --- a/release/release_logs/1.7.0/stress_tests/many_tasks.txt +++ /dev/null @@ -1,19 +0,0 @@ -{ - "success": 1, - "stage_0_time": 5.256332874298096, - "stage_1_time": 174.50774693489075, - "stage_1_avg_iteration_time": 17.450765538215638, - "stage_1_max_iteration_time": 17.627604961395264, - "stage_1_min_iteration_time": 17.23277997970581, - "stage_2_time": 268.01243686676025, - "stage_2_avg_iteration_time": 53.60213441848755, - "stage_2_max_iteration_time": 59.097413063049316, - "stage_2_min_iteration_time": 48.71518564224243, - "stage_3_creation_time": 0.5777060985565186, - "stage_3_time": 2066.70570230484, - "stage_4_spread": 3.2197082901427945, - "_runtime": 5045.744384527206, - "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_b8v2V4Tr7vwee6tCDjTjdXLL", - "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", - "_stable": true -} diff --git a/release/release_logs/1.7.0/stress_tests/placement_group.txt b/release/release_logs/1.7.0/stress_tests/placement_group.txt deleted file mode 100644 index cbe7c99c54a04..0000000000000 --- a/release/release_logs/1.7.0/stress_tests/placement_group.txt +++ /dev/null @@ -1,9 +0,0 @@ -{ - "success": 1, - "avg_pg_create_time_ms": 0.9874122837809874, - "avg_pg_remove_time_ms": 4.4027920900909265, - "_runtime": 458.8596382141113, - "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_7uQL743cWCzdDT3ZYTpRDETi", - "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", - "_stable": true -} diff --git a/release/serve_tests/serve_tests.yaml b/release/serve_tests/serve_tests.yaml index 4362ca296d909..06edd31be95eb 100644 --- a/release/serve_tests/serve_tests.yaml +++ b/release/serve_tests/serve_tests.yaml @@ -27,7 +27,6 @@ - name: serve_micro_benchmark cluster: app_config: app_config.yaml - # 16 CPUS compute_template: compute_tpl_single_node.yaml run: @@ -35,19 +34,5 @@ long_running: False script: python workloads/serve_micro_benchmark.py - smoke_test: - timeout: 600 - -- name: serve_cluster_fault_tolerance - cluster: - app_config: app_config.yaml - # 16 CPUS - compute_template: compute_tpl_single_node.yaml - - run: - timeout: 7200 - long_running: False - script: python workloads/serve_cluster_fault_tolerance.py - smoke_test: timeout: 600 \ No newline at end of file diff --git a/release/serve_tests/workloads/serve_cluster_fault_tolerance.py b/release/serve_tests/workloads/serve_cluster_fault_tolerance.py deleted file mode 100644 index 431c78b9c5df3..0000000000000 --- a/release/serve_tests/workloads/serve_cluster_fault_tolerance.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -Test that a serve deployment can recover from cluster failures by resuming -from checkpoints of external source, such as s3. - -For product testing, we skip the part of actually starting new cluster as -it's Job Manager's responsibility, and only re-deploy to the same cluster -with remote checkpoint. -""" - -import click -import time -import requests -import uuid -import os - -from serve_test_cluster_utils import setup_local_single_node_cluster - -from serve_test_utils import (save_test_results) - -import ray -from ray import serve -from ray.serve.utils import logger - -# Deployment configs -DEFAULT_NUM_REPLICAS = 4 -DEFAULT_MAX_BATCH_SIZE = 16 - - -def request_with_retries(endpoint, timeout=3): - start = time.time() - while True: - try: - return requests.get( - "http://127.0.0.1:8000" + endpoint, timeout=timeout) - except requests.RequestException: - if time.time() - start > timeout: - raise TimeoutError - time.sleep(0.1) - - -@click.command() -def main(): - # Setup local cluster, note this cluster setup is the same for both - # local and product ray cluster env. - # Each test uses different ray namespace, thus kv storage key for each - # checkpoint is different to avoid collision. - namespace = uuid.uuid4().hex - - # IS_SMOKE_TEST is set by args of releaser's e2e.py - smoke_test = os.environ.get("IS_SMOKE_TEST", "1") - if smoke_test == "1": - checkpoint_path = "file://checkpoint.db" - else: - checkpoint_path = "s3://serve-nightly-tests/fault-tolerant-test-checkpoint" # noqa: E501 - - _, cluster = setup_local_single_node_cluster( - 1, checkpoint_path=checkpoint_path, namespace=namespace) - - # Deploy for the first time - @serve.deployment(name="echo", num_replicas=DEFAULT_NUM_REPLICAS) - class Echo: - def __init__(self): - return True - - def __call__(self, request): - return "hii" - - Echo.deploy() - - # Ensure endpoint is working - for _ in range(5): - response = request_with_retries("/echo/", timeout=3) - assert response.text == "hii" - - logger.info("Initial deployment successful with working endpoint.") - - # Kill current cluster, recover from remote checkpoint and ensure endpoint - # is still available with expected results - - ray.kill(serve.api._global_client._controller, no_restart=True) - ray.shutdown() - cluster.shutdown() - serve.api._set_global_client(None) - - # Start another ray cluster with same namespace to resume from previous - # checkpoints with no new deploy() call. - setup_local_single_node_cluster( - 1, checkpoint_path=checkpoint_path, namespace=namespace) - - for _ in range(5): - response = request_with_retries("/echo/", timeout=3) - assert response.text == "hii" - - logger.info("Deployment recovery from s3 checkpoint is successful " - "with working endpoint.") - - # Delete dangling checkpoints. If script failed before this step, it's up - # to the TTL policy on s3 to clean up, but won't lead to collision with - # subsequent tests since each test run in different uuid namespace. - serve.shutdown() - ray.shutdown() - cluster.shutdown() - - # Checkpoints in S3 bucket are moved after 7 days with explicit lifecycle - # rules. Each checkpoint is ~260 Bytes in size from this test. - - # Save results - save_test_results( - { - "result": "success" - }, - default_output_file="/tmp/serve_cluster_fault_tolerance.json") - - -if __name__ == "__main__": - main() - import pytest - import sys - sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/release/serve_tests/workloads/serve_test_cluster_utils.py b/release/serve_tests/workloads/serve_test_cluster_utils.py index 3d9ccc44ae7f5..22e4e30cfdf35 100644 --- a/release/serve_tests/workloads/serve_test_cluster_utils.py +++ b/release/serve_tests/workloads/serve_test_cluster_utils.py @@ -6,16 +6,13 @@ from ray.cluster_utils import Cluster from ray.serve.utils import logger from ray.serve.config import DeploymentMode -from ray.serve.constants import DEFAULT_CHECKPOINT_PATH + # Cluster setup configs NUM_CPU_PER_NODE = 10 NUM_CONNECTIONS = 10 -def setup_local_single_node_cluster( - num_nodes: int, - checkpoint_path: str = DEFAULT_CHECKPOINT_PATH, - namespace="serve"): +def setup_local_single_node_cluster(num_nodes): """Setup ray cluster locally via ray.init() and Cluster() Each actor is simulated in local process on single node, @@ -24,23 +21,19 @@ def setup_local_single_node_cluster( cluster = Cluster() for i in range(num_nodes): cluster.add_node( - redis_port=6380 if i == 0 else None, + redis_port=6379 if i == 0 else None, num_cpus=NUM_CPU_PER_NODE, num_gpus=0, resources={str(i): 2}, ) - ray.init( - address=cluster.address, dashboard_host="0.0.0.0", namespace=namespace) + ray.init(address=cluster.address, dashboard_host="0.0.0.0") serve_client = serve.start( - detached=True, - http_options={"location": DeploymentMode.EveryNode}, - _checkpoint_path=checkpoint_path, - ) + http_options={"location": DeploymentMode.EveryNode}) - return serve_client, cluster + return serve_client -def setup_anyscale_cluster(checkpoint_path: str = DEFAULT_CHECKPOINT_PATH): +def setup_anyscale_cluster(): """Setup ray cluster at anyscale via ray.client() Note this is by default large scale and should be kicked off @@ -51,9 +44,7 @@ def setup_anyscale_cluster(checkpoint_path: str = DEFAULT_CHECKPOINT_PATH): # ray.client().env({}).connect() ray.init(address="auto") serve_client = serve.start( - http_options={"location": DeploymentMode.EveryNode}, - _checkpoint_path=checkpoint_path, - ) + http_options={"location": DeploymentMode.EveryNode}) return serve_client diff --git a/release/util/pip_download_test.sh b/release/util/pip_download_test.sh index c1d998b44e2b1..6ab91732ab255 100755 --- a/release/util/pip_download_test.sh +++ b/release/util/pip_download_test.sh @@ -56,7 +56,7 @@ do else failed=true fi - if bash sanity_check_cpp.sh; then + if sh sanity_check_cpp.sh; then echo "PYTHON ${PYTHON_VERSION} succeed sanity check C++." else cpp_failed=true diff --git a/rllib/BUILD b/rllib/BUILD index b09e149b14a22..f4c527bbb8099 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -876,13 +876,25 @@ py_test( srcs = ["agents/ddpg/tests/test_ddpg.py"] ) -# DQNTrainer +# DQNTrainer/SimpleQTrainer py_test( name = "test_dqn", tags = ["team:ml", "trainers_dir"], size = "large", srcs = ["agents/dqn/tests/test_dqn.py"] ) +py_test( + name = "test_r2d2", + tags = ["team:ml", "trainers_dir"], + size = "large", + srcs = ["agents/dqn/tests/test_r2d2.py"] +) +py_test( + name = "test_simple_q", + tags = ["team:ml", "trainers_dir"], + size = "medium", + srcs = ["agents/dqn/tests/test_simple_q.py"] +) # Dreamer py_test( @@ -990,22 +1002,6 @@ py_test( srcs = ["agents/qmix/tests/test_qmix.py"] ) -# R2D2Trainer -py_test( - name = "test_r2d2", - tags = ["team:ml", "trainers_dir"], - size = "large", - srcs = ["agents/dqn/tests/test_r2d2.py"] -) - -# RNNSACTrainer -py_test( - name = "test_rnnsac", - tags = ["team:ml", "trainers_dir"], - size = "medium", - srcs = ["agents/sac/tests/test_rnnsac.py"] -) - # SACTrainer py_test( name = "test_sac", @@ -1014,14 +1010,6 @@ py_test( srcs = ["agents/sac/tests/test_sac.py"] ) -# SimpleQTrainer -py_test( - name = "test_simple_q", - tags = ["team:ml", "trainers_dir"], - size = "medium", - srcs = ["agents/dqn/tests/test_simple_q.py"] -) - # TD3Trainer py_test( name = "test_td3", @@ -1340,38 +1328,18 @@ py_test( # -------------------------------------------------------------------- sh_test( - name = "env/tests/test_local_inference_cartpole", + name = "env/tests/test_local_inference", tags = ["team:ml", "env"], size = "medium", - srcs = ["env/tests/test_policy_client_server_setup.sh"], - args = ["local", "cartpole"], + srcs = ["env/tests/test_local_inference.sh"], data = glob(["examples/serving/*.py"]), ) sh_test( - name = "env/tests/test_remote_inference_cartpole", + name = "env/tests/test_remote_inference", tags = ["team:ml", "env"], size = "medium", - srcs = ["env/tests/test_policy_client_server_setup.sh"], - args = ["remote", "cartpole"], - data = glob(["examples/serving/*.py"]), -) - -sh_test( - name = "env/tests/test_local_inference_unity3d", - tags = ["team:ml", "env"], - size = "medium", - srcs = ["env/tests/test_policy_client_server_setup.sh"], - args = ["local", "unity3d"], - data = glob(["examples/serving/*.py"]), -) - -sh_test( - name = "env/tests/test_remote_inference_unity3d", - tags = ["team:ml", "env"], - size = "medium", - srcs = ["env/tests/test_policy_client_server_setup.sh"], - args = ["remote", "unity3d"], + srcs = ["env/tests/test_remote_inference.sh"], data = glob(["examples/serving/*.py"]), ) @@ -1382,13 +1350,6 @@ py_test( srcs = ["env/tests/test_record_env_wrapper.py"] ) -py_test( - name = "env/tests/test_remote_worker_envs", - tags = ["team:ml", "env"], - size = "medium", - srcs = ["env/tests/test_remote_worker_envs.py"] -) - py_test( name = "env/wrappers/tests/test_unity3d_env", tags = ["team:ml", "env"], @@ -1886,14 +1847,14 @@ py_test( args = ["TestSupportedMultiAgentOffPolicy"] ) -py_test( - name = "tests/test_supported_spaces_pg", - main = "tests/test_supported_spaces.py", - tags = ["team:ml", "tests_dir", "tests_dir_S"], - size = "large", - srcs = ["tests/test_supported_spaces.py"], - args = ["TestSupportedSpacesPG"] - ) +# py_test( +# name = "tests/test_supported_spaces_pg", +# main = "tests/test_supported_spaces.py", +# tags = ["team:ml", "tests_dir", "tests_dir_S"], +# size = "enormous", +# srcs = ["tests/test_supported_spaces.py"], +# args = ["TestSupportedSpacesPG"] +# ) py_test( name = "tests/test_supported_spaces_off_policy", diff --git a/rllib/agents/a3c/a3c_tf_policy.py b/rllib/agents/a3c/a3c_tf_policy.py index 6e7b362a4fd95..cbc5bbbd797d6 100644 --- a/rllib/agents/a3c/a3c_tf_policy.py +++ b/rllib/agents/a3c/a3c_tf_policy.py @@ -111,7 +111,7 @@ def grad_stats(policy: Policy, train_batch: SampleBatch, "grad_gnorm": tf.linalg.global_norm(grads), "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], - policy.model.value_function()) + policy.model.value_function()), } diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index ea44f4767cfdc..99172adb814e0 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -72,25 +72,19 @@ def actor_critic_loss(policy: Policy, model: ModelV2, total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] - entropy * policy.config["entropy_coeff"]) - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["entropy"] = entropy - model.tower_stats["pi_err"] = pi_err - model.tower_stats["value_err"] = value_err + policy.entropy = entropy + policy.pi_err = pi_err + policy.value_err = value_err return total_loss def loss_and_entropy_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: - return { - "policy_entropy": torch.mean( - torch.stack(policy.get_tower_stats("entropy"))), - "policy_loss": torch.mean( - torch.stack(policy.get_tower_stats("pi_err"))), - "vf_loss": torch.mean( - torch.stack(policy.get_tower_stats("value_err"))), + "policy_entropy": policy.entropy, + "policy_loss": policy.pi_err, + "vf_loss": policy.value_err, } diff --git a/rllib/agents/a3c/tests/test_a2c.py b/rllib/agents/a3c/tests/test_a2c.py index 4c8a259245adc..2394b3f5812b7 100644 --- a/rllib/agents/a3c/tests/test_a2c.py +++ b/rllib/agents/a3c/tests/test_a2c.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.a3c as a3c from ray.rllib.utils.test_utils import check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator class TestA2C(unittest.TestCase): @@ -29,7 +29,6 @@ def test_a2c_compilation(self): trainer = a3c.A2CTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() - check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop() @@ -38,9 +37,7 @@ def test_a2c_exec_impl(ray_start_regular): config = {"min_iter_time_s": 0} for _ in framework_iterator(config): trainer = a3c.A2CTrainer(env="CartPole-v0", config=config) - results = trainer.train() - check_train_results(results) - print(results) + assert isinstance(trainer.train(), dict) check_compute_single_action(trainer) trainer.stop() @@ -51,9 +48,7 @@ def test_a2c_exec_impl_microbatch(ray_start_regular): } for _ in framework_iterator(config): trainer = a3c.A2CTrainer(env="CartPole-v0", config=config) - results = trainer.train() - check_train_results(results) - print(results) + assert isinstance(trainer.train(), dict) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/agents/a3c/tests/test_a3c.py b/rllib/agents/a3c/tests/test_a3c.py index 59147f213a7a5..6ffbab01f955f 100644 --- a/rllib/agents/a3c/tests/test_a3c.py +++ b/rllib/agents/a3c/tests/test_a3c.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.a3c as a3c from ray.rllib.utils.test_utils import check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator class TestA3C(unittest.TestCase): @@ -31,7 +31,6 @@ def test_a3c_compilation(self): trainer = a3c.A3CTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() - check_train_results(results) print(results) check_compute_single_action( trainer, include_state=config["model"]["use_lstm"]) diff --git a/rllib/agents/ars/tests/test_ars.py b/rllib/agents/ars/tests/test_ars.py index a78353de44ac4..b6bb3c8df7277 100644 --- a/rllib/agents/ars/tests/test_ars.py +++ b/rllib/agents/ars/tests/test_ars.py @@ -7,16 +7,9 @@ class TestARS(unittest.TestCase): - @classmethod - def setUpClass(cls): - ray.init(num_cpus=3) - - @classmethod - def tearDownClass(cls): - ray.shutdown() - def test_ars_compilation(self): """Test whether an ARSTrainer can be built on all frameworks.""" + ray.init(num_cpus=3) config = ars.DEFAULT_CONFIG.copy() # Keep it simple. config["model"]["fcnet_hiddens"] = [10] @@ -37,6 +30,7 @@ def test_ars_compilation(self): check_compute_single_action(trainer) trainer.stop() + ray.shutdown() if __name__ == "__main__": diff --git a/rllib/agents/cql/cql.py b/rllib/agents/cql/cql.py index 19f1573e29ba9..3c9c026c7bc34 100644 --- a/rllib/agents/cql/cql.py +++ b/rllib/agents/cql/cql.py @@ -14,11 +14,10 @@ from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \ UpdateTargetNetwork from ray.rllib.offline.shuffled_input import ShuffledInput -from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import merge_dicts from ray.rllib.utils.framework import try_import_tf, try_import_tfp -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import TrainerConfigDict tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/cql/cql_torch_policy.py b/rllib/agents/cql/cql_torch_policy.py index f62b23069a4fd..fed6470dc585e 100644 --- a/rllib/agents/cql/cql_torch_policy.py +++ b/rllib/agents/cql/cql_torch_policy.py @@ -14,12 +14,12 @@ build_sac_model_and_action_dist, optimizer_fn, ComputeTDErrorMixin, \ TargetNetworkMixin, setup_late_mixins, action_distribution_fn from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import LocalOptimizer, TensorType, \ TrainerConfigDict from ray.rllib.utils.torch_ops import apply_grad_clipping, \ @@ -250,29 +250,23 @@ def cql_loss(policy: Policy, model: ModelV2, critic_loss[1].backward(retain_graph=False) policy.critic_optims[1].step() - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - # SAC stats. - model.tower_stats["q_t"] = q_t_selected - model.tower_stats["policy_t"] = policy_t - model.tower_stats["log_pis_t"] = log_pis_t - model.tower_stats["actor_loss"] = actor_loss - model.tower_stats["critic_loss"] = critic_loss - model.tower_stats["alpha_loss"] = alpha_loss - model.tower_stats["log_alpha_value"] = model.log_alpha - model.tower_stats["alpha_value"] = alpha - model.tower_stats["target_entropy"] = model.target_entropy - # CQL stats. - model.tower_stats["cql_loss"] = cql_loss - - # TD-error tensor in final stats - # will be concatenated and retrieved for each individual batch item. - model.tower_stats["td_error"] = td_error - + # Save for stats function. + policy.q_t = q_t_selected + policy.policy_t = policy_t + policy.log_pis_t = log_pis_t + model.td_error = td_error + policy.actor_loss = actor_loss + policy.critic_loss = critic_loss + policy.alpha_loss = alpha_loss + policy.log_alpha_value = model.log_alpha + policy.alpha_value = alpha + policy.target_entropy = model.target_entropy + # CQL Stats. + policy.cql_loss = cql_loss if use_lagrange: - model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0] - model.tower_stats["alpha_prime_value"] = alpha_prime - model.tower_stats["alpha_prime_loss"] = alpha_prime_loss + policy.log_alpha_prime_value = model.log_alpha_prime[0] + policy.alpha_prime_value = alpha_prime + policy.alpha_prime_loss = alpha_prime_loss if obs.shape[0] == policy.config["train_batch_size"]: policy.alpha_prime_optim.zero_grad() @@ -280,27 +274,22 @@ def cql_loss(policy: Policy, model: ModelV2, policy.alpha_prime_optim.step() # Return all loss terms corresponding to our optimizers. - return tuple([actor_loss] + critic_loss + [alpha_loss] + - ([alpha_prime_loss] if use_lagrange else [])) + if use_lagrange: + return tuple([policy.actor_loss] + policy.critic_loss + + [policy.alpha_loss] + [policy.alpha_prime_loss]) + return tuple([policy.actor_loss] + policy.critic_loss + + [policy.alpha_loss]) def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: - # Get SAC loss stats. - stats_dict = stats(policy, train_batch) - - # Add CQL loss stats to the dict. - stats_dict["cql_loss"] = torch.mean( - torch.stack(*policy.get_tower_stats("cql_loss"))) - + sac_dict = stats(policy, train_batch) + sac_dict["cql_loss"] = torch.mean(torch.stack(policy.cql_loss)) if policy.config["lagrangian"]: - stats_dict["log_alpha_prime_value"] = torch.mean( - torch.stack(policy.get_tower_stats("log_alpha_prime_value"))) - stats_dict["alpha_prime_value"] = torch.mean( - torch.stack(policy.get_tower_stats("alpha_prime_value"))) - stats_dict["alpha_prime_loss"] = torch.mean( - torch.stack(policy.get_tower_stats("alpha_prime_loss"))) - return stats_dict + sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value + sac_dict["alpha_prime_value"] = policy.alpha_prime_value + sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss + return sac_dict def cql_optimizer_fn(policy: Policy, config: TrainerConfigDict) -> \ diff --git a/rllib/agents/cql/tests/test_cql.py b/rllib/agents/cql/tests/test_cql.py index 7e3ef58896f67..9f8466a220e00 100644 --- a/rllib/agents/cql/tests/test_cql.py +++ b/rllib/agents/cql/tests/test_cql.py @@ -7,7 +7,7 @@ import ray.rllib.agents.cql as cql from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -69,13 +69,10 @@ def test_cql_compilation(self): for fw in framework_iterator(config): trainer = cql.CQLTrainer(config=config) for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) - eval_results = results.get("evaluation") - if eval_results: + results = trainer.train().get("evaluation") + if results: print(f"iter={trainer.iteration} " - f"R={eval_results['episode_reward_mean']}") + f"R={results['episode_reward_mean']}") check_compute_single_action(trainer) diff --git a/rllib/agents/ddpg/ddpg_tf_model.py b/rllib/agents/ddpg/ddpg_tf_model.py index f3c4a3ece6e9b..53d2d666dc60c 100644 --- a/rllib/agents/ddpg/ddpg_tf_model.py +++ b/rllib/agents/ddpg/ddpg_tf_model.py @@ -1,6 +1,6 @@ import numpy as np import gym -from typing import List, Optional +from typing import List from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.framework import try_import_tf @@ -29,9 +29,9 @@ def __init__( model_config: ModelConfigDict, name: str, # Extra DDPGActionModel args: - actor_hiddens: Optional[List[int]] = None, + actor_hiddens: List[int] = [256, 256], actor_hidden_activation: str = "relu", - critic_hiddens: Optional[List[int]] = None, + critic_hiddens: List[int] = [256, 256], critic_hidden_activation: str = "relu", twin_q: bool = False, add_layer_norm: bool = False): @@ -48,12 +48,6 @@ def __init__( should be defined in subclasses of DDPGActionModel. """ - if actor_hiddens is None: - actor_hiddens = [256, 256] - - if critic_hiddens is None: - critic_hiddens = [256, 256] - super(DDPGTFModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index d3c295feba940..8c24a84c04a5e 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -28,7 +28,7 @@ from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ - LocalOptimizer, ModelGradients + LocalOptimizer, ModelGradients, PolicyID from ray.util.debug import log_once tf1, tf, tfv = try_import_tf() @@ -429,17 +429,17 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, TargetNetworkMixin.__init__(policy, config) -def validate_spaces(policy: Policy, observation_space: gym.spaces.Space, +def validate_spaces(pid: PolicyID, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: if not isinstance(action_space, Box): raise UnsupportedSpaceException( "Action space ({}) of {} is not supported for " - "DDPG.".format(action_space, policy)) + "DDPG.".format(action_space, pid)) elif len(action_space.shape) > 1: raise UnsupportedSpaceException( "Action space ({}) of {} has multiple dimensions " - "{}. ".format(action_space, policy, action_space.shape) + + "{}. ".format(action_space, pid, action_space.shape) + "Consider reshaping this into a single dimension, " "using a Tuple action space, or the multi-agent API.") diff --git a/rllib/agents/ddpg/ddpg_torch_model.py b/rllib/agents/ddpg/ddpg_torch_model.py index 615e0ea8b5814..2297ee0b2a815 100644 --- a/rllib/agents/ddpg/ddpg_torch_model.py +++ b/rllib/agents/ddpg/ddpg_torch_model.py @@ -1,6 +1,6 @@ import numpy as np import gym -from typing import List, Dict, Union, Optional +from typing import List, Dict, Union from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 @@ -31,9 +31,9 @@ def __init__( model_config: ModelConfigDict, name: str, # Extra DDPGActionModel args: - actor_hiddens: Optional[List[int]] = None, + actor_hiddens: List[int] = [256, 256], actor_hidden_activation: str = "relu", - critic_hiddens: Optional[List[int]] = None, + critic_hiddens: List[int] = [256, 256], critic_hidden_activation: str = "relu", twin_q: bool = False, add_layer_norm: bool = False): @@ -51,12 +51,6 @@ def __init__( only defines the layers for the output heads. Those layers for forward() should be defined in subclasses of DDPGTorchModel. """ - if actor_hiddens is None: - actor_hiddens = [256, 256] - - if critic_hiddens is None: - critic_hiddens = [256, 256] - nn.Module.__init__(self) super(DDPGTorchModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index c6eb6bddbda6e..ef22a5e75fd47 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -172,17 +172,18 @@ def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _, [actor_loss, critic_loss] = model.custom_loss( [actor_loss, critic_loss], input_dict) - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["q_t"] = q_t - model.tower_stats["actor_loss"] = actor_loss - model.tower_stats["critic_loss"] = critic_loss - # TD-error tensor in final stats - # will be concatenated and retrieved for each individual batch item. - model.tower_stats["td_error"] = td_error + # Store values for stats function. + policy.q_t = q_t + policy.actor_loss = actor_loss + policy.critic_loss = critic_loss + + # Store td-error in model, such that for multi-GPU, we do not override + # them during the parallel loss phase. TD-error tensor in final stats + # can then be concatenated and retrieved for each individual batch item. + model.td_error = td_error # Return two loss terms (corresponding to the two optimizers, we create). - return actor_loss, critic_loss + return policy.actor_loss, policy.critic_loss def make_ddpg_optimizers(policy: Policy, @@ -216,16 +217,12 @@ def apply_gradients_fn(policy: Policy, gradients: GradInfoDict) -> None: def build_ddpg_stats(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]: - - q_t = torch.stack(policy.get_tower_stats("q_t")) stats = { - "actor_loss": torch.mean( - torch.stack(policy.get_tower_stats("actor_loss"))), - "critic_loss": torch.mean( - torch.stack(policy.get_tower_stats("critic_loss"))), - "mean_q": torch.mean(q_t), - "max_q": torch.max(q_t), - "min_q": torch.min(q_t), + "actor_loss": policy.actor_loss, + "critic_loss": policy.critic_loss, + "mean_q": torch.mean(policy.q_t), + "max_q": torch.max(policy.q_t), + "min_q": torch.min(policy.q_t), } return stats @@ -254,8 +251,8 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # (one TD-error value per item in batch to update PR weights). loss_fn(self, self.model, None, input_dict) - # `self.model.td_error` is set within actor_critic_loss call. - return self.model.tower_stats["td_error"] + # Self.td_error is set within actor_critic_loss call. + return self.model.td_error self.compute_td_error = compute_td_error diff --git a/rllib/agents/ddpg/tests/test_apex_ddpg.py b/rllib/agents/ddpg/tests/test_apex_ddpg.py index 16ebab9a1f9ae..61556fb9b961b 100644 --- a/rllib/agents/ddpg/tests/test_apex_ddpg.py +++ b/rllib/agents/ddpg/tests/test_apex_ddpg.py @@ -4,7 +4,7 @@ import ray import ray.rllib.agents.ddpg.apex as apex_ddpg from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator class TestApexDDPG(unittest.TestCase): @@ -40,9 +40,7 @@ def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self): check(scale, [0.0] + expected) for _ in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) + print(trainer.train()) check_compute_single_action(trainer) # Test again per-worker scale distribution diff --git a/rllib/agents/ddpg/tests/test_ddpg.py b/rllib/agents/ddpg/tests/test_ddpg.py index 7f72e03d0e30c..be404e720d48e 100644 --- a/rllib/agents/ddpg/tests/test_ddpg.py +++ b/rllib/agents/ddpg/tests/test_ddpg.py @@ -13,7 +13,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator from ray.rllib.utils.torch_ops import convert_to_torch_tensor tf1, tf, tfv = try_import_tf() @@ -45,7 +45,6 @@ def test_ddpg_compilation(self): trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0") for i in range(num_iterations): results = trainer.train() - check_train_results(results) print(results) check_compute_single_action(trainer) # Ensure apply_gradient_fn is being called and updating global_step @@ -289,9 +288,8 @@ def test_ddpg_loss_function(self): elif fw == "torch": loss_torch(policy, policy.model, None, input_) - c, a, t = policy.get_tower_stats("critic_loss")[0], \ - policy.get_tower_stats("actor_loss")[0], \ - policy.get_tower_stats("td_error")[0] + c, a, t = policy.critic_loss, policy.actor_loss, \ + policy.model.td_error # Check pure loss values. check(c, expect_c) check(a, expect_a) diff --git a/rllib/agents/ddpg/tests/test_td3.py b/rllib/agents/ddpg/tests/test_td3.py index a542cf5a1574d..75b84e4ddc57e 100644 --- a/rllib/agents/ddpg/tests/test_td3.py +++ b/rllib/agents/ddpg/tests/test_td3.py @@ -5,7 +5,7 @@ import ray.rllib.agents.ddpg.td3 as td3 from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator tf1, tf, tfv = try_import_tf() @@ -30,7 +30,6 @@ def test_td3_compilation(self): num_iterations = 1 for i in range(num_iterations): results = trainer.train() - check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 49c24b07ed3e7..74afc564f1708 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -33,7 +33,6 @@ from ray.rllib.utils import merge_dicts from ray.rllib.utils.actors import create_colocated from ray.rllib.utils.annotations import override -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.typing import SampleBatchType from ray.tune.trainable import Trainable from ray.tune.utils.placement_groups import PlacementGroupFactory @@ -228,7 +227,7 @@ def add_apex_metrics(result: dict) -> dict: result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), - LEARNER_INFO: copy.deepcopy(learner_thread.learner_info), + "learner": copy.deepcopy(learner_thread.stats), "replay_shard_0": replay_stats, }) return result diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 5f1eadf020a39..ac4b8f0dbb8e5 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -25,8 +25,7 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork, \ MultiGPUTrainOneStep -from ray.rllib.policy.policy import Policy -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -201,17 +200,8 @@ def update_prio(item): td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error")) samples.policy_batches[policy_id].set_get_interceptor(None) - batch_indices = samples.policy_batches[policy_id].get( - "batch_indexes") - # In case the buffer stores sequences, TD-error could already - # be calculated per sequence chunk. - if len(batch_indices) != len(td_error): - T = local_replay_buffer.replay_sequence_length - assert len(batch_indices) > len( - td_error) and len(batch_indices) % T == 0 - batch_indices = batch_indices.reshape([-1, T])[:, 0] - assert len(batch_indices) == len(td_error) - prio_dict[policy_id] = (batch_indices, td_error) + prio_dict[policy_id] = (samples.policy_batches[policy_id] + .get("batch_indexes"), td_error) local_replay_buffer.update_priorities(prio_dict) return info_dict diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index a7826d0da489c..d060a1ce4012a 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -121,7 +121,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # Do forward pass on loss to update td error attribute build_q_losses(self, self.model, None, input_dict) - return self.model.tower_stats["q_loss"].td_error + return self.q_loss.td_error self.compute_td_error = compute_td_error @@ -216,9 +216,8 @@ def get_distribution_inputs_and_class( is_training=is_training) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals - model.tower_stats["q_values"] = q_vals - - return q_vals, TorchCategorical, [] # state-out + policy.q_values = q_vals + return policy.q_values, TorchCategorical, [] # state-out def build_q_losses(policy: Policy, model, _, @@ -287,21 +286,19 @@ def build_q_losses(policy: Policy, model, _, q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1) - q_loss = QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, - q_probs_tp1_best, train_batch[PRIO_WEIGHTS], - train_batch[SampleBatch.REWARDS], - train_batch[SampleBatch.DONES].float(), config["gamma"], - config["n_step"], config["num_atoms"], config["v_min"], - config["v_max"]) + policy.q_loss = QLoss( + q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best, + train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS], + train_batch[SampleBatch.DONES].float(), config["gamma"], + config["n_step"], config["num_atoms"], config["v_min"], + config["v_max"]) - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["td_error"] = q_loss.td_error - # TD-error tensor in final stats - # will be concatenated and retrieved for each individual batch item. - model.tower_stats["q_loss"] = q_loss + # Store td-error in model, such that for multi-GPU, we do not override + # them during the parallel loss phase. TD-error tensor in final stats + # can then be concatenated and retrieved for each individual batch item. + model.td_error = policy.q_loss.td_error - return q_loss.loss + return policy.q_loss.loss def adam_optimizer(policy: Policy, @@ -317,16 +314,9 @@ def adam_optimizer(policy: Policy, def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: - stats = {} - for stats_key in policy.model_gpu_towers[0].tower_stats[ - "q_loss"].stats.keys(): - stats[stats_key] = torch.mean( - torch.stack([ - t.tower_stats["q_loss"].stats[stats_key].to(policy.device) - for t in policy.model_gpu_towers if "q_loss" in t.tower_stats - ])) - stats["cur_lr"] = policy.cur_lr - return stats + return dict({ + "cur_lr": policy.cur_lr, + }, **policy.q_loss.stats) def setup_early_mixins(policy: Policy, obs_space, action_space, @@ -395,7 +385,7 @@ def grad_process_and_td_error_fn(policy: Policy, def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, action_dist) -> Dict[str, TensorType]: - return {"q_values": model.tower_stats["q_values"]} + return {"q_values": policy.q_values} DQNTorchPolicy = build_policy_class( diff --git a/rllib/agents/dqn/learner_thread.py b/rllib/agents/dqn/learner_thread.py index 93bed4b18de5e..0f8d6f15bd79a 100644 --- a/rllib/agents/dqn/learner_thread.py +++ b/rllib/agents/dqn/learner_thread.py @@ -1,8 +1,9 @@ import queue import threading +from ray.rllib.evaluation.metrics import get_learner_stats +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat @@ -32,7 +33,7 @@ def __init__(self, local_worker): self.daemon = True self.weights_updated = False self.stopped = False - self.learner_info = {} + self.stats = {} def run(self): # Switch on eager mode if configured. @@ -48,18 +49,11 @@ def step(self): if replay is not None: prio_dict = {} with self.grad_timer: - # Use LearnerInfoBuilder as a unified way to build the - # final results dict from `learn_on_loaded_batch` call(s). - # This makes sure results dicts always have the same - # structure no matter the setup (multi-GPU, multi-agent, - # minibatch SGD, tf vs torch). - learner_info_builder = LearnerInfoBuilder(num_devices=1) - multi_agent_results = self.local_worker.learn_on_batch( - replay) - for pid, results in multi_agent_results.items(): - learner_info_builder.add_learn_on_batch_results( - results, pid) - td_error = results["td_error"] + grad_out = self.local_worker.learn_on_batch(replay) + for pid, info in grad_out.items(): + td_error = info.get( + "td_error", + info[LEARNER_STATS_KEY].get("td_error")) # Switch off auto-conversion from numpy to torch/tf # tensors for the indices. This may lead to errors # when sent to the buffer for processing @@ -68,7 +62,7 @@ def step(self): prio_dict[pid] = ( replay.policy_batches[pid].get("batch_indexes"), td_error) - self.learner_info = learner_info_builder.finalize() + self.stats[pid] = get_learner_stats(info) self.grad_timer.push_units_processed(replay.count) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) diff --git a/rllib/agents/dqn/r2d2.py b/rllib/agents/dqn/r2d2.py index d568272e957e9..7985b55fe305a 100644 --- a/rllib/agents/dqn/r2d2.py +++ b/rllib/agents/dqn/r2d2.py @@ -28,7 +28,7 @@ DEFAULT_CONFIG = dqn.DQNTrainer.merge_trainer_configs( dqn.DEFAULT_CONFIG, # See keys in impala.py, which are also supported. { - # Learning rate for adam optimizer. + # Learning rate for adam optimizer "lr": 1e-4, # Discount factor. "gamma": 0.997, @@ -40,6 +40,8 @@ "num_workers": 2, # Batch mode must be complete_episodes. "batch_mode": "complete_episodes", + # R2D2 does not suport n-step > 1 yet! + "n_step": 1, # If True, assume a zero-initialized state input (no matter where in # the episode the sequence is located). @@ -69,6 +71,7 @@ # Size of the replay buffer (in sequences, not timesteps). "buffer_size": 100000, # If True prioritized replay buffer will be used. + # Note: Not supported yet by R2D2! "prioritized_replay": False, # Set automatically: The number of contiguous environment steps to # replay at once. Will be calculated via @@ -88,8 +91,7 @@ def validate_config(config: TrainerConfigDict) -> None: """Checks and updates the config based on settings. - Rewrites rollout_fragment_length to take into account burn-in and - max_seq_len truncation. + Rewrites rollout_fragment_length to take into account n_step truncation. """ if config["replay_sequence_length"] != -1: raise ValueError( @@ -100,9 +102,15 @@ def validate_config(config: TrainerConfigDict) -> None: config["replay_sequence_length"] = \ config["burn_in"] + config["model"]["max_seq_len"] + if config.get("prioritized_replay"): + raise ValueError("Prioritized replay is not supported for R2D2 yet!") + if config.get("batch_mode") != "complete_episodes": raise ValueError("`batch_mode` must be 'complete_episodes'!") + if config["n_step"] > 1: + raise ValueError("`n_step` > 1 not yet supported by R2D2!") + def calculate_rr_weights(config: TrainerConfigDict) -> List[float]: """Calculate the round robin weights for the rollout and train steps""" diff --git a/rllib/agents/dqn/r2d2_tf_policy.py b/rllib/agents/dqn/r2d2_tf_policy.py index d34c35a44976b..1d72d12e7e25b 100644 --- a/rllib/agents/dqn/r2d2_tf_policy.py +++ b/rllib/agents/dqn/r2d2_tf_policy.py @@ -156,7 +156,7 @@ def r2d2_loss(policy: Policy, model, _, def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, seq_mask)) - # Make sure to use the correct time indices: + # Make sure use the correct time indices: # Q(t) - [gamma * r + Q^(t+1)] q_selected = tf.reshape(q_selected, [B, T])[:, :-1] td_error = q_selected - tf.stop_gradient( @@ -164,9 +164,7 @@ def reduce_mean_valid(t): td_error = td_error * tf.cast(seq_mask, tf.float32) weights = tf.reshape(weights, [B, T])[:, :-1] policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) - # Store the TD-error per time chunk (b/c we need only one mean - # prioritized replay weight per stored sequence). - policy._td_error = tf.reduce_mean(td_error, axis=-1) + policy._td_error = tf.reshape(td_error, [-1]) policy._loss_stats = { "mean_q": reduce_mean_valid(q_selected), "min_q": tf.reduce_min(q_selected), diff --git a/rllib/agents/dqn/r2d2_torch_policy.py b/rllib/agents/dqn/r2d2_torch_policy.py index 97c34327f7215..894c6dc2fb729 100644 --- a/rllib/agents/dqn/r2d2_torch_policy.py +++ b/rllib/agents/dqn/r2d2_torch_policy.py @@ -19,8 +19,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import apply_grad_clipping, \ - concat_multi_gpu_td_errors, FLOAT_MIN, huber_loss, sequence_mask +from ray.rllib.utils.torch_ops import apply_grad_clipping, FLOAT_MIN, \ + huber_loss, sequence_mask from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, nn = try_import_torch() @@ -170,20 +170,16 @@ def reduce_mean_valid(t): td_error = q_selected - target.reshape([B, T])[:, :-1].detach() td_error = td_error * seq_mask weights = weights.reshape([B, T])[:, :-1] - total_loss = reduce_mean_valid(weights * huber_loss(td_error)) + policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) + policy._td_error = td_error.reshape([-1]) + policy._loss_stats = { + "mean_q": reduce_mean_valid(q_selected), + "min_q": torch.min(q_selected), + "max_q": torch.max(q_selected), + "mean_td_error": reduce_mean_valid(td_error), + } - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["total_loss"] = total_loss - model.tower_stats["mean_q"] = reduce_mean_valid(q_selected) - model.tower_stats["min_q"] = torch.min(q_selected) - model.tower_stats["max_q"] = torch.max(q_selected) - model.tower_stats["mean_td_error"] = reduce_mean_valid(td_error) - # Store per time chunk (b/c we need only one mean - # prioritized replay weight per stored sequence). - model.tower_stats["td_error"] = torch.mean(td_error, dim=-1) - - return total_loss + return policy._total_loss def h_function(x, epsilon=1.0): @@ -237,23 +233,15 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # Do forward pass on loss to update td error attribute r2d2_loss(self, self.model, None, input_dict) - return self.model.tower_stats["td_error"] + return self._td_error self.compute_td_error = compute_td_error -def build_q_stats(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]: - - return { +def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: + return dict({ "cur_lr": policy.cur_lr, - "total_loss": torch.mean( - torch.stack(policy.get_tower_stats("total_loss"))), - "mean_q": torch.mean(torch.stack(policy.get_tower_stats("mean_q"))), - "min_q": torch.mean(torch.stack(policy.get_tower_stats("min_q"))), - "max_q": torch.mean(torch.stack(policy.get_tower_stats("max_q"))), - "mean_td_error": torch.mean( - torch.stack(policy.get_tower_stats("mean_td_error"))), - } + }, **policy._loss_stats) def setup_early_mixins(policy: Policy, obs_space, action_space, @@ -291,7 +279,7 @@ def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, postprocess_fn=postprocess_nstep_and_prio, optimizer_fn=adam_optimizer, extra_grad_process_fn=grad_process_and_td_error_fn, - extra_learn_fetches_fn=concat_multi_gpu_td_errors, + extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error}, extra_action_out_fn=extra_action_out_fn, before_init=setup_early_mixins, before_loss_init=before_loss_init, diff --git a/rllib/agents/dqn/simple_q_tf_policy.py b/rllib/agents/dqn/simple_q_tf_policy.py index 13e62bca1fd9a..0801b6fd26e63 100644 --- a/rllib/agents/dqn/simple_q_tf_policy.py +++ b/rllib/agents/dqn/simple_q_tf_policy.py @@ -181,7 +181,7 @@ def compute_q_values(policy: Policy, explore, is_training=None) -> TensorType: model_out, _ = model({ - SampleBatch.OBS: obs, + SampleBatch.CUR_OBS: obs, "is_training": is_training if is_training is not None else policy._get_is_training_placeholder(), }, [], None) diff --git a/rllib/agents/dqn/simple_q_torch_policy.py b/rllib/agents/dqn/simple_q_torch_policy.py index 205fa6042e09e..055ce51598265 100644 --- a/rllib/agents/dqn/simple_q_torch_policy.py +++ b/rllib/agents/dqn/simple_q_torch_policy.py @@ -16,7 +16,7 @@ from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import concat_multi_gpu_td_errors, huber_loss +from ray.rllib.utils.torch_ops import huber_loss from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, nn = try_import_torch() @@ -112,20 +112,12 @@ def build_q_losses(policy: Policy, model, dist_class, td_error = q_t_selected - q_t_selected_target.detach() loss = torch.mean(huber_loss(td_error)) - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["loss"] = loss - # TD-error tensor in final stats - # will be concatenated and retrieved for each individual batch item. - model.tower_stats["td_error"] = td_error + # save TD error as an attribute for outside access + policy.td_error = td_error return loss -def stats_fn(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]: - return {"loss": torch.mean(torch.stack(policy.get_tower_stats("loss")))} - - def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, action_dist) -> Dict[str, TensorType]: """Adds q-values to the action out dict.""" @@ -152,11 +144,10 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, framework="torch", loss_fn=build_q_losses, get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, - stats_fn=stats_fn, extra_action_out_fn=extra_action_out_fn, after_init=setup_late_mixins, make_model_and_action_dist=build_q_model_and_distribution, mixins=[TargetNetworkMixin], action_distribution_fn=get_distribution_inputs_and_class, - extra_learn_fetches_fn=concat_multi_gpu_td_errors, + extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error}, ) diff --git a/rllib/agents/dqn/tests/test_apex_dqn.py b/rllib/agents/dqn/tests/test_apex_dqn.py index 93702bf8d7c1b..63c051310baec 100644 --- a/rllib/agents/dqn/tests/test_apex_dqn.py +++ b/rllib/agents/dqn/tests/test_apex_dqn.py @@ -4,10 +4,8 @@ import ray import ray.rllib.agents.dqn.apex as apex from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ - LEARNER_STATS_KEY from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator class TestApexDQN(unittest.TestCase): @@ -28,9 +26,7 @@ def test_apex_zero_workers(self): config["optimizer"]["num_replay_buffer_shards"] = 1 for _ in framework_iterator(config): trainer = apex.ApexTrainer(config=config, env="CartPole-v0") - results = trainer.train() - check_train_results(results) - print(results) + trainer.train() trainer.stop() def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): @@ -57,9 +53,7 @@ def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): check_compute_single_action(trainer) for i in range(2): - results = trainer.train() - check_train_results(results) - print(results) + print(trainer.train()) # Test again per-worker epsilon distribution # (should not have changed). @@ -103,8 +97,7 @@ def _step_n_times(trainer, n: int): """ for _ in range(n): results = trainer.train() - return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ - LEARNER_STATS_KEY]["cur_lr"] + return results["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"] # Check eager execution frameworks here, since it's easier to control # exact timesteps with these frameworks. diff --git a/rllib/agents/dqn/tests/test_dqn.py b/rllib/agents/dqn/tests/test_dqn.py index fbf029a511243..dbf4876742b1f 100644 --- a/rllib/agents/dqn/tests/test_dqn.py +++ b/rllib/agents/dqn/tests/test_dqn.py @@ -4,7 +4,7 @@ import ray import ray.rllib.agents.dqn as dqn from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator class TestDQN(unittest.TestCase): @@ -30,7 +30,6 @@ def test_dqn_compilation(self): trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() - check_train_results(results) print(results) check_compute_single_action(trainer) @@ -47,7 +46,6 @@ def test_dqn_compilation(self): trainer = dqn.DQNTrainer(config=rainbow_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() - check_train_results(results) print(results) check_compute_single_action(trainer) diff --git a/rllib/agents/dqn/tests/test_r2d2.py b/rllib/agents/dqn/tests/test_r2d2.py index 44b2e0887a1c5..d6e0d52d285e8 100644 --- a/rllib/agents/dqn/tests/test_r2d2.py +++ b/rllib/agents/dqn/tests/test_r2d2.py @@ -4,7 +4,7 @@ import ray.rllib.agents.dqn as dqn from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() @@ -43,7 +43,6 @@ def test_r2d2_compilation(self): trainer = dqn.R2D2Trainer(config=config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() - check_train_results(results) print(results) check_compute_single_action(trainer, include_state=True) diff --git a/rllib/agents/dqn/tests/test_simple_q.py b/rllib/agents/dqn/tests/test_simple_q.py index 299bf39f63e51..12cddac283208 100644 --- a/rllib/agents/dqn/tests/test_simple_q.py +++ b/rllib/agents/dqn/tests/test_simple_q.py @@ -10,7 +10,7 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.numpy import fc, one_hot, huber_loss from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator tf1, tf, tfv = try_import_tf() @@ -41,7 +41,6 @@ def test_simple_q_compilation(self): sb = rw.sample() assert sb.count == config["rollout_fragment_length"] results = trainer.train() - check_train_results(results) print(results) check_compute_single_action(trainer) diff --git a/rllib/agents/dreamer/dreamer.py b/rllib/agents/dreamer/dreamer.py index b3433f62cd5a0..4a8170f527875 100644 --- a/rllib/agents/dreamer/dreamer.py +++ b/rllib/agents/dreamer/dreamer.py @@ -7,12 +7,11 @@ from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - _get_shared_metrics + LEARNER_INFO, _get_shared_metrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.agents.dreamer.dreamer_model import DreamerModel from ray.rllib.execution.rollout_ops import ParallelRollouts -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.typing import SampleBatchType logger = logging.getLogger(__name__) diff --git a/rllib/agents/impala/tests/test_impala.py b/rllib/agents/impala/tests/test_impala.py index 13e22240e34aa..ba5e28e82073c 100644 --- a/rllib/agents/impala/tests/test_impala.py +++ b/rllib/agents/impala/tests/test_impala.py @@ -4,10 +4,8 @@ import ray.rllib.agents.impala as impala from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ - LEARNER_STATS_KEY from ray.rllib.utils.test_utils import check, \ - check_compute_single_action, check_train_results, framework_iterator + check_compute_single_action, framework_iterator tf1, tf, tfv = try_import_tf() @@ -41,10 +39,7 @@ def test_impala_compilation(self): # to do with LSTMs, though). trainer = impala.ImpalaTrainer(config=local_cfg, env=env) for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) - + print(trainer.train()) check_compute_single_action( trainer, include_state=lstm, @@ -66,8 +61,7 @@ def test_impala_lr_schedule(self): config["env"] = "CartPole-v0" def get_lr(result): - return result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ - LEARNER_STATS_KEY]["cur_lr"] + return result["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"] for fw in framework_iterator(config, frameworks=("tf", "torch")): trainer = impala.ImpalaTrainer(config=config) diff --git a/rllib/agents/impala/vtrace_tf_policy.py b/rllib/agents/impala/vtrace_tf_policy.py index f5b5ddc4192db..99960a3206b2c 100644 --- a/rllib/agents/impala/vtrace_tf_policy.py +++ b/rllib/agents/impala/vtrace_tf_policy.py @@ -111,13 +111,8 @@ def __init__(self, self.mean_entropy = tf.reduce_mean(masked_entropy) # The summed weighted loss. - self.total_loss = self.pi_loss - self.entropy * entropy_coeff - - # Optional vf loss (or in a separate term due to separate - # optimizers/networks). - self.loss_wo_vf = self.total_loss - if not config["_separate_vf_optimizer"]: - self.total_loss += self.vf_loss * vf_loss_coeff + self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - + self.entropy * entropy_coeff) def _make_time_major(policy, seq_lens, tensor, drop_last=False): @@ -225,10 +220,7 @@ def make_time_major(*args, **kw): clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) - if policy.config.get("_separate_vf_optimizer"): - return policy.loss.loss_wo_vf, policy.loss.vf_loss - else: - return policy.loss.total_loss + return policy.loss.total_loss def stats(policy, train_batch): @@ -247,21 +239,13 @@ def stats(policy, train_batch): "vf_loss": policy.loss.mean_vf_loss, "vf_explained_var": explained_variance( tf.reshape(policy.loss.value_targets, [-1]), - tf.reshape(values_batched, [-1])) + tf.reshape(values_batched, [-1])), } def grad_stats(policy, train_batch, grads): - # We have support for more than one loss (list of lists of grads). - if policy.config.get("_tf_policy_handles_more_than_one_loss"): - grad_gnorm = [tf.linalg.global_norm(g) for g in grads] - # Old case: We have a single list of grads (only one loss term and - # optimizer). - else: - grad_gnorm = tf.linalg.global_norm(grads) - return { - "grad_gnorm": grad_gnorm, + "grad_gnorm": tf.linalg.global_norm(grads), } diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index c8738d1875f63..ec279cd5573b0 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -1,12 +1,10 @@ import gym import logging import numpy as np -from typing import Any, Dict import ray import ray.rllib.agents.impala.vtrace_torch as vtrace from ray.rllib.models.torch.torch_action_dist import TorchCategorical -from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule, \ @@ -184,22 +182,17 @@ def _make_time_major(*args, **kw): clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["pi_loss"] = loss.pi_loss - model.tower_stats["vf_loss"] = loss.vf_loss - model.tower_stats["entropy"] = loss.entropy - model.tower_stats["mean_entropy"] = loss.mean_entropy - model.tower_stats["total_loss"] = loss.total_loss - - values_batched = make_time_major( - policy, - train_batch.get(SampleBatch.SEQ_LENS), - values, - drop_last=policy.config["vtrace"]) - model.tower_stats["vf_explained_var"] = explained_variance( - torch.reshape(loss.value_targets, [-1]), - torch.reshape(values_batched, [-1])) + # Store loss object only for multi-GPU tower 0. + if model is policy.model_gpu_towers[0]: + policy.loss = loss + values_batched = make_time_major( + policy, + train_batch.get(SampleBatch.SEQ_LENS), + values, + drop_last=policy.config["vtrace"]) + policy._vf_explained_var = explained_variance( + torch.reshape(loss.value_targets, [-1]), + torch.reshape(values_batched, [-1])), return loss.total_loss @@ -243,21 +236,15 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False): return res -def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, Any]: - +def stats(policy, train_batch): return { "cur_lr": policy.cur_lr, - "total_loss": torch.mean( - torch.stack(policy.get_tower_stats("total_loss"))), - "policy_loss": torch.mean( - torch.stack(policy.get_tower_stats("pi_loss"))), - "entropy": torch.mean( - torch.stack(policy.get_tower_stats("mean_entropy"))), + "policy_loss": policy.loss.pi_loss, + "entropy": policy.loss.mean_entropy, "entropy_coeff": policy.entropy_coeff, "var_gnorm": global_norm(policy.model.trainable_variables()), - "vf_loss": torch.mean(torch.stack(policy.get_tower_stats("vf_loss"))), - "vf_explained_var": torch.mean( - torch.stack(policy.get_tower_stats("vf_explained_var"))), + "vf_loss": policy.loss.vf_loss, + "vf_explained_var": policy._vf_explained_var, } diff --git a/rllib/agents/maml/maml.py b/rllib/agents/maml/maml.py index c85d4f158b3c5..9d82a0e192cc5 100644 --- a/rllib/agents/maml/maml.py +++ b/rllib/agents/maml/maml.py @@ -8,12 +8,11 @@ from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - STEPS_TRAINED_COUNTER, _get_shared_metrics + STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.execution.metric_ops import CollectMetrics from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.utils.deprecation import DEPRECATED_VALUE -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.util.iter import from_actors logger = logging.getLogger(__name__) @@ -99,10 +98,9 @@ def __call__(self, data_tuple): # Metric Updating metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += samples.count - fetches = None for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) - learner_stats = get_learner_stats(fetches) + fetches = get_learner_stats(fetches) # Sync workers with meta policy self.workers.sync_weights() @@ -112,12 +110,11 @@ def __call__(self, data_tuple): # Update KLS def update(pi, pi_id): - assert "inner_kl" not in learner_stats, ( - "inner_kl should be nested under policy id key", learner_stats) - if pi_id in learner_stats: - assert "inner_kl" in learner_stats[pi_id], (learner_stats, - pi_id) - pi.update_kls(learner_stats[pi_id]["inner_kl"]) + assert "inner_kl" not in fetches, ( + "inner_kl should be nested under policy id key", fetches) + if pi_id in fetches: + assert "inner_kl" in fetches[pi_id], (fetches, pi_id) + pi.update_kls(fetches[pi_id]["inner_kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) diff --git a/rllib/agents/maml/tests/test_maml.py b/rllib/agents/maml/tests/test_maml.py index e1905b5cc853f..b84e028571907 100644 --- a/rllib/agents/maml/tests/test_maml.py +++ b/rllib/agents/maml/tests/test_maml.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.maml as maml from ray.rllib.utils.test_utils import check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator class TestMAML(unittest.TestCase): @@ -34,9 +34,7 @@ def test_maml_compilation(self): env_ = "ray.rllib.examples.env.{}".format(env) trainer = maml.MAMLTrainer(config=config, env=env_) for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) + trainer.train() check_compute_single_action( trainer, include_prev_action_reward=True) trainer.stop() diff --git a/rllib/agents/marwil/tests/test_bc.py b/rllib/agents/marwil/tests/test_bc.py index d6ac234897839..c6508330e43de 100644 --- a/rllib/agents/marwil/tests/test_bc.py +++ b/rllib/agents/marwil/tests/test_bc.py @@ -6,7 +6,7 @@ import ray.rllib.agents.marwil as marwil from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator tf1, tf, tfv = try_import_tf() @@ -51,11 +51,7 @@ def test_bc_compilation_and_learning_from_offline_file(self): trainer = marwil.BCTrainer(config=config, env="CartPole-v0") learnt = False for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) - - eval_results = results.get("evaluation") + eval_results = trainer.train().get("evaluation") if eval_results: print("iter={} R={}".format( i, eval_results["episode_reward_mean"])) diff --git a/rllib/agents/marwil/tests/test_marwil.py b/rllib/agents/marwil/tests/test_marwil.py index b8ca7af86ae21..29c6b678ecf2c 100644 --- a/rllib/agents/marwil/tests/test_marwil.py +++ b/rllib/agents/marwil/tests/test_marwil.py @@ -9,7 +9,7 @@ from ray.rllib.offline import JsonReader from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -57,11 +57,7 @@ def test_marwil_compilation_and_learning_from_offline_file(self): trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0") learnt = False for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) - - eval_results = results.get("evaluation") + eval_results = trainer.train().get("evaluation") if eval_results: print("iter={} R={} ".format( i, eval_results["episode_reward_mean"])) diff --git a/rllib/agents/mbmpo/mbmpo.py b/rllib/agents/mbmpo/mbmpo.py index aaf2d835e6c1f..0a537213ac193 100644 --- a/rllib/agents/mbmpo/mbmpo.py +++ b/rllib/agents/mbmpo/mbmpo.py @@ -26,11 +26,10 @@ get_learner_stats from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - STEPS_TRAINED_COUNTER, _get_shared_metrics + STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics from ray.rllib.execution.metric_ops import CollectMetrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils.deprecation import DEPRECATED_VALUE -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.sgd import standardized from ray.rllib.utils.torch_ops import convert_to_torch_tensor from ray.rllib.utils.typing import EnvType, TrainerConfigDict @@ -161,19 +160,17 @@ def __call__(self, data_tuple): adapt_metrics_dict, prefix="MAMLIter{}".format(self.step_counter)) # MAML Meta-update. - fetches = None for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) - learner_stats = get_learner_stats(fetches) + fetches = get_learner_stats(fetches) # Update KLs. def update(pi, pi_id): - assert "inner_kl" not in learner_stats, ( - "inner_kl should be nested under policy id key", learner_stats) - if pi_id in learner_stats: - assert "inner_kl" in learner_stats[pi_id], (learner_stats, - pi_id) - pi.update_kls(learner_stats[pi_id]["inner_kl"]) + assert "inner_kl" not in fetches, ( + "inner_kl should be nested under policy id key", fetches) + if pi_id in fetches: + assert "inner_kl" in fetches[pi_id], (fetches, pi_id) + pi.update_kls(fetches[pi_id]["inner_kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) diff --git a/rllib/agents/mbmpo/tests/test_mbmpo.py b/rllib/agents/mbmpo/tests/test_mbmpo.py index 941686c3e717b..de708fd50d58c 100644 --- a/rllib/agents/mbmpo/tests/test_mbmpo.py +++ b/rllib/agents/mbmpo/tests/test_mbmpo.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.mbmpo as mbmpo from ray.rllib.utils.test_utils import check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator class TestMBMPO(unittest.TestCase): @@ -28,12 +28,8 @@ def test_mbmpo_compilation(self): trainer = mbmpo.MBMPOTrainer( config=config, env="ray.rllib.examples.env.mbmpo_env.CartPoleWrapper") - for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) - + trainer.train() check_compute_single_action( trainer, include_prev_action_reward=False) trainer.stop() diff --git a/rllib/agents/pg/pg_torch_policy.py b/rllib/agents/pg/pg_torch_policy.py index 34a17c5e03f97..d707f01f2364e 100644 --- a/rllib/agents/pg/pg_torch_policy.py +++ b/rllib/agents/pg/pg_torch_policy.py @@ -44,15 +44,11 @@ def pg_torch_loss( # L = -E[ log(pi(a|s)) * A] log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS]) - # Final policy loss. - policy_loss = -torch.mean( + # Save the loss in the policy object for the stats_fn below. + policy.pi_err = -torch.mean( log_probs * train_batch[Postprocessing.ADVANTAGES]) - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["policy_loss"] = policy_loss - - return policy_loss + return policy.pi_err def pg_loss_stats(policy: Policy, @@ -68,8 +64,8 @@ def pg_loss_stats(policy: Policy, """ return { - "policy_loss": torch.mean( - torch.stack(policy.get_tower_stats("policy_loss"))), + # `pi_err` (the loss) is stored inside `pg_torch_loss()`. + "policy_loss": policy.pi_err.item(), } diff --git a/rllib/agents/pg/tests/test_pg.py b/rllib/agents/pg/tests/test_pg.py index 40b985cc8e488..44a52829beaf3 100644 --- a/rllib/agents/pg/tests/test_pg.py +++ b/rllib/agents/pg/tests/test_pg.py @@ -7,9 +7,8 @@ from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.numpy import fc -from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator +from ray.rllib.utils import check, check_compute_single_action, fc, \ + framework_iterator class TestPG(unittest.TestCase): @@ -32,10 +31,7 @@ def test_pg_compilation(self): for env in ["FrozenLake-v0", "CartPole-v0"]: trainer = pg.PGTrainer(config=config, env=env) for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) - + print(trainer.train()) check_compute_single_action( trainer, include_prev_action_reward=True) diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index 142b96d6e247f..455044bebfe1d 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -304,7 +304,7 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: "vf_loss": policy._mean_vf_loss, "vf_explained_var": explained_variance( tf.reshape(policy._value_targets, [-1]), - tf.reshape(values_batched, [-1])) + tf.reshape(values_batched, [-1])), } if policy.config["vtrace"]: diff --git a/rllib/agents/ppo/appo_torch_policy.py b/rllib/agents/ppo/appo_torch_policy.py index 324b73bf5a6b7..f8ee24989d825 100644 --- a/rllib/agents/ppo/appo_torch_policy.py +++ b/rllib/agents/ppo/appo_torch_policy.py @@ -159,7 +159,7 @@ def reduce_mean_valid(t): torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) - mean_kl_loss = reduce_mean_valid(action_kl) + mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. @@ -188,7 +188,7 @@ def reduce_mean_valid(t): torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) - mean_kl_loss = reduce_mean_valid(action_kl) + mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. @@ -208,17 +208,16 @@ def reduce_mean_valid(t): # Optional additional KL Loss if policy.config["use_kl_loss"]: - total_loss += policy.kl_coeff * mean_kl_loss - - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["total_loss"] = total_loss - model.tower_stats["mean_policy_loss"] = mean_policy_loss - model.tower_stats["mean_kl_loss"] = mean_kl_loss - model.tower_stats["mean_vf_loss"] = mean_vf_loss - model.tower_stats["mean_entropy"] = mean_entropy - model.tower_stats["value_targets"] = value_targets - model.tower_stats["vf_explained_var"] = explained_variance( + total_loss += policy.kl_coeff * mean_kl + + policy._total_loss = total_loss + policy._mean_policy_loss = mean_policy_loss + # Backward compatibility: Deprecate policy._mean_kl. + policy._mean_kl_loss = policy._mean_kl = mean_kl + policy._mean_vf_loss = mean_vf_loss + policy._mean_entropy = mean_entropy + policy._value_targets = value_targets + policy._vf_explained_var = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] @@ -240,28 +239,22 @@ def stats(policy: Policy, train_batch: SampleBatch): """ stats_dict = { "cur_lr": policy.cur_lr, - "total_loss": torch.mean( - torch.stack(policy.get_tower_stats("total_loss"))), - "policy_loss": torch.mean( - torch.stack(policy.get_tower_stats("mean_policy_loss"))), - "entropy": torch.mean( - torch.stack(policy.get_tower_stats("mean_entropy"))), + "policy_loss": policy._mean_policy_loss, + "entropy": policy._mean_entropy, "var_gnorm": global_norm(policy.model.trainable_variables()), - "vf_loss": torch.mean( - torch.stack(policy.get_tower_stats("mean_vf_loss"))), - "vf_explained_var": torch.mean( - torch.stack(policy.get_tower_stats("vf_explained_var"))), + "vf_loss": policy._mean_vf_loss, + "vf_explained_var": policy._vf_explained_var, } if policy.config["vtrace"]: is_stat_mean = torch.mean(policy._is_ratio, [0, 1]) is_stat_var = torch.var(policy._is_ratio, [0, 1]) - stats_dict["mean_IS"] = is_stat_mean - stats_dict["var_IS"] = is_stat_var + stats_dict.update({"mean_IS": is_stat_mean}) + stats_dict.update({"var_IS": is_stat_var}) if policy.config["use_kl_loss"]: - stats_dict["kl"] = policy.get_tower_stats("mean_kl_loss") - stats_dict["KL_Coeff"] = policy.kl_coeff + stats_dict.update({"kl": policy._mean_kl_loss}) + stats_dict.update({"KL_Coeff": policy.kl_coeff}) return stats_dict diff --git a/rllib/agents/ppo/ddppo.py b/rllib/agents/ppo/ddppo.py index b7c15918b16fe..d3eee646999e4 100644 --- a/rllib/agents/ppo/ddppo.py +++ b/rllib/agents/ppo/ddppo.py @@ -26,10 +26,9 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - STEPS_TRAINED_COUNTER, LEARN_ON_BATCH_TIMER, \ + STEPS_TRAINED_COUNTER, LEARNER_INFO, LEARN_ON_BATCH_TIMER, \ _get_shared_metrics, _get_global_vars from ray.rllib.evaluation.rollout_worker import get_global_worker -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -76,11 +75,6 @@ "truncate_episodes": True, # This is auto set based on sample batch size. "train_batch_size": -1, - # Kl divergence penalty should be fixed to 0 in DDPPO because in order - # for it to be used as a penalty, we would have to un-decentralize - # DDPPO - "kl_coeff": 0.0, - "kl_target": 0.0 }, _allow_unknown_configs=True, ) @@ -137,13 +131,6 @@ def validate_config(config): raise ValueError( "Distributed data parallel requires truncate_episodes " "batch mode.") - # DDPPO doesn't support KL penalties like PPO-1. - # In order to support KL penalties, DDPPO would need to become - # undecentralized, which defeats the purpose of the algorithm. - # Users can still tune the entropy coefficient to control the - # policy entropy (similar to controlling the KL penalty). - if config["kl_coeff"] != 0.0 or config["kl_target"] != 0.0: - raise ValueError("DDPPO doesn't support KL penalties like PPO-1") def execution_plan(workers: WorkerSet, diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index e43d460087b84..e0ced5d82cdeb 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -20,11 +20,9 @@ StandardizeFields, SelectExperiences from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep from ray.rllib.execution.metric_ops import StandardMetricsReporting -from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.deprecation import DEPRECATED_VALUE -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ - LEARNER_STATS_KEY from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -219,12 +217,12 @@ def warn_about_bad_reward_scales(config, result): return result # Punt on handling multiagent case. # Warn about excessively high VF loss. - learner_info = result["info"][LEARNER_INFO] - if DEFAULT_POLICY_ID in learner_info: + learner_stats = result["info"]["learner"] + if DEFAULT_POLICY_ID in learner_stats: scaled_vf_loss = config["vf_loss_coeff"] * \ - learner_info[DEFAULT_POLICY_ID][LEARNER_STATS_KEY]["vf_loss"] + learner_stats[DEFAULT_POLICY_ID][LEARNER_STATS_KEY]["vf_loss"] - policy_loss = learner_info[DEFAULT_POLICY_ID][LEARNER_STATS_KEY][ + policy_loss = learner_stats[DEFAULT_POLICY_ID][LEARNER_STATS_KEY][ "policy_loss"] if config.get("model", {}).get("vf_share_layers") and \ scaled_vf_loss > 100: diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index 69e19e33d7817..f8f310e6b07e3 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -105,15 +105,15 @@ def reduce_mean_valid(t): policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["total_loss"] = total_loss - model.tower_stats["mean_policy_loss"] = mean_policy_loss - model.tower_stats["mean_vf_loss"] = mean_vf_loss - model.tower_stats["vf_explained_var"] = explained_variance( + # Store stats in policy for stats_fn. + policy._total_loss = total_loss + policy._mean_policy_loss = mean_policy_loss + policy._mean_vf_loss = mean_vf_loss + policy._vf_explained_var = explained_variance( train_batch[Postprocessing.VALUE_TARGETS], model.value_function()) - model.tower_stats["mean_entropy"] = mean_entropy - model.tower_stats["mean_kl_loss"] = mean_kl_loss + policy._mean_entropy = mean_entropy + # Backward compatibility: Deprecate policy._mean_kl. + policy._mean_kl_loss = policy._mean_kl = mean_kl_loss return total_loss @@ -132,17 +132,12 @@ def kl_and_loss_stats(policy: Policy, return { "cur_kl_coeff": policy.kl_coeff, "cur_lr": policy.cur_lr, - "total_loss": torch.mean( - torch.stack(policy.get_tower_stats("total_loss"))), - "policy_loss": torch.mean( - torch.stack(policy.get_tower_stats("mean_policy_loss"))), - "vf_loss": torch.mean( - torch.stack(policy.get_tower_stats("mean_vf_loss"))), - "vf_explained_var": torch.mean( - torch.stack(policy.get_tower_stats("vf_explained_var"))), - "kl": torch.mean(torch.stack(policy.get_tower_stats("mean_kl_loss"))), - "entropy": torch.mean( - torch.stack(policy.get_tower_stats("mean_entropy"))), + "total_loss": policy._total_loss, + "policy_loss": policy._mean_policy_loss, + "vf_loss": policy._mean_vf_loss, + "vf_explained_var": policy._vf_explained_var, + "kl": policy._mean_kl_loss, + "entropy": policy._mean_entropy, "entropy_coeff": policy.entropy_coeff, } diff --git a/rllib/agents/ppo/tests/test_appo.py b/rllib/agents/ppo/tests/test_appo.py index be007f3dd9995..32a5989263f7c 100644 --- a/rllib/agents/ppo/tests/test_appo.py +++ b/rllib/agents/ppo/tests/test_appo.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.ppo as ppo from ray.rllib.utils.test_utils import check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator class TestAPPO(unittest.TestCase): @@ -27,9 +27,7 @@ def test_appo_compilation(self): _config["vtrace"] = False trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0") for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) + print(trainer.train()) check_compute_single_action(trainer) trainer.stop() @@ -38,9 +36,7 @@ def test_appo_compilation(self): _config["vtrace"] = True trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0") for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) + print(trainer.train()) check_compute_single_action(trainer) trainer.stop() @@ -59,12 +55,10 @@ def test_appo_two_tf_optimizers(self): num_iterations = 2 # Only supported for tf so far. - for _ in framework_iterator(config, frameworks=("tf2", "tf")): + for _ in framework_iterator(config, frameworks="tf"): trainer = ppo.APPOTrainer(config=config, env="CartPole-v0") for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) + print(trainer.train()) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/agents/ppo/tests/test_ddppo.py b/rllib/agents/ppo/tests/test_ddppo.py index 0e8154a662d12..e1191cfb2cd35 100644 --- a/rllib/agents/ppo/tests/test_ddppo.py +++ b/rllib/agents/ppo/tests/test_ddppo.py @@ -1,13 +1,11 @@ import unittest -import pytest import ray import ray.rllib.agents.ppo as ppo from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ - LEARNER_STATS_KEY +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator class TestDDPPO(unittest.TestCase): @@ -28,9 +26,7 @@ def test_ddppo_compilation(self): for _ in framework_iterator(config, frameworks="torch"): trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) + trainer.train() # Make sure, weights on all workers are the same (including # local one). weights = trainer.workers.foreach_worker( @@ -52,25 +48,13 @@ def test_ddppo_schedule(self): trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") for _ in range(num_iterations): result = trainer.train() - lr = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ + lr = result["info"]["learner"][DEFAULT_POLICY_ID][ LEARNER_STATS_KEY]["cur_lr"] trainer.stop() assert lr == 0.0, "lr should anneal to 0.0" - def test_validate_config(self): - """Test if DDPPO will raise errors after invalid configs are passed.""" - config = ppo.ddppo.DEFAULT_CONFIG.copy() - config["kl_coeff"] = 1. - msg = "DDPPO doesn't support KL penalties like PPO-1" - # import ipdb; ipdb.set_trace() - with pytest.raises(ValueError, match=msg): - ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") - config["kl_coeff"] = 0. - config["kl_target"] = 1. - with pytest.raises(ValueError, match=msg): - ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") - if __name__ == "__main__": + import pytest import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index 198922ee7a338..2dfcec41010b5 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -14,12 +14,11 @@ from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_action_dist import TorchCategorical +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ - LEARNER_STATS_KEY from ray.rllib.utils.numpy import fc -from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator +from ray.rllib.utils.test_utils import check, framework_iterator, \ + check_compute_single_action # Fake CartPole episode of n time steps. FAKE_BATCH = SampleBatch({ @@ -60,8 +59,7 @@ def _check_lr_tf(policy, policy_id): assert lr == optim_lr, "LR scheduling error!" def on_train_result(self, *, trainer, result: dict, **kwargs): - stats = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ - LEARNER_STATS_KEY] + stats = result["info"]["learner"][DEFAULT_POLICY_ID][LEARNER_STATS_KEY] # Learning rate should go to 0 after 1 iter. check(stats["cur_lr"], 5e-5 if trainer.iteration == 1 else 0.0) # Entropy coeff goes to 0.05, then 0.0 (per iter). @@ -92,7 +90,7 @@ def test_ppo_compilation_and_schedule_mixins(self): config["model"]["lstm_cell_size"] = 10 config["model"]["max_seq_len"] = 20 # Use default-native keras models whenever possible. - # config["model"]["_use_default_native_models"] = True + config["model"]["_use_default_native_models"] = True # Setup lr- and entropy schedules for testing. config["lr_schedule"] = [[0, config["lr"]], [128, 0.0]] @@ -126,9 +124,7 @@ def test_ppo_compilation_and_schedule_mixins(self): check(lr, config["lr"]) for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) + print(trainer.train()) check_compute_single_action( trainer, @@ -317,19 +313,6 @@ def test_ppo_loss_function(self): check(pl, np.mean(-pg_loss)) check(v, np.mean(vf_loss), decimals=4) check(tl, overall_loss, decimals=4) - elif fw == "torch": - check(policy.model.tower_stats["mean_kl_loss"], kl) - check(policy.model.tower_stats["mean_entropy"], entropy) - check(policy.model.tower_stats["mean_policy_loss"], - np.mean(-pg_loss)) - check( - policy.model.tower_stats["mean_vf_loss"], - np.mean(vf_loss), - decimals=4) - check( - policy.model.tower_stats["total_loss"], - overall_loss, - decimals=4) else: check(policy._mean_kl_loss, kl) check(policy._mean_entropy, entropy) diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index ca0324ce4d08f..6c1078cbad314 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -8,6 +8,7 @@ from ray.rllib.agents.qmix.model import RNNModel, _get_size from ray.rllib.env.multi_agent_env import ENV_STATE from ray.rllib.env.wrappers.group_agents_wrapper import GROUP_REWARDS +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import chop_into_sequences @@ -15,7 +16,6 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import _unpack_obs from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.annotations import override # Torch must be installed. diff --git a/rllib/agents/sac/rnnsac.py b/rllib/agents/sac/rnnsac.py index 79bf6cdc816bc..3fb67e50d8ced 100644 --- a/rllib/agents/sac/rnnsac.py +++ b/rllib/agents/sac/rnnsac.py @@ -11,6 +11,10 @@ { # Batch mode (see common config) "batch_mode": "complete_episodes", + # If True prioritized replay buffer will be used. + "prioritized_replay": False, + # RNNSAC does not suport n-step > 1 yet! + "n_step": 1, # If True, assume a zero-initialized state input (no matter where in # the episode the sequence is located). # If False, store the initial states along with each SampleBatch, use @@ -46,6 +50,9 @@ def validate_config(config: TrainerConfigDict) -> None: config["replay_sequence_length"] = \ config["burn_in"] + config["model"]["max_seq_len"] + if config["n_step"] > 1: + raise ValueError("`n_step` > 1 not yet supported by RNNSAC!") + def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: """Policy class picker function. Class is chosen based on DL-framework. diff --git a/rllib/agents/sac/rnnsac_torch_policy.py b/rllib/agents/sac/rnnsac_torch_policy.py index faef59e1bee67..c0d223c0a4766 100644 --- a/rllib/agents/sac/rnnsac_torch_policy.py +++ b/rllib/agents/sac/rnnsac_torch_policy.py @@ -371,7 +371,6 @@ def reduce_mean_valid(t): critic_loss.append( reduce_mean_valid( train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) - td_error = td_error * seq_mask # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. @@ -402,21 +401,26 @@ def reduce_mean_valid(t): actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["q_t"] = q_t * seq_mask[..., None] - model.tower_stats["policy_t"] = policy_t * seq_mask[..., None] - model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None] - model.tower_stats["actor_loss"] = actor_loss - model.tower_stats["critic_loss"] = critic_loss - model.tower_stats["alpha_loss"] = alpha_loss - # Store per time chunk (b/c we need only one mean - # prioritized replay weight per stored sequence). - model.tower_stats["td_error"] = torch.mean( - td_error.reshape([-1, T]), dim=-1) + # Save for stats function. + policy.q_t = q_t * seq_mask[..., None] + policy.policy_t = policy_t * seq_mask[..., None] + policy.log_pis_t = log_pis_t * seq_mask[..., None] + + # Store td-error in model, such that for multi-GPU, we do not override + # them during the parallel loss phase. TD-error tensor in final stats + # can then be concatenated and retrieved for each individual batch item. + model.td_error = td_error * seq_mask + + policy.actor_loss = actor_loss + policy.critic_loss = critic_loss + policy.alpha_loss = alpha_loss + policy.log_alpha_value = model.log_alpha + policy.alpha_value = alpha + policy.target_entropy = model.target_entropy # Return all loss terms corresponding to our optimizers. - return tuple([actor_loss] + critic_loss + [alpha_loss]) + return tuple([policy.actor_loss] + policy.critic_loss + + [policy.alpha_loss]) RNNSACTorchPolicy = SACTorchPolicy.with_updates( diff --git a/rllib/agents/sac/sac_tf_model.py b/rllib/agents/sac/sac_tf_model.py index 0b78f65a526fb..546de04ab47c9 100644 --- a/rllib/agents/sac/sac_tf_model.py +++ b/rllib/agents/sac/sac_tf_model.py @@ -1,7 +1,6 @@ import gym from gym.spaces import Box, Discrete import numpy as np -import tree # pip install dm_tree from typing import Dict, List, Optional from ray.rllib.models.catalog import ModelCatalog @@ -268,18 +267,13 @@ def get_policy_output(self, model_out: TensorType) -> TensorType: Returns: TensorType: Distribution inputs for sampling actions. """ - # Model outs may come as original Tuple/Dict observations, concat them + # Model outs may come as original Tuple observations, concat them # here if this is the case. if isinstance(self.action_model.obs_space, Box): if isinstance(model_out, (list, tuple)): model_out = tf.concat(model_out, axis=-1) elif isinstance(model_out, dict): - model_out = tf.concat( - [ - tf.expand_dims(val, 1) if len(val.shape) == 1 else val - for val in tree.flatten(model_out.values()) - ], - axis=-1) + model_out = tf.concat(list(model_out.values()), axis=-1) out, _ = self.action_model({"obs": model_out}, [], None) return out diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index 629de0efce536..111d8b717f494 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -6,6 +6,7 @@ from gym.spaces import Box, Discrete from functools import partial import logging +import numpy as np from typing import Dict, List, Optional, Tuple, Type, Union import ray @@ -52,6 +53,9 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, target model will be created in this function and assigned to `policy.target_model`. """ + # With separate state-preprocessor (before obs+action concat). + num_outputs = int(np.product(obs_space.shape)) + # Force-ignore any additionally provided hidden layer sizes. # Everything should be configured using SAC's "Q_model" and "policy_model" # settings. @@ -66,7 +70,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, - num_outputs=None, + num_outputs=num_outputs, model_config=config["model"], framework=config["framework"], default_model=default_model_cls, @@ -86,7 +90,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, policy.target_model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, - num_outputs=None, + num_outputs=num_outputs, model_config=config["model"], framework=config["framework"], default_model=default_model_cls, diff --git a/rllib/agents/sac/sac_torch_model.py b/rllib/agents/sac/sac_torch_model.py index 1fdc09412da13..64bbb40920453 100644 --- a/rllib/agents/sac/sac_torch_model.py +++ b/rllib/agents/sac/sac_torch_model.py @@ -1,7 +1,6 @@ import gym from gym.spaces import Box, Discrete import numpy as np -import tree # pip install dm_tree from typing import Dict, List, Optional from ray.rllib.models.catalog import ModelCatalog @@ -282,12 +281,7 @@ def get_policy_output(self, model_out: TensorType) -> TensorType: if isinstance(model_out, (list, tuple)): model_out = torch.cat(model_out, dim=-1) elif isinstance(model_out, dict): - model_out = torch.cat( - [ - torch.unsqueeze(val, 1) if len(val.shape) == 1 else val - for val in tree.flatten(model_out.values()) - ], - dim=-1) + model_out = torch.cat(list(model_out.values()), dim=-1) out, _ = self.action_model({"obs": model_out}, [], None) return out diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index dee2693abf29e..6bfdb98decc7b 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -5,7 +5,6 @@ import gym from gym.spaces import Box, Discrete import logging -import tree # pip install dm_tree from typing import Dict, List, Optional, Tuple, Type, Union import ray @@ -315,21 +314,26 @@ def actor_critic_loss( # the Q-net(s)' variables. actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy) - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["q_t"] = q_t - model.tower_stats["policy_t"] = policy_t - model.tower_stats["log_pis_t"] = log_pis_t - model.tower_stats["actor_loss"] = actor_loss - model.tower_stats["critic_loss"] = critic_loss - model.tower_stats["alpha_loss"] = alpha_loss + # Save for stats function. + policy.q_t = q_t + policy.policy_t = policy_t + policy.log_pis_t = log_pis_t - # TD-error tensor in final stats - # will be concatenated and retrieved for each individual batch item. - model.tower_stats["td_error"] = td_error + # Store td-error in model, such that for multi-GPU, we do not override + # them during the parallel loss phase. TD-error tensor in final stats + # can then be concatenated and retrieved for each individual batch item. + model.td_error = td_error + + policy.actor_loss = actor_loss + policy.critic_loss = critic_loss + policy.alpha_loss = alpha_loss + policy.log_alpha_value = model.log_alpha + policy.alpha_value = alpha + policy.target_entropy = model.target_entropy # Return all loss terms corresponding to our optimizers. - return tuple([actor_loss] + critic_loss + [alpha_loss]) + return tuple([policy.actor_loss] + policy.critic_loss + + [policy.alpha_loss]) def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: @@ -342,23 +346,17 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: Returns: Dict[str, TensorType]: The stats dict. """ - q_t = torch.stack(policy.get_tower_stats("q_t")) - return { - "actor_loss": torch.mean( - torch.stack(policy.get_tower_stats("actor_loss"))), - "critic_loss": torch.mean( - torch.stack(tree.flatten(policy.get_tower_stats("critic_loss")))), - "alpha_loss": torch.mean( - torch.stack(policy.get_tower_stats("alpha_loss"))), - "alpha_value": torch.exp(policy.model.log_alpha), - "log_alpha_value": policy.model.log_alpha, - "target_entropy": policy.model.target_entropy, - "policy_t": torch.mean( - torch.stack(policy.get_tower_stats("policy_t"))), - "mean_q": torch.mean(q_t), - "max_q": torch.max(q_t), - "min_q": torch.min(q_t), + "actor_loss": torch.mean(policy.actor_loss), + "critic_loss": torch.mean(torch.stack(policy.critic_loss)), + "alpha_loss": torch.mean(policy.alpha_loss), + "alpha_value": torch.mean(policy.alpha_value), + "log_alpha_value": torch.mean(policy.log_alpha_value), + "target_entropy": policy.target_entropy, + "policy_t": torch.mean(policy.policy_t), + "mean_q": torch.mean(policy.q_t), + "max_q": torch.max(policy.q_t), + "min_q": torch.min(policy.q_t), } @@ -432,9 +430,9 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # (one TD-error value per item in batch to update PR weights). actor_critic_loss(self, self.model, None, input_dict) - # `self.model.td_error` is set within actor_critic_loss call. - # Return its updated value here. - return self.model.tower_stats["td_error"] + # `self.td_error` is set within actor_critic_loss call. Return + # its updated value here. + return self.td_error # Assign the method to policy (self) for later usage. self.compute_td_error = compute_td_error diff --git a/rllib/agents/sac/tests/test_rnnsac.py b/rllib/agents/sac/tests/test_rnnsac.py deleted file mode 100644 index f0e8c5a750c57..0000000000000 --- a/rllib/agents/sac/tests/test_rnnsac.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -import ray -import ray.rllib.agents.sac as sac -from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator - -tf1, tf, tfv = try_import_tf() -torch, nn = try_import_torch() - - -class TestRNNSAC(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - ray.init() - - @classmethod - def tearDownClass(cls) -> None: - ray.shutdown() - - def test_rnnsac_compilation(self): - """Test whether a R2D2Trainer can be built on all frameworks.""" - config = sac.RNNSAC_DEFAULT_CONFIG.copy() - config["num_workers"] = 0 # Run locally. - - # Wrap with an LSTM and use a very simple base-model. - config["model"] = { - "max_seq_len": 20, - } - config["policy_model"] = { - "use_lstm": True, - "lstm_cell_size": 64, - "fcnet_hiddens": [10], - "lstm_use_prev_action": True, - "lstm_use_prev_reward": True, - } - config["Q_model"] = { - "use_lstm": True, - "lstm_cell_size": 64, - "fcnet_hiddens": [10], - "lstm_use_prev_action": True, - "lstm_use_prev_reward": True, - } - - # Test with PR activated. - config["prioritized_replay"] = True - - config["burn_in"] = 20 - config["zero_init_states"] = True - - config["lr"] = 5e-4 - - num_iterations = 1 - - # Test building an RNNSAC agent in all frameworks. - for _ in framework_iterator(config, frameworks="torch"): - trainer = sac.RNNSACTrainer(config=config, env="CartPole-v0") - for i in range(num_iterations): - results = trainer.train() - print(results) - - check_compute_single_action( - trainer, - include_state=True, - include_prev_action_reward=True, - ) - - -if __name__ == "__main__": - import pytest - import sys - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/sac/tests/test_sac.py b/rllib/agents/sac/tests/test_sac.py index 06083b33e3fa9..d9b1de208af33 100644 --- a/rllib/agents/sac/tests/test_sac.py +++ b/rllib/agents/sac/tests/test_sac.py @@ -1,5 +1,5 @@ from gym import Env -from gym.spaces import Box, Dict, Discrete, Tuple +from gym.spaces import Box, Discrete, Tuple import numpy as np import re import unittest @@ -21,9 +21,8 @@ from ray.rllib.utils.numpy import fc, huber_loss, relu from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator from ray.rllib.utils.torch_ops import convert_to_torch_tensor -from ray import tune tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -72,6 +71,8 @@ def test_sac_compilation(self): config["num_workers"] = 0 # Run locally. config["n_step"] = 3 config["twin_q"] = True + config["clip_actions"] = False + config["normalize_actions"] = True config["learning_starts"] = 0 config["prioritized_replay"] = True config["rollout_fragment_length"] = 10 @@ -91,28 +92,22 @@ def test_sac_compilation(self): image_space = Box(-1.0, 1.0, shape=(84, 84, 3)) simple_space = Box(-1.0, 1.0, shape=(3, )) - tune.register_env( - "random_dict_env", lambda _: RandomEnv({ - "observation_space": Dict({ - "a": simple_space, - "b": Discrete(2), - "c": image_space, }), - "action_space": Box(-1.0, 1.0, shape=(1, )), })) - tune.register_env( - "random_tuple_env", lambda _: RandomEnv({ - "observation_space": Tuple([ - simple_space, Discrete(2), image_space]), - "action_space": Box(-1.0, 1.0, shape=(1, )), })) - for fw in framework_iterator(config): # Test for different env types (discrete w/ and w/o image, + cont). for env in [ - "random_dict_env", - "random_tuple_env", + RandomEnv, "MsPacmanNoFrameskip-v4", "CartPole-v0", ]: print("Env={}".format(env)) + if env == RandomEnv: + config["env_config"] = { + "observation_space": Tuple((simple_space, Discrete(2), + image_space)), + "action_space": Box(-1.0, 1.0, shape=(1, )), + } + else: + config["env_config"] = {} # Test making the Q-model a custom one for CartPole, otherwise, # use the default model. config["Q_model"]["custom_model"] = "batch_norm{}".format( @@ -121,7 +116,6 @@ def test_sac_compilation(self): trainer = sac.SACTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() - check_train_results(results) print(results) check_compute_single_action(trainer) @@ -312,10 +306,8 @@ def test_sac_loss_function(self): elif fw == "torch": loss_torch(policy, policy.model, None, input_) - c, a, e, t = policy.get_tower_stats("critic_loss")[0], \ - policy.get_tower_stats("actor_loss")[0], \ - policy.get_tower_stats("alpha_loss")[0], \ - policy.get_tower_stats("td_error")[0] + c, a, e, t = policy.critic_loss, policy.actor_loss, \ + policy.alpha_loss, policy.model.td_error # Test actor gradients. policy.actor_optim.zero_grad() diff --git a/rllib/agents/tests/test_trainer.py b/rllib/agents/tests/test_trainer.py index 937206deac138..baf7b665963ca 100644 --- a/rllib/agents/tests/test_trainer.py +++ b/rllib/agents/tests/test_trainer.py @@ -13,7 +13,6 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.examples.parallel_evaluation_and_training import \ AssertNumEvalEpisodesCallback -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import framework_iterator @@ -73,7 +72,7 @@ def test_add_delete_policy(self): trainer = pg.PGTrainer(config=config) pol0 = trainer.get_policy("p0") r = trainer.train() - self.assertTrue("p0" in r["info"][LEARNER_INFO]) + self.assertTrue("p0" in r["info"]["learner"]) for i in range(1, 3): def new_mapping_fn(agent_id, episode, worker, **kwargs): diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index a1f4b64ee2426..7147ba9ea85c7 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -8,24 +8,22 @@ import pickle import tempfile import time -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Callable, Dict, List, Optional, Type, Union import ray from ray.actor import ActorHandle from ray.exceptions import RayError from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.env.env_context import EnvContext -from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.utils import gym_env_creator from ray.rllib.evaluation.collectors.simple_list_collector import \ SimpleListCollector -from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.policy.policy import Policy, PolicySpec -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils import deep_update, FilterManager, merge_dicts from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override, \ PublicAPI @@ -38,7 +36,7 @@ from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \ PartialTrainerConfigDict, PolicyID, ResultDict, TensorStructType, \ - TensorType, TrainerConfigDict + TrainerConfigDict from ray.tune.logger import Logger, UnifiedLogger from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.resources import Resources @@ -1009,29 +1007,17 @@ def _sync_weights_to_workers( @PublicAPI def compute_single_action( self, - observation: Optional[TensorStructType] = None, - state: Optional[List[TensorStructType]] = None, - *, - prev_action: Optional[TensorStructType] = None, - prev_reward: Optional[float] = None, - info: Optional[EnvInfoDict] = None, - input_dict: Optional[SampleBatch] = None, + observation: TensorStructType, + state: List[TensorStructType] = None, + prev_action: TensorStructType = None, + prev_reward: float = None, + info: EnvInfoDict = None, policy_id: PolicyID = DEFAULT_POLICY_ID, full_fetch: bool = False, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - episode: Optional[MultiAgentEpisode] = None, - unsquash_action: Optional[bool] = None, - clip_action: Optional[bool] = None, - - # Deprecated args. - unsquash_actions=DEPRECATED_VALUE, - clip_actions=DEPRECATED_VALUE, - - # Kwargs placeholder for future compatibility. - **kwargs, - ) -> Union[TensorStructType, Tuple[TensorStructType, List[TensorType], - Dict[str, TensorType]]]: + explore: bool = None, + unsquash_actions: Optional[bool] = None, + clip_actions: Optional[bool] = None, + ) -> TensorStructType: """Computes an action for the specified policy on the local worker. Note that you can also access the policy object through @@ -1039,123 +1025,70 @@ def compute_single_action( directly. Args: - observation: Single (unbatched) observation from the - environment. - state: List of all RNN hidden (single, unbatched) state tensors. - prev_action: Single (unbatched) previous action value. - prev_reward: Single (unbatched) previous reward value. - info: Env info dict, if any. - input_dict: An optional SampleBatch that holds all the values - for: obs, state, prev_action, and prev_reward, plus maybe - custom defined views of the current env trajectory. Note - that only one of `obs` or `input_dict` must be non-None. - policy_id: Policy to query (only applies to multi-agent). - Default: "default_policy". - full_fetch: Whether to return extra action fetch results. - This is always set to True if `state` is specified. - explore: Whether to apply exploration to the action. - Default: None -> use self.config["explore"]. - timestep: The current (sampling) time step. - episode: This provides access to all of the internal episodes' - state, which may be useful for model-based or multi-agent - algorithms. - unsquash_action: Should actions be unsquashed according to the - env's/Policy's action space? If None, use the value of - self.config["normalize_actions"]. - clip_action: Should actions be clipped according to the - env's/Policy's action space? If None, use the value of - self.config["clip_actions"]. - - Keyword Args: - kwargs: forward compatibility placeholder + observation (TensorStructType): observation from the environment. + state (List[TensorStructType]): RNN hidden state, if any. If state + is not None, then all of compute_single_action(...) is returned + (computed action, rnn state(s), logits dictionary). + Otherwise compute_single_action(...)[0] is returned + (computed action). + prev_action (TensorStructType): Previous action value, if any. + prev_reward (float): Previous reward, if any. + info (EnvInfoDict): info object, if any + policy_id (PolicyID): Policy to query (only applies to + multi-agent). + full_fetch (bool): Whether to return extra action fetch results. + This is always set to True if RNN state is specified. + explore (bool): Whether to pick an exploitation or exploration + action (default: None -> use self.config["explore"]). + unsquash_actions (bool): Should actions be unsquashed according to + the env's/Policy's action space? + clip_actions (bool): Should actions be clipped according to the + env's/Policy's action space? Returns: - The computed action if full_fetch=False, or a tuple of a) the - full output of policy.compute_actions() if full_fetch=True - or we have an RNN-based Policy. + any: The computed action if full_fetch=False, or + tuple: The full output of policy.compute_actions() if + full_fetch=True or we have an RNN-based Policy. Raises: KeyError: If the `policy_id` cannot be found in this Trainer's local worker. """ - if clip_actions != DEPRECATED_VALUE: - deprecation_warning( - old="Trainer.compute_single_action(`clip_actions`=...)", - new="Trainer.compute_single_action(`clip_action`=...)", - error=False) - clip_action = clip_actions - if unsquash_actions != DEPRECATED_VALUE: - deprecation_warning( - old="Trainer.compute_single_action(`unsquash_actions`=...)", - new="Trainer.compute_single_action(`unsquash_action`=...)", - error=False) - unsquash_action = unsquash_actions - - # User provided an input-dict: Assert that `obs`, `prev_a|r`, `state` - # are all None. - err_msg = "Provide either `input_dict` OR [`observation`, ...] as " \ - "args to Trainer.compute_single_action!" - if input_dict is not None: - assert observation is None and prev_action is None and \ - prev_reward is None and state is None, err_msg - observation = input_dict[SampleBatch.OBS] - else: - assert observation is not None, err_msg - - # Get the policy to compute the action for (in the multi-agent case, - # Trainer may hold >1 policies). policy = self.get_policy(policy_id) if policy is None: raise KeyError( f"PolicyID '{policy_id}' not found in PolicyMap of the " f"Trainer's local worker!") + local_worker = self.workers.local_worker() + if state is None: + state = [] + # Check the preprocessor and preprocess, if necessary. pp = local_worker.preprocessors[policy_id] if pp and type(pp).__name__ != "NoPreprocessor": observation = pp.transform(observation) - observation = local_worker.filters[policy_id]( + filtered_observation = local_worker.filters[policy_id]( observation, update=False) - # Input-dict. - if input_dict is not None: - input_dict[SampleBatch.OBS] = observation - action, state, extra = policy.compute_single_action( - input_dict=input_dict, - explore=explore, - timestep=timestep, - episode=episode, - ) - # Individual args. - else: - action, state, extra = policy.compute_single_action( - obs=observation, - state=state, - prev_action=prev_action, - prev_reward=prev_reward, - info=info, - explore=explore, - timestep=timestep, - episode=episode, - ) - - # If we work in normalized action space (normalize_actions=True), - # we re-translate here into the env's action space. - if unsquash_action: - action = space_utils.unsquash_action(action, - policy.action_space_struct) - # Clip, according to env's action space. - elif clip_action: - action = space_utils.clip_action(action, - policy.action_space_struct) + # Compute the action. + result = policy.compute_single_action( + filtered_observation, + state, + prev_action, + prev_reward, + info, + unsquash_actions=unsquash_actions, + clip_actions=clip_actions, + explore=explore) # Return 3-Tuple: Action, states, and extra-action fetches. if state or full_fetch: - return action, state, extra + return result # Ensure backward compatibility. else: - return action + return result[0] @Deprecated(new="compute_single_action", error=False) def compute_action(self, *args, **kwargs): @@ -1165,21 +1098,15 @@ def compute_action(self, *args, **kwargs): def compute_actions( self, observations: TensorStructType, - state: Optional[List[TensorStructType]] = None, - *, - prev_action: Optional[TensorStructType] = None, - prev_reward: Optional[TensorStructType] = None, - info: Optional[EnvInfoDict] = None, - policy_id: PolicyID = DEFAULT_POLICY_ID, - full_fetch: bool = False, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - episodes: Optional[List[MultiAgentEpisode]] = None, - unsquash_actions: Optional[bool] = None, - clip_actions: Optional[bool] = None, - # Deprecated. + state: List[TensorStructType] = None, + prev_action: TensorStructType = None, + prev_reward: TensorStructType = None, + info=None, + policy_id=DEFAULT_POLICY_ID, + full_fetch=False, + explore=None, normalize_actions=None, - **kwargs, + clip_actions=None, ): """Computes an action for the specified policy on the local Worker. @@ -1187,46 +1114,30 @@ def compute_actions( self.get_policy(policy_id) and call compute_actions() on it directly. Args: - observation: observation from the environment. - state: RNN hidden state, if any. If state is not None, + observation (obj): observation from the environment. + state (dict): RNN hidden state, if any. If state is not None, then all of compute_single_action(...) is returned (computed action, rnn state(s), logits dictionary). Otherwise compute_single_action(...)[0] is returned (computed action). - prev_action: Previous action value, if any. - prev_reward: Previous reward, if any. - info: Env info dict, if any. - policy_id: Policy to query (only applies to multi-agent). - full_fetch: Whether to return extra action fetch results. + prev_action (obj): previous action value, if any + prev_reward (int): previous reward, if any + info (dict): info object, if any + policy_id (str): Policy to query (only applies to multi-agent). + full_fetch (bool): Whether to return extra action fetch results. This is always set to True if RNN state is specified. - explore: Whether to pick an exploitation or exploration + explore (bool): Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). - timestep: The current (sampling) time step. - episodes: This provides access to all of the internal episodes' - state, which may be useful for model-based or multi-agent - algorithms. - unsquash_actions: Should actions be unsquashed according - to the env's/Policy's action space? If None, use - self.config["normalize_actions"]. - clip_actions: Should actions be clipped according to the - env's/Policy's action space? If None, use - self.config["clip_actions"]. - - Keyword Args: - kwargs: forward compatibility placeholder + normalize_actions (bool): Should actions be unsquashed according + to the env's/Policy's action space? + clip_actions (bool): Should actions be clipped according to the + env's/Policy's action space? Returns: any: The computed action if full_fetch=False, or tuple: The full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy. """ - if normalize_actions is not None: - deprecation_warning( - old="Trainer.compute_actions(`normalize_actions`=...)", - new="Trainer.compute_actions(`unsquash_actions`=...)", - error=False) - unsquash_actions = normalize_actions - # Preprocess obs and states. state_defined = state is not None policy = self.get_policy(policy_id) @@ -1251,38 +1162,23 @@ def compute_actions( state = list(zip(*filtered_state)) state = [np.stack(s) for s in state] - input_dict = {SampleBatch.OBS: obs_batch} - if prev_action: - input_dict[SampleBatch.PREV_ACTIONS] = prev_action - if prev_reward: - input_dict[SampleBatch.PREV_REWARDS] = prev_reward - if info: - input_dict[SampleBatch.INFOS] = info - for i, s in enumerate(state): - input_dict[f"state_in_{i}"] = s - # Batch compute actions - actions, states, infos = policy.compute_actions_from_input_dict( - input_dict=input_dict, - explore=explore, - timestep=timestep, - episodes=episodes, - ) - - # Unbatch actions for the environment into a multi-agent dict. - single_actions = space_utils.unbatch(actions) - actions = {} - for key, a in zip(observations, single_actions): - # If we work in normalized action space (normalize_actions=True), - # we re-translate here into the env's action space. - if unsquash_actions: - a = space_utils.unsquash_action(a, policy.action_space_struct) - # Clip, according to env's action space. - elif clip_actions: - a = space_utils.clip_action(a, policy.action_space_struct) - actions[key] = a - - # Unbatch states into a multi-agent dict. + actions, states, infos = policy.compute_actions( + obs_batch, + state, + prev_action, + prev_reward, + info, + normalize_actions=normalize_actions, + clip_actions=clip_actions, + explore=explore) + + # Unbatch actions for the environment + atns, actions = space_utils.unbatch(actions), {} + for key, atn in zip(observations, atns): + actions[key] = atn + + # Unbatch states into a dict unbatched_states = {} for idx, agent_id in enumerate(observations): unbatched_states[agent_id] = [s[idx] for s in states] @@ -1507,7 +1403,6 @@ def collect_metrics(self, selected_workers=selected_workers) @classmethod - @override(Trainable) def resource_help(cls, config: TrainerConfigDict) -> str: return ("\n\nYou can adjust the resource requests of RLlib agents by " "setting `num_workers`, `num_gpus`, and other configs. See " @@ -1843,25 +1738,23 @@ def with_updates(**overrides) -> Type["Trainer"]: "build_trainer()` function!") def _register_if_needed(self, env_object: Union[str, EnvType, None], - config) -> Optional[str]: + config): if isinstance(env_object, str): return env_object elif isinstance(env_object, type): name = env_object.__name__ - if config.get("remote_worker_envs"): + # Add convenience `_get_spaces` method. - @ray.remote(num_cpus=0) - class _wrapper(env_object): - # Add convenience `_get_spaces` and `_is_multi_agent` - # methods. - def _get_spaces(self): - return self.observation_space, self.action_space + def _get_spaces(s): + return s.observation_space, s.action_space - def _is_multi_agent(self): - return isinstance(self, MultiAgentEnv) + env_object._get_spaces = _get_spaces - register_env(name, lambda cfg: _wrapper.remote(cfg)) + if config.get("remote_worker_envs"): + register_env( + name, + lambda cfg: ray.remote(num_cpus=0)(env_object).remote(cfg)) else: register_env(name, lambda cfg: env_object(cfg)) return name diff --git a/rllib/contrib/alpha_zero/core/alpha_zero_policy.py b/rllib/contrib/alpha_zero/core/alpha_zero_policy.py index ad97829c04ba3..7b3b46e74747e 100644 --- a/rllib/contrib/alpha_zero/core/alpha_zero_policy.py +++ b/rllib/contrib/alpha_zero/core/alpha_zero_policy.py @@ -1,11 +1,10 @@ import numpy as np -from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.contrib.alpha_zero.core.mcts import Node, RootParentNode from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY torch, _ = try_import_torch() @@ -40,9 +39,9 @@ def compute_actions(self, **kwargs): input_dict = {"obs": obs_batch} - if prev_action_batch is not None: + if prev_action_batch: input_dict["prev_actions"] = prev_action_batch - if prev_reward_batch is not None: + if prev_reward_batch: input_dict["prev_rewards"] = prev_reward_batch return self.compute_actions_from_input_dict( diff --git a/rllib/contrib/bandits/agents/policy.py b/rllib/contrib/bandits/agents/policy.py index 07d837b4fc150..e47c91005232c 100644 --- a/rllib/contrib/bandits/agents/policy.py +++ b/rllib/contrib/bandits/agents/policy.py @@ -9,11 +9,11 @@ ParametricLinearModelThompsonSampling, ParametricLinearModelUCB from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import restore_original_dimensions +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils.annotations import override -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.util.debug import log_once logger = logging.getLogger(__name__) diff --git a/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py b/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py index 4501a04357fee..dfe3b8c85156d 100644 --- a/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py +++ b/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py @@ -7,7 +7,6 @@ from ray.rllib.contrib.bandits.agents import LinTSTrainer from ray.rllib.contrib.bandits.envs import WheelBanditEnv -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO def plot_model_weights(means, covs): @@ -44,7 +43,7 @@ def plot_model_weights(means, covs): trainer.train() info = trainer.train() - print(info["info"][LEARNER_INFO]) + print(info["info"]["learner"]) # Get model parameters means = [model.arms[i].theta.numpy() for i in range(5)] diff --git a/rllib/contrib/maddpg/maddpg_policy.py b/rllib/contrib/maddpg/maddpg_policy.py index 51a02f35afaea..86e417e5d3112 100644 --- a/rllib/contrib/maddpg/maddpg_policy.py +++ b/rllib/contrib/maddpg/maddpg_policy.py @@ -1,5 +1,6 @@ import ray from ray.rllib.agents.dqn.dqn_tf_policy import minimize_and_clip +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.postprocessing import adjust_nstep from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch @@ -8,7 +9,6 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils.framework import try_import_tf, try_import_tfp -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY import logging from gym.spaces import Box, Discrete diff --git a/rllib/contrib/sumo/connector.py b/rllib/contrib/sumo/connector.py index 0b795c45c8421..6b1d3d1d47e35 100644 --- a/rllib/contrib/sumo/connector.py +++ b/rllib/contrib/sumo/connector.py @@ -162,7 +162,7 @@ def _stopping_condition(self, current_step_counter, until_end): return True return False - def step(self, until_end=False, agents=None): + def step(self, until_end=False, agents=set()): """ Runs a "learning" step and returns if the simulation has finished. This function in meant to be called by the RLLIB Environment. @@ -176,9 +176,6 @@ def step(self, until_end=False, agents=None): Return: Bool. True iff the simulation is still ongoing. """ - if agents is None: - agents = set() - # Execute SUMO steps until the learning needs to happen current_step_counter = 0 logger.debug( diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 8ee302eb24683..4b2c77fe1532b 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -1,6 +1,5 @@ from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING -import ray from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -122,14 +121,10 @@ def to_base_env( env = _VectorEnvToBaseEnv(env) else: if remote_envs: - # Determine, whether the already existing sub-env (could - # be a ray.actor) is multi-agent or not. - multiagent = ray.get(env._is_multi_agent.remote()) if \ - hasattr(env, "_is_multi_agent") else False env = RemoteVectorEnv( make_env, num_envs, - multiagent=multiagent, + multiagent=False, remote_env_batch_wait_ms=remote_env_batch_wait_ms, existing_envs=[env], ) diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 4840de357585a..ed5705bf725d0 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -163,8 +163,7 @@ def make_multi_agent(env_name_or_creator): """ class MultiEnv(MultiAgentEnv): - def __init__(self, config=None): - config = config or {} + def __init__(self, config): num = config.pop("num_agents", 1) if isinstance(env_name_or_creator, str): self.agents = [ diff --git a/rllib/env/policy_server_input.py b/rllib/env/policy_server_input.py index c7148a94a8a2d..26e96673adb5c 100644 --- a/rllib/env/policy_server_input.py +++ b/rllib/env/policy_server_input.py @@ -89,18 +89,10 @@ def get_metrics(): # and sends data and metrics into the queues. handler = _make_handler(self.rollout_worker, self.samples_queue, self.metrics_queue) - try: - import time - time.sleep(1) - HTTPServer.__init__(self, (address, port), handler) - except OSError: - print(f"Creating a PolicyServer on {address}:{port} failed!") - import time - time.sleep(1) - raise - - logger.info("Starting connector server at " - f"{self.server_name}:{self.server_port}") + HTTPServer.__init__(self, (address, port), handler) + + logger.info("Starting connector server at {}:{}".format( + self.server_name, self.server_port)) # Start the serving thread, listening on socket and handling commands. serving_thread = threading.Thread( diff --git a/rllib/env/remote_vector_env.py b/rllib/env/remote_vector_env.py index 2d09302f59c15..aa2e958efee5a 100644 --- a/rllib/env/remote_vector_env.py +++ b/rllib/env/remote_vector_env.py @@ -29,8 +29,6 @@ def __init__(self, existing_envs: Optional[List[ray.actor.ActorHandle]] = None): # Could be creating local or remote envs. self.make_env = make_env - # Whether the given `make_env` callable already returns ray.remote - # objects or not. self.make_env_creates_actors = False # Already existing env objects (generated by the RolloutWorker). self.existing_envs = existing_envs or [] @@ -52,13 +50,9 @@ def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, self.actors = [] while len(self.actors) < self.num_envs: self.actors.append(self.make_env(len(self.actors))) - # `self.make_env` produces gym.Envs (or children thereof, such + # `self.make_env` produces gym.Envs (or other similar types, such # as MultiAgentEnv): Need to auto-wrap it here. The problem with - # this is that custom methods wil get lost. If you would like to - # keep your custom methods in your envs, you should provide the - # env class directly in your config (w/o tune.register_env()), - # such that your class will directly be made a @ray.remote - # (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`). + # this is that custom methods wil get lost. else: def make_remote_env(i): @@ -131,15 +125,7 @@ def make_remote_env(i): def send_actions(self, action_dict: MultiEnvDict) -> None: for env_id, actions in action_dict.items(): actor = self.actors[env_id] - # `actor` is a simple single-agent (remote) env, e.g. a gym.Env - # that was made a @ray.remote. - if not self.multiagent and self.make_env_creates_actors: - obj_ref = actor.step.remote(actions[_DUMMY_AGENT_ID]) - # `actor` is already a _RemoteSingleAgentEnv or - # _RemoteMultiAgentEnv wrapper - # (handles the multi-agent action_dict automatically). - else: - obj_ref = actor.step.remote(actions) + obj_ref = actor.step.remote(actions) self.pending[obj_ref] = actor @override(BaseEnv) diff --git a/rllib/env/tests/test_local_inference.sh b/rllib/env/tests/test_local_inference.sh new file mode 100755 index 0000000000000..be910f173c620 --- /dev/null +++ b/rllib/env/tests/test_local_inference.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +rm -f last_checkpoint.out +pkill -f cartpole_server.py +sleep 1 + +if [ -f test_local_inference.sh ]; then + basedir="../../examples/serving" +else + basedir="rllib/examples/serving" # In bazel. +fi + +# Start server with 2 workers (will listen on ports 9900 and 9901 for client +# connections). +# Do not attempt to restore from checkpoint; leads to errors on travis. +(python $basedir/cartpole_server.py --run=PPO --num-workers=2 --no-restore 2>&1 | grep -v 200) & +server_pid=$! + +echo "Waiting for server to start" +while ! curl localhost:9900; do + sleep 1 +done +while ! curl localhost:9901; do + sleep 1 +done + +# Start client 1 (port 9900). +sleep 2 +(python $basedir/cartpole_client.py --inference-mode=local --port=9900) & +client1_pid=$! + +# Start client 2 (port 9901). +sleep 2 +(python $basedir/cartpole_client.py --inference-mode=local --port=9901) & +client2_pid=$! + +# Start client 3 (also port 9901) and run it until it reaches 150.0 +# reward. Then stop everything. +sleep 2 +python $basedir/cartpole_client.py --stop-reward=150.0 --inference-mode=local --port=9901 + +kill $server_pid $client1_pid $client2_pid || true diff --git a/rllib/env/tests/test_policy_client_server_setup.sh b/rllib/env/tests/test_policy_client_server_setup.sh deleted file mode 100755 index 4d458ee5b8dba..0000000000000 --- a/rllib/env/tests/test_policy_client_server_setup.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash - -rm -f last_checkpoint.out - -if [ "$1" == "local" ]; then - inference_mode=local -else - inference_mode=remote -fi - -if [ "$2" == "cartpole" ]; then - server_script=cartpole_server.py - client_script=cartpole_client.py - stop_criterion="--stop-reward=150.0" -else - server_script=unity3d_server.py - client_script=unity3d_dummy_client.py - stop_criterion="--num-episodes=10" -fi - -pkill -f $server_script -sleep 1 - -if [ -f test_policy_client_server_setup.sh ]; then - basedir="../../examples/serving" -else - basedir="rllib/examples/serving" # In bazel. -fi - - -# Start server with 2 workers (will listen on ports 9900 and 9901 for client -# connections). -# Do not attempt to restore from checkpoint; leads to errors on travis. -(python $basedir/$server_script --run=PPO --num-workers=2 --no-restore 2>&1 | grep -v 200) & -server_pid=$! - -echo "Waiting for server to start ..." -while ! curl localhost:9900; do - sleep 1 -done -echo "Remote worker #1 on port 9900 is up!" -while ! curl localhost:9901; do - sleep 1 -done -echo "Remote worker #2 on port 9901 is up!" - -# Start client 1 (connect to port 9900). -sleep 2 -(python $basedir/$client_script --inference-mode=$inference_mode --port=9900) & -client1_pid=$! - -# Start client 2 (connect to port 9901). -sleep 2 -(python $basedir/$client_script --inference-mode=$inference_mode --port=9901) & -client2_pid=$! - -# Start client 3 (also connecting to port 9901) and run it until it reaches -# x reward (CartPole) or n episodes (dummy Unity3D). -# Then stop everything. -sleep 2 -python $basedir/$client_script $stop_criterion --inference-mode=$inference_mode --port=9901 - -kill $server_pid $client1_pid $client2_pid || true diff --git a/rllib/env/tests/test_remote_inference.sh b/rllib/env/tests/test_remote_inference.sh new file mode 100755 index 0000000000000..1a9ead838576c --- /dev/null +++ b/rllib/env/tests/test_remote_inference.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +rm -f last_checkpoint.out +pkill -f cartpole_server.py +sleep 1 + +if [ -f test_local_inference.sh ]; then + basedir="../../examples/serving" +else + basedir="rllib/examples/serving" # In bazel. +fi + +# Do not attempt to restore from checkpoint; leads to errors on travis. +(python $basedir/cartpole_server.py --run=DQN --num-workers=2 --no-restore 2>&1 | grep -v 200) & +server_pid=$! + +echo "Waiting for server to start" +while ! curl localhost:9900; do + sleep 1 +done +while ! curl localhost:9901; do + sleep 1 +done + +# Start client 1 (port 9900). +sleep 2 +(python $basedir/cartpole_client.py --inference-mode=remote --port=9900) & +client1_pid=$! + +# Start client 2 (port 9901). +sleep 2 +(python $basedir/cartpole_client.py --inference-mode=remote --port=9901) & +client2_pid=$! + +# Start client 3 (also port 9901) and run it until it reaches 150.0 +# reward. Then stop everything. +sleep 2 +python $basedir/cartpole_client.py --stop-reward=150.0 --inference-mode=remote --port=9901 + +kill $server_pid $client1_pid $client2_pid || true + diff --git a/rllib/env/tests/test_remote_worker_envs.py b/rllib/env/tests/test_remote_worker_envs.py deleted file mode 100644 index ba80c7e4cede1..0000000000000 --- a/rllib/env/tests/test_remote_worker_envs.py +++ /dev/null @@ -1,98 +0,0 @@ -import gym -import numpy as np -from pettingzoo.butterfly import pistonball_v4 -from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0 -import unittest - -import ray -from ray.rllib.agents.pg import pg -from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv -from ray.rllib.examples.env.random_env import RandomEnv, RandomMultiAgentEnv -from ray.rllib.examples.remote_vector_env_with_custom_api import \ - NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv -from ray import tune - - -# Function that outputs the environment you wish to register. -def env_creator(config): - env = pistonball_v4.env(local_ratio=config.get("local_ratio", 0.2)) - env = dtype_v0(env, dtype=np.float32) - env = color_reduction_v0(env, mode="R") - env = normalize_obs_v0(env) - return env - - -tune.register_env("cartpole", lambda env_ctx: gym.make("CartPole-v0")) - -tune.register_env("pistonball", - lambda config: PettingZooEnv(env_creator(config))) - - -class TestRemoteWorkerEnvSetting(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - ray.init(num_cpus=4) - - @classmethod - def tearDownClass(cls) -> None: - ray.shutdown() - - def test_remote_worker_env(self): - config = pg.DEFAULT_CONFIG.copy() - config["remote_worker_envs"] = True - config["num_envs_per_worker"] = 4 - - # Simple string env definition (gym.make(...)). - config["env"] = "CartPole-v0" - trainer = pg.PGTrainer(config=config) - print(trainer.train()) - trainer.stop() - - # Using tune.register. - config["env"] = "cartpole" - trainer = pg.PGTrainer(config=config) - print(trainer.train()) - trainer.stop() - - # Using class directly. - config["env"] = RandomEnv - trainer = pg.PGTrainer(config=config) - print(trainer.train()) - trainer.stop() - - # Using class directly: Sub-class of gym.Env, - # which implements its own API. - config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv - trainer = pg.PGTrainer(config=config) - print(trainer.train()) - trainer.stop() - - def test_remote_worker_env_multi_agent(self): - config = pg.DEFAULT_CONFIG.copy() - config["remote_worker_envs"] = True - config["num_envs_per_worker"] = 4 - - # Full classpath provided. - config["env"] = \ - "ray.rllib.examples.env.random_env.RandomMultiAgentEnv" - trainer = pg.PGTrainer(config=config) - print(trainer.train()) - trainer.stop() - - # Using tune.register. - config["env"] = "pistonball" - trainer = pg.PGTrainer(config=config) - print(trainer.train()) - trainer.stop() - - # Using class directly. - config["env"] = RandomMultiAgentEnv - trainer = pg.PGTrainer(config=config) - print(trainer.train()) - trainer.stop() - - -if __name__ == "__main__": - import pytest - import sys - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/wrappers/unity3d_env.py b/rllib/env/wrappers/unity3d_env.py index 2f9f75e79cb1e..2ec8fd6282945 100644 --- a/rllib/env/wrappers/unity3d_env.py +++ b/rllib/env/wrappers/unity3d_env.py @@ -6,7 +6,6 @@ from typing import Callable, Optional, Tuple from ray.rllib.env.multi_agent_env import MultiAgentEnv -from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID logger = logging.getLogger(__name__) @@ -305,12 +304,10 @@ def get_policy_configs_for_game( # Policies (Unity: "behaviors") and agent-to-policy mapping fns. if game_name == "SoccerStrikersVsGoalie": policies = { - "Goalie": PolicySpec( - observation_space=obs_spaces["Goalie"], - action_space=action_spaces["Goalie"]), - "Striker": PolicySpec( - observation_space=obs_spaces["Striker"], - action_space=action_spaces["Striker"]), + "Goalie": (None, obs_spaces["Goalie"], action_spaces["Goalie"], + {}), + "Striker": (None, obs_spaces["Striker"], + action_spaces["Striker"], {}), } def policy_mapping_fn(agent_id, episode, worker, **kwargs): @@ -318,9 +315,8 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): else: policies = { - game_name: PolicySpec( - observation_space=obs_spaces[game_name], - action_space=action_spaces[game_name]), + game_name: (None, obs_spaces[game_name], + action_spaces[game_name], {}), } def policy_mapping_fn(agent_id, episode, worker, **kwargs): diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 7c5415375d230..40745251e2b64 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -756,11 +756,8 @@ def postprocess_episode( "True. Alternatively, set no_done_at_end=True to " "allow this.") - if len(pre_batches) > 1: - other_batches = pre_batches.copy() - del other_batches[agent_id] - else: - other_batches = {} + other_batches = pre_batches.copy() + del other_batches[agent_id] pid = self.agent_key_to_policy_id[(episode_id, agent_id)] policy = self.policy_map[pid] if any(pre_batch[SampleBatch.DONES][:-1]) or len( diff --git a/rllib/evaluation/metrics.py b/rllib/evaluation/metrics.py index 73c25f916f0bb..06afe96d3fc6f 100644 --- a/rllib/evaluation/metrics.py +++ b/rllib/evaluation/metrics.py @@ -9,8 +9,8 @@ from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict if TYPE_CHECKING: @@ -42,6 +42,7 @@ def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict: >>> print(get_stats(grad_info)) {"vf_loss": ..., "policy_loss": ...} """ + if LEARNER_STATS_KEY in grad_info: return grad_info[LEARNER_STATS_KEY] @@ -56,15 +57,10 @@ def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict: @DeveloperAPI def collect_metrics(local_worker: Optional["RolloutWorker"] = None, - remote_workers: Optional[List[ActorHandle]] = None, - to_be_collected: Optional[List[ObjectRef]] = None, + remote_workers: List[ActorHandle] = [], + to_be_collected: List[ObjectRef] = [], timeout_seconds: int = 180) -> ResultDict: """Gathers episode metrics from RolloutWorker instances.""" - if remote_workers is None: - remote_workers = [] - - if to_be_collected is None: - to_be_collected = [] episodes, to_be_collected = collect_episodes( local_worker, @@ -78,16 +74,11 @@ def collect_metrics(local_worker: Optional["RolloutWorker"] = None, @DeveloperAPI def collect_episodes( local_worker: Optional["RolloutWorker"] = None, - remote_workers: Optional[List[ActorHandle]] = None, - to_be_collected: Optional[List[ObjectRef]] = None, + remote_workers: List[ActorHandle] = [], + to_be_collected: List[ObjectRef] = [], timeout_seconds: int = 180 ) -> Tuple[List[Union[RolloutMetrics, OffPolicyEstimate]], List[ObjectRef]]: """Gathers new episodes metrics tuples from the given evaluators.""" - if remote_workers is None: - remote_workers = [] - - if to_be_collected is None: - to_be_collected = [] if remote_workers: pending = [ diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 7151851587f73..a703b9f0a66e1 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -860,15 +860,14 @@ def compute_gradients( summarize(samples))) if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} - if self.policy_config.get("framework") == "tf": + if self.tf_sess is not None: + builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue - policy = self.policy_map[pid] - builder = TFRunBuilder(policy.get_session(), - "compute_gradients") grad_out[pid], info_out[pid] = ( - policy._build_compute_gradients(builder, batch)) + self.policy_map[pid]._build_compute_gradients( + builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: @@ -898,21 +897,14 @@ def apply_gradients(self, grads: ModelGradients) -> Dict[PolicyID, Any]: if log_once("apply_gradients"): logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) if isinstance(grads, dict): - if self.policy_config.get("framework") == "tf": - builders = {} - outputs = {} - for pid, grad in grads.items(): - if pid not in self.policies_to_train: - continue - policy = self.policy_map[pid] - builders[pid] = TFRunBuilder(policy.get_session(), - "apply_gradients") - outputs[pid] = policy._build_apply_gradients( - builders[pid], grad) - return { - pid: builders[pid].get(op) - for pid, op in outputs.items() + if self.tf_sess is not None: + builder = TFRunBuilder(self.tf_sess, "apply_gradients") + outputs = { + pid: self.policy_map[pid]._build_apply_gradients( + builder, grad) + for pid, grad in grads.items() } + return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 09fdb3b968dea..0737355dc0dfe 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -179,7 +179,7 @@ def central_vf_stats(policy, train_batch, grads): return { "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], - policy._central_value_out) + policy._central_value_out), } diff --git a/rllib/examples/custom_keras_model.py b/rllib/examples/custom_keras_model.py index c1c419d50e545..cec793dd17bb6 100644 --- a/rllib/examples/custom_keras_model.py +++ b/rllib/examples/custom_keras_model.py @@ -11,10 +11,9 @@ from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ - LEARNER_STATS_KEY tf1, tf, tfv = try_import_tf() @@ -111,7 +110,7 @@ def metrics(self): # Tests https://github.com/ray-project/ray/issues/7293 def check_has_custom_metric(result): - r = result["result"]["info"][LEARNER_INFO] + r = result["result"]["info"]["learner"] if DEFAULT_POLICY_ID in r: r = r[DEFAULT_POLICY_ID].get(LEARNER_STATS_KEY, r[DEFAULT_POLICY_ID]) diff --git a/rllib/examples/custom_model_loss_and_metrics.py b/rllib/examples/custom_model_loss_and_metrics.py index 6a38084f01188..9cea42cdf639a 100644 --- a/rllib/examples/custom_model_loss_and_metrics.py +++ b/rllib/examples/custom_model_loss_and_metrics.py @@ -19,10 +19,9 @@ from ray.rllib.examples.models.custom_loss_model import CustomLossModel, \ TorchCustomLossModel from ray.rllib.models import ModelCatalog +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ - LEARNER_STATS_KEY tf1, tf, tfv = try_import_tf() @@ -84,9 +83,9 @@ # Torch metrics structure. if args.framework == "torch": - assert LEARNER_STATS_KEY in info[LEARNER_INFO][DEFAULT_POLICY_ID] - assert "model" in info[LEARNER_INFO][DEFAULT_POLICY_ID] - assert "custom_metrics" in info[LEARNER_INFO][DEFAULT_POLICY_ID] + assert LEARNER_STATS_KEY in info["learner"][DEFAULT_POLICY_ID] + assert "model" in info["learner"][DEFAULT_POLICY_ID] + assert "custom_metrics" in info["learner"][DEFAULT_POLICY_ID] # TODO: (sven) Make sure the metrics structure gets unified between # tf and torch. Tf should work like current torch: @@ -97,5 +96,4 @@ # model: [return values of ModelV2's `metrics` method] # custom_metrics: [return values of callback: `on_learn_on_batch`] else: - assert "model" in info[LEARNER_INFO][DEFAULT_POLICY_ID][ - LEARNER_STATS_KEY] + assert "model" in info["learner"][DEFAULT_POLICY_ID][LEARNER_STATS_KEY] diff --git a/rllib/examples/deterministic_training.py b/rllib/examples/deterministic_training.py index e6fd21e56a9c3..528e002971c43 100644 --- a/rllib/examples/deterministic_training.py +++ b/rllib/examples/deterministic_training.py @@ -10,7 +10,6 @@ from ray.rllib.examples.env.env_using_remote_actor import \ CartPoleWithRemoteParamServer, ParameterStorage from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import check parser = argparse.ArgumentParser() @@ -61,7 +60,6 @@ check(results1["hist_stats"], results2["hist_stats"]) # As well as training behavior (minibatch sequence during SGD # iterations). - check( - results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"], - results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"]) + check(results1["info"]["learner"][DEFAULT_POLICY_ID]["learner_stats"], + results2["info"]["learner"][DEFAULT_POLICY_ID]["learner_stats"]) ray.shutdown() diff --git a/rllib/examples/env/coin_game_non_vectorized_env.py b/rllib/examples/env/coin_game_non_vectorized_env.py index e773bab36a6b9..5d725ade56d5d 100644 --- a/rllib/examples/env/coin_game_non_vectorized_env.py +++ b/rllib/examples/env/coin_game_non_vectorized_env.py @@ -13,7 +13,7 @@ from gym.utils import seeding from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.utils import override -from typing import Dict, Optional +from typing import Dict from ray.rllib.examples.env.utils.interfaces import InfoAccumulationInterface @@ -36,9 +36,7 @@ class CoinGame(InfoAccumulationInterface, MultiAgentEnv, gym.Env): np.array([-1, 0]), ] - def __init__(self, config: Optional[Dict] = None): - if config is None: - config = {} + def __init__(self, config: Dict = {}): self._validate_config(config) @@ -327,10 +325,7 @@ def _init_info(self): class AsymCoinGame(CoinGame): NAME = "AsymCoinGame" - def __init__(self, config: Optional[dict] = None): - if config is None: - config = {} - + def __init__(self, config: dict = {}): if "asymmetric" in config: assert config["asymmetric"] else: diff --git a/rllib/examples/env/coin_game_vectorized_env.py b/rllib/examples/env/coin_game_vectorized_env.py index 546a9b1a815b0..a71fa4327d399 100644 --- a/rllib/examples/env/coin_game_vectorized_env.py +++ b/rllib/examples/env/coin_game_vectorized_env.py @@ -21,9 +21,7 @@ class VectorizedCoinGame(CoinGame): Vectorized Coin Game environment. """ - def __init__(self, config=None): - if config is None: - config = {} + def __init__(self, config={}): super().__init__(config) @@ -161,10 +159,7 @@ def _load_env(self, env_state): class AsymVectorizedCoinGame(VectorizedCoinGame): NAME = "AsymCoinGame" - def __init__(self, config=None): - if config is None: - config = {} - + def __init__(self, config={}): if "asymmetric" in config: assert config["asymmetric"] else: diff --git a/rllib/examples/env/matrix_sequential_social_dilemma.py b/rllib/examples/env/matrix_sequential_social_dilemma.py index 9348a184890b8..97d222b3cff20 100644 --- a/rllib/examples/env/matrix_sequential_social_dilemma.py +++ b/rllib/examples/env/matrix_sequential_social_dilemma.py @@ -8,7 +8,7 @@ import logging from abc import ABC from collections import Iterable -from typing import Dict, Optional +from typing import Dict import numpy as np from gym.spaces import Discrete @@ -39,9 +39,7 @@ class MatrixSequentialSocialDilemma(InfoAccumulationInterface, MultiAgentEnv, episode. """ - def __init__(self, config: Optional[Dict] = None): - if config is None: - config = {} + def __init__(self, config: Dict = {}): assert "reward_randomness" not in config.keys() assert self.PAYOUT_MATRIX is not None diff --git a/rllib/examples/env/random_env.py b/rllib/examples/env/random_env.py index ceeca23424c24..b6b451fef7c33 100644 --- a/rllib/examples/env/random_env.py +++ b/rllib/examples/env/random_env.py @@ -14,9 +14,7 @@ class RandomEnv(gym.Env): configured as well. """ - def __init__(self, config=None): - config = config or {} - + def __init__(self, config): # Action space. self.action_space = config.get("action_space", Discrete(2)) # Observation space from which to sample. @@ -65,25 +63,3 @@ def step(self, action): # Multi-agent version of the RandomEnv. RandomMultiAgentEnv = make_multi_agent(lambda c: RandomEnv(c)) - - -# Large observation space "pre-compiled" random env (for testing). -class RandomLargeObsSpaceEnv(RandomEnv): - def __init__(self, config=None): - config = config or {} - config.update({ - "observation_space": gym.spaces.Box(-1.0, 1.0, (5000, )) - }) - super().__init__(config=config) - - -# Large observation space + cont. actions "pre-compiled" random env -# (for testing). -class RandomLargeObsSpaceEnvContActions(RandomEnv): - def __init__(self, config=None): - config = config or {} - config.update({ - "observation_space": gym.spaces.Box(-1.0, 1.0, (5000, )), - "action_space": gym.spaces.Box(-1.0, 1.0, (5, )), - }) - super().__init__(config=config) diff --git a/rllib/examples/pettingzoo_env.py b/rllib/examples/pettingzoo_env.py index 661f03f012088..5eeb962200849 100644 --- a/rllib/examples/pettingzoo_env.py +++ b/rllib/examples/pettingzoo_env.py @@ -42,17 +42,19 @@ def env_creator(config): # Register env register_env("pistonball", lambda config: PettingZooEnv(env_creator(config))) + env = PettingZooEnv(env_creator(config)) + observation_space = env.observation_space + action_space = env.action_space + del env # Configuration for multiagent setup with policy sharing: config["multiagent"] = { - # Setup a single, shared policy for all agents: "av". - # Use a simple set of strings (PolicyID) here. RLlib will - # automatically determine the policy class (Trainer's default class), - # observation- and action spaces (inferred from the env), and - # config overrides ({} in this case). - "policies": {"av"}, - # Map all agents to the "av" PolicyID. - "policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: "av", + # Setup a single, shared policy for all agents. + "policies": { + "av": (None, observation_space, action_space, {}) + }, + # Map all agents to that policy. + "policy_mapping_fn": lambda agent_id, episode, **kwargs: "av", } # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. diff --git a/rllib/examples/remote_vector_env_with_custom_api.py b/rllib/examples/remote_vector_env_with_custom_api.py index c212249990611..1dcc65eda89f8 100644 --- a/rllib/examples/remote_vector_env_with_custom_api.py +++ b/rllib/examples/remote_vector_env_with_custom_api.py @@ -65,7 +65,7 @@ class NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv(TaskSettableEnv): of gym.Env). """ - def __init__(self, config=None): + def __init__(self, config): self.action_space = gym.spaces.Box(0, 1, shape=(1, )) self.observation_space = gym.spaces.Box(0, 1, shape=(2, )) self.task = 1 @@ -108,6 +108,7 @@ def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: # Specify your custom (single, non-vectorized) env directly as a # class. This way, RLlib can auto-create Actors from this class # and handle everything correctly. + # TODO: Test for multi-agent case. "env": NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv, # Set up our own callbacks. "callbacks": TaskSettingCallback, diff --git a/rllib/examples/rock_paper_scissors_multiagent.py b/rllib/examples/rock_paper_scissors_multiagent.py index 0905314c1140b..bc7477a7f0716 100644 --- a/rllib/examples/rock_paper_scissors_multiagent.py +++ b/rllib/examples/rock_paper_scissors_multiagent.py @@ -9,19 +9,19 @@ import argparse import os -from pettingzoo.classic import rps_v2 import random from ray import tune from ray.rllib.agents.pg import PGTrainer, PGTFPolicy, PGTorchPolicy from ray.rllib.agents.registry import get_trainer_class -from ray.rllib.env import PettingZooEnv from ray.rllib.examples.policy.rock_paper_scissors_dummies import \ BeatLastHeuristic, AlwaysSameHeuristic from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved from ray.tune.registry import register_env +from ray.rllib.env import PettingZooEnv +from pettingzoo.classic import rps_v2 tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -149,8 +149,8 @@ def entropy_policy_gradient_loss(policy, model, dist_class, train_batch): logits, _ = model.from_batch(train_batch) action_dist = dist_class(logits, model) if args.framework == "torch": - # Required by PGTorchPolicy's stats fn. - model.tower_stats["policy_loss"] = torch.tensor([0.0]) + # required by PGTorchPolicy's stats fn. + policy.pi_err = torch.tensor([0.0]) return torch.mean(-0.1 * action_dist.entropy() - (action_dist.logp(train_batch["actions"]) * train_batch["advantages"])) diff --git a/rllib/examples/serving/cartpole_client.py b/rllib/examples/serving/cartpole_client.py index a368e6b44b852..4f9f247eda49b 100755 --- a/rllib/examples/serving/cartpole_client.py +++ b/rllib/examples/serving/cartpole_client.py @@ -54,7 +54,7 @@ "(Policy-computed) ones.") parser.add_argument( "--stop-reward", - type=float, + type=int, default=9999, help="Stop once the specified reward is reached.") parser.add_argument( diff --git a/rllib/examples/serving/unity3d_client.py b/rllib/examples/serving/unity3d_client.py index f3089abd402ae..8c8784ebf18ab 100644 --- a/rllib/examples/serving/unity3d_client.py +++ b/rllib/examples/serving/unity3d_client.py @@ -52,13 +52,9 @@ parser.add_argument( "--server", type=str, - default=SERVER_ADDRESS, - help="The Policy server's address to connect to from this client.") -parser.add_argument( - "--port", - type=int, - default=SERVER_PORT, - help="The port to use (on --server).") + default=SERVER_ADDRESS + ":" + str(SERVER_PORT), + help="The Policy server's address and port to connect to from this client." +) parser.add_argument( "--no-train", action="store_true", @@ -79,7 +75,7 @@ "learnt policy weights from the server?") parser.add_argument( "--stop-reward", - type=float, + type=int, default=9999, help="Stop once the specified reward is reached.") @@ -89,7 +85,7 @@ # Start the client for sending environment information (e.g. observations, # actions) to a policy server (listening on port 9900). client = PolicyClient( - "http://" + args.server + ":" + str(args.port), + "http://" + args.server, inference_mode=args.inference_mode, update_interval=args.update_interval_local_mode) diff --git a/rllib/examples/serving/unity3d_dummy_client.py b/rllib/examples/serving/unity3d_dummy_client.py deleted file mode 100644 index 93e7245f31a43..0000000000000 --- a/rllib/examples/serving/unity3d_dummy_client.py +++ /dev/null @@ -1,144 +0,0 @@ -""" -Dummy in-place replacement for the unity3d_client.py script -in case you don't have an actual Unity3D engine installed or just want -to test client/server connectivity with the unity3d_server.py script. - -This client script simply uses RLlib's RandomMultiAgentEnv to mimic -one of the ML Agents (Unity3D) example games (e.g. "3DBall"). - -To run this script on possibly different machines -against a central Policy server: - -1) Run (two separate shells/machines): -$ python unity3d_server.py --env 3DBall -$ python unity3d_dummy_client.py --env 3DBall --inference-mode=local -""" - -import argparse - -from ray.rllib.env.policy_client import PolicyClient -from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv -from ray.rllib.examples.env.random_env import RandomMultiAgentEnv - -SERVER_ADDRESS = "localhost" -SERVER_PORT = 9900 - -parser = argparse.ArgumentParser() -parser.add_argument( - "--env", - type=str, - default="3DBall", - choices=[ - "3DBall", "3DBallHard", "FoodCollector", "GridFoodCollector", - "Pyramids", "Sorter", "Tennis", "VisualHallway", "Walker" - ], - help="The name of the Env to mimic. Only those examples supported so " - "far for which all agents have the same " - "observation- and action spaces (feel free to add more to this script!)") -parser.add_argument( - "--horizon", - type=int, - default=200, - help="The max. number of `step()`s for any episode (per agent) before " - "it'll be reset again automatically.") -parser.add_argument( - "--server", - type=str, - default=SERVER_ADDRESS, - help="The Policy server's address to connect to from this client.") -parser.add_argument( - "--port", - type=int, - default=SERVER_PORT, - help="The port to use (on --server).") -parser.add_argument( - "--no-train", - action="store_true", - help="Whether to disable training (on the server side).") -parser.add_argument( - "--inference-mode", - type=str, - default="local", - choices=["local", "remote"], - help="Whether to compute actions `local`ly or `remote`ly. Note that " - "`local` is much faster b/c observations/actions do not have to be " - "sent via the network.") -parser.add_argument( - "--update-interval-local-mode", - type=float, - default=10.0, - help="For `inference-mode=local`, every how many seconds do we update " - "learnt policy weights from the server?") -parser.add_argument( - "--num-episodes", - type=int, - default=10, - help="Stop once the specified number of episodes have been played.") - -if __name__ == "__main__": - args = parser.parse_args() - - # Start the client for sending environment information (e.g. observations, - # actions) to a policy server (listening on port 9900). - client = PolicyClient( - "http://" + args.server + ":" + str(args.port), - inference_mode=args.inference_mode, - update_interval=args.update_interval_local_mode) - - # Get the multi-agent policies dict and agent->policy - # mapping-fn. - policies, policy_mapping_fn = \ - Unity3DEnv.get_policy_configs_for_game(args.env) - - # Make sure all policies' obs- and action spaces are the same. - # If not, we won't be able to mimic the Unity3D env using RLlib's - # RandomMultiAgentEnv. - first_policy_spec = next(iter(policies.values())) - for pid, policy_spec in policies.items(): - assert policy_spec.observation_space == \ - first_policy_spec.observation_space - assert policy_spec.action_space == first_policy_spec.action_space - - # Start and reset the actual Unity3DEnv (either already running Unity3D - # editor or a binary (game) to be started automatically). - env = RandomMultiAgentEnv({ - # Same number of agents as the actual Unity3D game would have. - "num_agents": len(policies), - # Make sure we stick to the user given horizons using our - # RandomMultiAgentEnv options. - "max_episode_len": args.horizon, - "p_done": 0.0, - # Same obs- action spaces as the actual Unity3D game would have. - "observation_space": first_policy_spec.observation_space, - "action_space": first_policy_spec.action_space, - }) - obs = env.reset() - eid = client.start_episode(training_enabled=not args.no_train) - - # Keep track of the total reward per episode. - total_rewards_this_episode = 0.0 - - # Loop through the env until n episodes completed. - num_episodes = 0 - while True: - # Get actions from the Policy server given our current obs. - actions = client.get_action(eid, obs) - # Apply actions to our env. - obs, rewards, dones, infos = env.step(actions) - total_rewards_this_episode += sum(rewards.values()) - # Log rewards and single-agent dones. - client.log_returns(eid, rewards, infos, multiagent_done_dict=dones) - # Check whether all agents are done and end the episode, if necessary. - if dones["__all__"]: - print("Episode done: Reward={}".format(total_rewards_this_episode)) - - num_episodes += 1 - if num_episodes >= args.num_episodes: - quit(0) - - # End the episode and reset dummy Env. - total_rewards_this_episode = 0.0 - client.end_episode(eid, obs) - obs = env.reset() - # Start a new episode. - eid = client.start_episode(training_enabled=not args.no_train) diff --git a/rllib/examples/serving/unity3d_server.py b/rllib/examples/serving/unity3d_server.py index 04ce5567fc165..56c1a0089fe50 100755 --- a/rllib/examples/serving/unity3d_server.py +++ b/rllib/examples/serving/unity3d_server.py @@ -31,42 +31,24 @@ import os import ray -from ray.rllib.agents.registry import get_trainer_class +from ray.tune import register_env +from ray.rllib.agents.ppo import PPOTrainer from ray.rllib.env.policy_server_input import PolicyServerInput from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv +from ray.rllib.examples.env.random_env import RandomMultiAgentEnv SERVER_ADDRESS = "localhost" SERVER_PORT = 9900 CHECKPOINT_FILE = "last_checkpoint_{}.out" parser = argparse.ArgumentParser() -parser.add_argument( - "--run", - default="PPO", - choices=["DQN", "PPO"], - help="The RLlib-registered algorithm to use.") -parser.add_argument( - "--framework", - choices=["tf", "tf2", "tfe", "torch"], - default="tf", - help="The DL framework specifier.") -parser.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of workers to use. Each worker will create " - "its own listening socket for incoming experiences.") parser.add_argument( "--env", type=str, default="3DBall", - choices=[ - "3DBall", "3DBallHard", "FoodCollector", "GridFoodCollector", - "Pyramids", "SoccerStrikersVsGoalie", "Sorter", "Tennis", - "VisualHallway", "Walker" - ], - help="The name of the Env to run in the Unity3D editor " - "(feel free to add more to this script!)") + choices=["3DBall", "SoccerStrikersVsGoalie"], + help="The name of the Env to run in the Unity3D editor. Either `3DBall` " + "or `SoccerStrikersVsGoalie` (feel free to add more to this script!)") parser.add_argument( "--port", type=int, @@ -89,21 +71,11 @@ args = parser.parse_args() ray.init() - # `InputReader` generator (returns None if no input reader is needed on - # the respective worker). - def _input(ioctx): - # We are remote worker or we are local worker with num_workers=0: - # Create a PolicyServerInput. - if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0: - return PolicyServerInput( - ioctx, SERVER_ADDRESS, args.port + ioctx.worker_index - - (1 if ioctx.worker_index > 0 else 0)) - # No InputReader (PolicyServerInput) needed. - else: - return None - - # Get the multi-agent policies dict and agent->policy - # mapping-fn. + # Create a fake-env for the server. This env will never be used (neither + # for sampling, nor for evaluation) and its obs/action Spaces do not + # matter either (multi-agent config below defines Spaces per Policy). + register_env("fake_unity", lambda c: RandomMultiAgentEnv(c)) + policies, policy_mapping_fn = \ Unity3DEnv.get_policy_configs_for_game(args.env) @@ -111,31 +83,27 @@ def _input(ioctx): # build their own samplers (and also Policy objects iff # `inference_mode=local` on clients' command line). config = { - # Indicate that the Trainer we setup here doesn't need an actual env. - # Allow spaces to be determined by user (see below). - "env": None, - - # Use the `PolicyServerInput` to generate experiences. - "input": _input, - # Use n worker processes to listen on different ports. - "num_workers": args.num_workers, + # Use the connector server to generate experiences. + "input": ( + lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, args.port)), + # Use a single worker process (w/ SyncSampler) to run the server. + "num_workers": 0, # Disable OPE, since the rollouts are coming from online clients. "input_evaluation": [], # Other settings. "train_batch_size": 256, "rollout_fragment_length": 20, - # Multi-agent setup for the given env. + # Multi-agent setup for the particular env. "multiagent": { "policies": policies, "policy_mapping_fn": policy_mapping_fn, }, - # DL framework to use. - "framework": args.framework, + "framework": "tf", } # Create the Trainer used for Policy serving. - trainer = get_trainer_class(args.run)(config=config) + trainer = PPOTrainer(env="fake_unity", config=config) # Attempt to restore from checkpoint if possible. checkpoint_path = CHECKPOINT_FILE.format(args.env) diff --git a/rllib/examples/trajectory_view_api.py b/rllib/examples/trajectory_view_api.py index b4a288e013bd5..31ce04e879126 100644 --- a/rllib/examples/trajectory_view_api.py +++ b/rllib/examples/trajectory_view_api.py @@ -1,15 +1,13 @@ import argparse -import numpy as np import ray -from ray.rllib.agents.ppo import PPOTrainer +from ray import tune from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole from ray.rllib.examples.models.trajectory_view_utilizing_models import \ FrameStackingCartPoleModel, TorchFrameStackingCartPoleModel from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_learning_achieved -from ray import tune tf1, tf, tfv = try_import_tf() @@ -49,19 +47,18 @@ args = parser.parse_args() ray.init(num_cpus=3) - num_frames = 16 - ModelCatalog.register_custom_model( "frame_stack_model", FrameStackingCartPoleModel if args.framework != "torch" else TorchFrameStackingCartPoleModel) + tune.register_env("stateless_cartpole", lambda c: StatelessCartPole()) config = { - "env": StatelessCartPole, + "env": "stateless_cartpole", "model": { "vf_share_layers": True, "custom_model": "frame_stack_model", "custom_model_config": { - "num_frames": num_frames, + "num_frames": 16, }, # To compare against a simple LSTM: @@ -84,45 +81,8 @@ "timesteps_total": args.stop_timesteps, "episode_reward_mean": args.stop_reward, } - results = tune.run( - args.run, config=config, stop=stop, verbose=2, checkpoint_at_end=True) + results = tune.run(args.run, config=config, stop=stop, verbose=2) if args.as_test: check_learning_achieved(results, args.stop_reward) - - checkpoints = results.get_trial_checkpoints_paths( - trial=results.get_best_trial("episode_reward_mean", mode="max"), - metric="episode_reward_mean") - - checkpoint_path = checkpoints[0][0] - trainer = PPOTrainer(config) - trainer.restore(checkpoint_path) - - # Inference loop. - env = StatelessCartPole() - - # Run manual inference loop for n episodes. - for _ in range(10): - episode_reward = 0.0 - reward = 0.0 - action = 0 - done = False - obs = env.reset() - while not done: - # Create a dummy action using the same observation n times, - # as well as dummy prev-n-actions and prev-n-rewards. - action, state, logits = trainer.compute_single_action( - input_dict={ - "obs": obs, - "prev_n_obs": np.stack([obs for _ in range(num_frames)]), - "prev_n_actions": np.stack([0 for _ in range(num_frames)]), - "prev_n_rewards": np.stack( - [1.0 for _ in range(num_frames)]), - }, - full_fetch=True) - obs, reward, done, info = env.step(action) - episode_reward += reward - - print(f"Episode reward={episode_reward}") - ray.shutdown() diff --git a/rllib/execution/common.py b/rllib/execution/common.py index 25e4428bffb63..3349541dac2f6 100644 --- a/rllib/execution/common.py +++ b/rllib/execution/common.py @@ -22,6 +22,9 @@ LEARN_ON_BATCH_TIMER = "learn" LOAD_BATCH_TIMER = "load" +# Instant metrics (keys for metrics.info). +LEARNER_INFO = "learner" + # Asserts that an object is a type of SampleBatch. def _check_sample_batch_type(batch: SampleBatchType) -> None: diff --git a/rllib/execution/learner_thread.py b/rllib/execution/learner_thread.py index d8c6f93c146b1..be7b028cdb04f 100644 --- a/rllib/execution/learner_thread.py +++ b/rllib/execution/learner_thread.py @@ -3,11 +3,10 @@ import threading from typing import Dict, Optional +from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.execution.minibatch_buffer import MinibatchBuffer from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ - LEARNER_INFO, LEARNER_STATS_KEY from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat from ray.util.iter import _NextValueNotReady @@ -57,7 +56,7 @@ def __init__(self, local_worker: RolloutWorker, minibatch_buffer_size: int, self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False - self.learner_info = {} + self.stats = {} self.stopped = False self.num_steps = 0 @@ -76,24 +75,12 @@ def step(self) -> Optional[_NextValueNotReady]: return _NextValueNotReady() with self.grad_timer: - # Use LearnerInfoBuilder as a unified way to build the final - # results dict from `learn_on_loaded_batch` call(s). - # This makes sure results dicts always have the same structure - # no matter the setup (multi-GPU, multi-agent, minibatch SGD, - # tf vs torch). - learner_info_builder = LearnerInfoBuilder(num_devices=1) - multi_agent_results = self.local_worker.learn_on_batch(batch) - for pid, results in multi_agent_results.items(): - learner_info_builder.add_learn_on_batch_results(results, pid) - self.learner_info = learner_info_builder.finalize() - learner_stats = { - pid: info[LEARNER_STATS_KEY] - for pid, info in self.learner_info.items() - } + fetches = self.local_worker.learn_on_batch(batch) self.weights_updated = True + self.stats = get_learner_stats(fetches) self.num_steps += 1 - self.outqueue.put((batch.count, learner_stats)) + self.outqueue.put((batch.count, self.stats)) self.learner_queue_size.push(self.inqueue.qsize()) def add_learner_metrics(self, result: Dict) -> Dict: @@ -104,7 +91,7 @@ def timer_to_ms(timer): result["info"].update({ "learner_queue": self.learner_queue_size.stats(), - LEARNER_INFO: copy.deepcopy(self.learner_info), + "learner": copy.deepcopy(self.stats), "timing_breakdown": { "learner_grad_time_ms": timer_to_ms(self.grad_timer), "learner_load_time_ms": timer_to_ms(self.load_timer), diff --git a/rllib/execution/multi_gpu_learner_thread.py b/rllib/execution/multi_gpu_learner_thread.py index 1120be7a77d4b..0d230878ff609 100644 --- a/rllib/execution/multi_gpu_learner_thread.py +++ b/rllib/execution/multi_gpu_learner_thread.py @@ -1,15 +1,15 @@ import logging -from six.moves import queue import threading +from six.moves import queue + +from ray.rllib.evaluation.metrics import get_learner_stats +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.execution.learner_thread import LearnerThread from ray.rllib.execution.minibatch_buffer import MinibatchBuffer -from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ - LEARNER_STATS_KEY from ray.rllib.utils.timer import TimerStat from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -103,14 +103,18 @@ def __init__( self.train_batch_size = train_batch_size - self.policy_map = self.local_worker.policy_map - self.devices = next(iter(self.policy_map.values())).devices + # TODO: (sven) Allow multi-GPU to work for multi-agent as well. + self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID] - logger.info("MultiGPULearnerThread devices {}".format(self.devices)) - assert self.train_batch_size % len(self.devices) == 0 - assert self.train_batch_size >= len(self.devices),\ + logger.info("MultiGPULearnerThread devices {}".format( + self.policy.devices)) + assert self.train_batch_size % len(self.policy.devices) == 0 + assert self.train_batch_size >= len(self.policy.devices),\ "batch too small" + if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}: + raise NotImplementedError("Multi-gpu mode for multi-agent") + self.tower_stack_indices = list(range(num_multi_gpu_tower_stacks)) # Two queues for tower stacks: @@ -142,39 +146,18 @@ def step(self) -> None: with self.load_wait_timer: buffer_idx, released = self.ready_tower_stacks_buffer.get() - get_num_samples_loaded_into_buffer = 0 with self.grad_timer: - # Use LearnerInfoBuilder as a unified way to build the final - # results dict from `learn_on_loaded_batch` call(s). - # This makes sure results dicts always have the same structure - # no matter the setup (multi-GPU, multi-agent, minibatch SGD, - # tf vs torch). - learner_info_builder = LearnerInfoBuilder( - num_devices=len(self.devices)) - - for pid in self.policy_map.keys(): - # Not a policy-to-train. - if pid not in self.local_worker.policies_to_train: - continue - policy = self.policy_map[pid] - default_policy_results = policy.learn_on_loaded_batch( - offset=0, buffer_index=buffer_idx) - learner_info_builder.add_learn_on_batch_results( - default_policy_results) - self.weights_updated = True - get_num_samples_loaded_into_buffer += \ - policy.get_num_samples_loaded_into_buffer(buffer_idx) - - self.learner_info = learner_info_builder.finalize() - learner_stats = { - pid: self.learner_info[pid][LEARNER_STATS_KEY] - for pid in self.learner_info.keys() - } + fetches = self.policy.learn_on_loaded_batch( + offset=0, buffer_index=buffer_idx) + self.weights_updated = True + self.stats = {DEFAULT_POLICY_ID: get_learner_stats(fetches)} if released: self.idle_tower_stacks.put(buffer_idx) - self.outqueue.put((get_num_samples_loaded_into_buffer, learner_stats)) + self.outqueue.put( + (self.policy.get_num_samples_loaded_into_buffer(buffer_idx), + self.stats)) self.learner_queue_size.push(self.inqueue.qsize()) @@ -197,7 +180,7 @@ def run(self) -> None: def _step(self) -> None: s = self.multi_gpu_learner_thread - policy_map = s.policy_map + policy = s.policy # Get a new batch from the data (inqueue). with self.queue_timer: @@ -208,14 +191,7 @@ def _step(self) -> None: # Load the batch into the idle stack. with self.load_timer: - for pid in policy_map.keys(): - if pid not in s.local_worker.policies_to_train: - continue - policy = policy_map[pid] - policy.load_batch_into_buffer( - batch=batch if isinstance(batch, SampleBatch) else - batch.policy_batches[pid], - buffer_index=buffer_idx) + policy.load_batch_into_buffer(batch=batch, buffer_index=buffer_idx) # Tag just-loaded stack as "ready". s.ready_tower_stacks.put(buffer_idx) diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 1f65620b115eb..364a814c8c996 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -4,15 +4,14 @@ from ray.util.iter import from_actors, LocalIterator from ray.util.iter_metrics import SharedMetrics +from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \ - STEPS_SAMPLED_COUNTER, SAMPLE_TIMER, GRAD_WAIT_TIMER, \ + STEPS_SAMPLED_COUNTER, LEARNER_INFO, SAMPLE_TIMER, GRAD_WAIT_TIMER, \ _check_sample_batch_type, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ - LEARNER_STATS_KEY from ray.rllib.utils.sgd import standardized from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients @@ -131,9 +130,7 @@ def __call__(self, item): (grads, info), count = item metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count - metrics.info[LEARNER_INFO] = { - DEFAULT_POLICY_ID: info - } if LEARNER_STATS_KEY in info else info + metrics.info[LEARNER_INFO] = get_learner_stats(info) metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() - self.fetch_start_time) return grads, count @@ -165,24 +162,15 @@ def __init__(self, min_batch_size: int, count_steps_by: str = "env_steps"): def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) + self.buffer.append(batch) if self.count_steps_by == "env_steps": - size = batch.count + self.count += batch.count else: assert isinstance(batch, MultiAgentBatch), \ "`count_steps_by=agent_steps` only allowed in multi-agent " \ "environments!" - size = batch.agent_steps() - - # Incoming batch is an empty dummy batch -> Ignore. - # Possibly produced automatically by a PolicyServer to unblock - # an external env waiting for inputs from unresponsive/disconnected - # client(s). - if size == 0: - return [] - - self.count += size - self.buffer.append(batch) + self.count += batch.agent_steps() if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index e289d5a7f2fbb..6c0e089ef598a 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -1,21 +1,22 @@ import logging import numpy as np import math +import tree # pip install dm_tree from typing import List, Tuple, Any import ray +from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import \ AGENT_STEPS_TRAINED_COUNTER, APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, \ - LAST_TARGET_UPDATE_TS, LEARN_ON_BATCH_TIMER, \ + LAST_TARGET_UPDATE_TS, LEARNER_INFO, LEARN_ON_BATCH_TIMER, \ LOAD_BATCH_TIMER, NUM_TARGET_UPDATES, STEPS_SAMPLED_COUNTER, \ STEPS_TRAINED_COUNTER, WORKER_UPDATE_TIMER, _check_sample_batch_type, \ _get_global_vars, _get_shared_metrics +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ - LEARNER_INFO from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients @@ -61,7 +62,7 @@ def __call__(self, # train batch and loop through train batch `num_sgd_iter` times. if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: lw = self.workers.local_worker() - learner_info = do_minibatch_sgd( + info = do_minibatch_sgd( batch, { pid: lw.get_policy(pid) for pid in self.policies @@ -69,10 +70,9 @@ def __call__(self, }, lw, self.num_sgd_iter, self.sgd_minibatch_size, []) # Single update step using train batch. else: - learner_info = \ - self.workers.local_worker().learn_on_batch(batch) + info = self.workers.local_worker().learn_on_batch(batch) - metrics.info[LEARNER_INFO] = learner_info + metrics.info[LEARNER_INFO] = info learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): @@ -88,7 +88,7 @@ def __call__(self, e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) - return batch, learner_info + return batch, info class MultiGPUTrainOneStep: @@ -174,43 +174,56 @@ def __call__(self, # Execute minibatch SGD on loaded data. with learn_timer: - # Use LearnerInfoBuilder as a unified way to build the final - # results dict from `learn_on_loaded_batch` call(s). - # This makes sure results dicts always have the same structure - # no matter the setup (multi-GPU, multi-agent, minibatch SGD, - # tf vs torch). - learner_info_builder = LearnerInfoBuilder( - num_devices=len(self.devices)) - + fetches = {} for policy_id, samples_per_device in num_loaded_samples.items(): policy = self.local_worker.policy_map[policy_id] num_batches = max( 1, int(samples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) + batch_fetches_all_towers = [] for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. - results = policy.learn_on_loaded_batch( + batch_fetches = policy.learn_on_loaded_batch( permutation[batch_index] * self.per_device_batch_size, buffer_index=0) - learner_info_builder.add_learn_on_batch_results( - results, policy_id) - - # Tower reduce and finalize results. - learner_info = learner_info_builder.finalize() + # No towers: Single CPU. + if "tower_0" not in batch_fetches: + batch_fetches_all_towers.append(batch_fetches) + else: + batch_fetches_all_towers.append( + tree.map_structure_with_path( + lambda p, *s: all_tower_reduce(p, *s), + *(batch_fetches.pop( + "tower_{}".format(tower_num)) + for tower_num in range( + len(self.devices))))) + for k, v in batch_fetches.items(): + if k == LEARNER_STATS_KEY: + for k1, v1 in batch_fetches[k].items(): + batch_fetches_all_towers[-1][ + LEARNER_STATS_KEY][k1] = v1 + else: + batch_fetches_all_towers[-1][k] = v + + # Reduce mean across all minibatch SGD steps (axis=0 to keep + # all shapes as-is). + fetches[policy_id] = tree.map_structure( + lambda *s: None if s[0] is None else np.nanmean(s, axis=0), + *batch_fetches_all_towers) load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() - metrics.info[LEARNER_INFO] = learner_info + metrics.info[LEARNER_INFO] = fetches if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: @@ -221,13 +234,24 @@ def __call__(self, # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) - return samples, learner_info + return samples, fetches # Backward compatibility. TrainTFMultiGPU = MultiGPUTrainOneStep +def all_tower_reduce(path, *tower_data): + """Reduces stats across towers based on their stats-dict paths.""" + if len(path) == 1 and path[0] == "td_error": + return np.concatenate(tower_data, axis=0) + elif path[-1].startswith("min_"): + return np.nanmin(tower_data) + elif path[-1].startswith("max_"): + return np.nanmax(tower_data) + return np.nanmean(tower_data) + + class ComputeGradients: """Callable that computes gradients with respect to the policy loss. @@ -249,12 +273,7 @@ def __call__(self, samples: SampleBatchType) -> Tuple[ModelGradients, int]: metrics = _get_shared_metrics() with metrics.timers[COMPUTE_GRADS_TIMER]: grad, info = self.workers.local_worker().compute_gradients(samples) - # RolloutWorker.compute_gradients returns pure single agent stats - # in a non-multi agent setup. - if isinstance(samples, MultiAgentBatch): - metrics.info[LEARNER_INFO] = info - else: - metrics.info[LEARNER_INFO] = {DEFAULT_POLICY_ID: info} + metrics.info[LEARNER_INFO] = get_learner_stats(info) return grad, samples.count diff --git a/rllib/models/tests/test_preprocessors.py b/rllib/models/tests/test_preprocessors.py index 2107ddec0cbd0..015efe6edd723 100644 --- a/rllib/models/tests/test_preprocessors.py +++ b/rllib/models/tests/test_preprocessors.py @@ -10,7 +10,7 @@ get_preprocessor, NoPreprocessor, TupleFlatteningPreprocessor, \ OneHotPreprocessor, AtariRamPreprocessor, GenericPixelPreprocessor from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator class TestPreprocessors(unittest.TestCase): @@ -50,9 +50,7 @@ def test_preprocessing_disabled(self): for _ in framework_iterator(config): trainer = ppo.PPOTrainer(config=config) for i in range(num_iterations): - results = trainer.train() - check_train_results(results) - print(results) + print(trainer.train()) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index c7323c41cab96..2236607d3f75f 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -38,8 +38,6 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, assert isinstance(self.original_space, (Dict, Tuple)), \ "`obs_space.original_space` must be [Dict|Tuple]!" - self.processed_obs_space = self.original_space if \ - model_config.get("_disable_preprocessor_api") else obs_space super().__init__(self.original_space, action_space, num_outputs, model_config, name) @@ -126,10 +124,8 @@ def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: - orig_obs = restore_original_dimensions( - input_dict[SampleBatch.OBS], - self.processed_obs_space, - tensorlib="tf") + orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], + self.obs_space, "tf") # Push image observations through our CNNs. outs = [] for i, component in enumerate(tree.flatten(orig_obs)): diff --git a/rllib/models/torch/complex_input_net.py b/rllib/models/torch/complex_input_net.py index ac053bab6ccf3..b795e4d5485c3 100644 --- a/rllib/models/torch/complex_input_net.py +++ b/rllib/models/torch/complex_input_net.py @@ -40,9 +40,6 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, assert isinstance(self.original_space, (Dict, Tuple)), \ "`obs_space.original_space` must be [Dict|Tuple]!" - self.processed_obs_space = self.original_space if \ - model_config.get("_disable_preprocessor_api") else obs_space - nn.Module.__init__(self) TorchModelV2.__init__(self, self.original_space, action_space, num_outputs, model_config, name) @@ -143,10 +140,8 @@ def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: - orig_obs = restore_original_dimensions( - input_dict[SampleBatch.OBS], - self.processed_obs_space, - tensorlib="torch") + orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], + self.obs_space, "tf") # Push image observations through our CNNs. outs = [] for i, component in enumerate(tree.flatten(orig_obs)): diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py index a7cc38cef6c8a..5cde72c4422e6 100644 --- a/rllib/models/torch/torch_modelv2.py +++ b/rllib/models/torch/torch_modelv2.py @@ -46,14 +46,6 @@ def __init__(self, *args, **kwargs): name, framework="torch") - # Dict to store per multi-gpu tower stats into. - # In PyTorch multi-GPU, we use a single TorchPolicy and copy - # it's Model(s) n times (1 copy for each GPU). When computing the loss - # on each tower, we cannot store the stats (e.g. `entropy`) inside the - # policy object as this would lead to race conditions between the - # different towers all accessing the same property at the same time. - self.tower_stats = {} - @override(ModelV2) def variables(self, as_dict: bool = False) -> \ Union[List[TensorType], Dict[str, TensorType]]: diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 76bb4c6bb666a..169dc0bad7f41 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -11,14 +11,13 @@ from ray.util.debug import log_once from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.repeated_values import RepeatedValues -from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import add_mixins, force_list from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.tf_ops import get_gpu_devices from ray.rllib.utils.threading import with_lock @@ -66,17 +65,15 @@ def convert_eager_inputs(func): @functools.wraps(func) def _func(*args, **kwargs): if tf.executing_eagerly(): - eager_args = [_convert_to_tf(x) for x in args] + args = [_convert_to_tf(x) for x in args] # TODO: (sven) find a way to remove key-specific hacks. - eager_kwargs = { + kwargs = { k: _convert_to_tf( v, dtype=tf.int64 if k == "timestep" else None) for k, v in kwargs.items() if k not in {"info_batch", "episodes"} } - return func(*eager_args, **eager_kwargs) - else: - return func(*args, **kwargs) + return func(*args, **kwargs) return _func @@ -185,14 +182,6 @@ def apply_gradients(self, grads): return TracedEagerPolicy -class OptimizerWrapper: - def __init__(self, tape): - self.tape = tape - - def compute_gradients(self, loss, var_list): - return list(zip(self.tape.gradient(loss, var_list), var_list)) - - def build_eager_tf_policy( name, loss_fn, @@ -334,11 +323,8 @@ def __init__(self, observation_space, action_space, config): if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) - - # The list of local (tf) optimizers (one per loss term). - self._optimizers: List[LocalOptimizer] = optimizers - # Backward compatibility: A user's policy may only support a single - # loss term and optimizer (no lists). + # TODO: (sven) Allow tf policy to have more than 1 optimizer. + # Just like torch Policy does. self._optimizer: LocalOptimizer = \ optimizers[0] if optimizers else None @@ -446,7 +432,6 @@ def compute_actions(self, lambda s: tf.convert_to_tensor(s), obs_batch), }, _is_training=tf.constant(False)) - self._lazy_tensor_dict(input_dict) if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = \ tf.convert_to_tensor(prev_action_batch) @@ -480,6 +465,7 @@ def compute_actions_from_input_dict( explore, timestep) @with_lock + @convert_eager_inputs @convert_eager_outputs def _compute_action_helper(self, input_dict, state_batches, episodes, explore, timestep): @@ -495,8 +481,7 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, self._is_training = False self._state_in = state_batches or [] # Calculate RNN sequence lengths. - batch_size = int( - tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]) + batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0] seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \ else None @@ -543,7 +528,7 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, dist_inputs, self.dist_class, state_out = \ action_distribution_fn( self, self.model, - input_dict[SampleBatch.OBS], + input_dict[SampleBatch.CUR_OBS], explore=explore, timestep=timestep, is_training=False) @@ -581,7 +566,7 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, extra_fetches.update(extra_action_out_fn(self)) # Update our global timestep by the batch size. - self.global_timestep += batch_size + self.global_timestep += int(batch_size) return actions, state_out, extra_fetches @@ -740,78 +725,51 @@ def export_checkpoint(self, export_dir): def _get_is_training_placeholder(self): return tf.convert_to_tensor(self._is_training) + def _apply_gradients(self, grads_and_vars): + if apply_gradients_fn: + apply_gradients_fn(self, self._optimizer, grads_and_vars) + else: + self._optimizer.apply_gradients( + [(g, v) for g, v in grads_and_vars if g is not None]) + @with_lock def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" - # Gather all variables for which to calculate losses. + with tf.GradientTape(persistent=compute_gradients_fn is not None) \ + as tape: + loss = loss_fn(self, self.model, self.dist_class, samples) + if isinstance(self.model, tf.keras.Model): variables = self.model.trainable_variables else: variables = self.model.trainable_variables() - # Calculate the loss(es) inside a tf GradientTape. - with tf.GradientTape(persistent=compute_gradients_fn is not None) \ - as tape: - losses = loss_fn(self, self.model, self.dist_class, samples) - losses = force_list(losses) - - # User provided a compute_gradients_fn. if compute_gradients_fn: - # Wrap our tape inside a wrapper, such that the resulting - # object looks like a "classic" tf.optimizer. This way, custom - # compute_gradients_fn will work on both tf static graph - # and tf-eager. - optimizer = OptimizerWrapper(tape) - # More than one loss terms/optimizers. - if self.config["_tf_policy_handles_more_than_one_loss"]: - grads_and_vars = compute_gradients_fn( - self, [optimizer] * len(losses), losses) - # Only one loss and one optimizer. - else: - grads_and_vars = [ - compute_gradients_fn(self, optimizer, losses[0]) - ] - # Default: Compute gradients using the above tape. + + class OptimizerWrapper: + def __init__(self, tape): + self.tape = tape + + def compute_gradients(self, loss, var_list): + return list( + zip(self.tape.gradient(loss, var_list), var_list)) + + grads_and_vars = compute_gradients_fn(self, + OptimizerWrapper(tape), + loss) else: - grads_and_vars = [ - list(zip(tape.gradient(loss, variables), variables)) - for loss in losses - ] + grads_and_vars = list( + zip(tape.gradient(loss, variables), variables)) if log_once("grad_vars"): - for g_and_v in grads_and_vars: - for g, v in g_and_v: - if g is not None: - logger.info(f"Optimizing variable {v.name}") - - # `grads_and_vars` is returned a list (len=num optimizers/losses) - # of lists of (grad, var) tuples. - if self.config["_tf_policy_handles_more_than_one_loss"]: - grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars] - # `grads_and_vars` is returned as a list of (grad, var) tuples. - else: - grads_and_vars = grads_and_vars[0] - grads = [g for g, _ in grads_and_vars] + for _, v in grads_and_vars: + logger.info("Optimizing variable {}".format(v.name)) + grads = [g for g, v in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, stats - def _apply_gradients(self, grads_and_vars): - if apply_gradients_fn: - if self.config["_tf_policy_handles_more_than_one_loss"]: - apply_gradients_fn(self, self._optimizers, grads_and_vars) - else: - apply_gradients_fn(self, self._optimizer, grads_and_vars) - else: - if self.config["_tf_policy_handles_more_than_one_loss"]: - for i, o in enumerate(self._optimizers): - o.apply_gradients([(g, v) for g, v in grads_and_vars[i] - if g is not None]) - else: - self._optimizer.apply_gradients( - [(g, v) for g, v in grads_and_vars if g is not None]) - def _stats(self, outputs, samples, grads): fetches = {} diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 6fd89f0117b97..3f75a8429c98a 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -14,8 +14,9 @@ from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import from_config -from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ - get_dummy_batch_for_space, unbatch +from ray.rllib.utils.spaces.space_utils import clip_action, \ + get_base_struct_from_space, get_dummy_batch_for_space, unbatch, \ + unsquash_action from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ TensorType, TensorStructType, TrainerConfigDict, Tuple, Union @@ -27,6 +28,10 @@ logger = logging.getLogger(__name__) +# By convention, metrics from optimizing the loss can be reported in the +# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. +LEARNER_STATS_KEY = "learner_stats" + # A policy spec used in the "config.multiagent.policies" specification dict # as values (keys are the policy IDs (str)). E.g.: # config: @@ -175,17 +180,16 @@ def compute_actions( @DeveloperAPI def compute_single_action( self, - obs: Optional[TensorStructType] = None, + obs: TensorStructType, state: Optional[List[TensorType]] = None, - *, prev_action: Optional[TensorStructType] = None, prev_reward: Optional[TensorStructType] = None, info: dict = None, - input_dict: Optional[SampleBatch] = None, episode: Optional["MultiAgentEpisode"] = None, + clip_actions: bool = None, explore: Optional[bool] = None, timestep: Optional[int] = None, - # Kwars placeholder for future compatibility. + unsquash_actions: bool = None, **kwargs) -> \ Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]: """Unbatched version of compute_actions. @@ -195,13 +199,14 @@ def compute_single_action( state: List of RNN state inputs, if any. prev_action: Previous action value, if any. prev_reward: Previous reward, if any. - info: Info object, if any. - input_dict: A SampleBatch or input dict containing the - single (unbatched) Tensors to compute actions. If given, it'll - be used instead of `obs`, `state`, `prev_action|reward`, and - `info`. - episode: This provides access to all of the internal episode state, - which may be useful for model-based or multi-agent algorithms. + info (dict): Info object, if any. + episode: this provides access to all + of the internal episode state, which may be useful for + model-based or multi-agent algorithms. + unsquash_actions: Should actions be unsquashed according to + the Policy's action space? + clip_actions: Should actions be clipped according to the + Policy's action space? explore: Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). @@ -215,37 +220,43 @@ def compute_single_action( - state_outs: List of RNN state outputs, if any. - info: Dictionary of extra features, if any. """ - # Build the input-dict used for the call to - # `self.compute_actions_from_input_dict()`. - if input_dict is None: - input_dict = {SampleBatch.OBS: obs} - if state is not None: - for i, s in enumerate(state): - input_dict[f"state_in_{i}"] = s - if prev_action is not None: - input_dict[SampleBatch.PREV_ACTIONS] = prev_action - if prev_reward is not None: - input_dict[SampleBatch.PREV_REWARDS] = prev_reward - if info is not None: - input_dict[SampleBatch.INFOS] = info - - # Batch all data in input dict. - input_dict = tree.map_structure_with_path( - lambda p, s: (s if p == "seq_lens" else s.unsqueeze(0) if - torch and isinstance(s, torch.Tensor) else - np.expand_dims(s, 0)), - input_dict) - + # If policy works in normalized space, we should unsquash the action. + # Use value of config.normalize_actions, if None. + unsquash_actions = \ + unsquash_actions if unsquash_actions is not None \ + else self.config["normalize_actions"] + clip_actions = clip_actions if clip_actions is not None else \ + self.config["clip_actions"] + + prev_action_batch = None + prev_reward_batch = None + info_batch = None episodes = None + state_batch = None + if prev_action is not None: + prev_action_batch = [prev_action] + if prev_reward is not None: + prev_reward_batch = [prev_reward] + if info is not None: + info_batch = [info] if episode is not None: episodes = [episode] - - out = self.compute_actions_from_input_dict( - input_dict=SampleBatch(input_dict), + if state is not None: + state_batch = [ + s.unsqueeze(0) + if torch and isinstance(s, torch.Tensor) else np.expand_dims( + s, 0) for s in state + ] + + out = self.compute_actions( + tree.map_structure(lambda s: np.array([s]), obs), + state_batch, + prev_action_batch=prev_action_batch, + prev_reward_batch=prev_reward_batch, + info_batch=info_batch, episodes=episodes, explore=explore, - timestep=timestep, - ) + timestep=timestep) # Some policies don't return a tuple, but always just a single action. # E.g. ES and ARS. @@ -260,6 +271,16 @@ def compute_single_action( assert len(single_action) == 1 single_action = single_action[0] + # If we work in normalized action space (normalize_actions=True), + # we re-translate here into the env's action space. + if unsquash_actions: + single_action = unsquash_action(single_action, + self.action_space_struct) + # Clip, according to env's action space. + elif clip_actions: + single_action = clip_action(single_action, + self.action_space_struct) + # Return action, internal state(s), infos. return single_action, [s[0] for s in state_out], \ {k: v[0] for k, v in info.items()} @@ -267,7 +288,7 @@ def compute_single_action( @DeveloperAPI def compute_actions_from_input_dict( self, - input_dict: Union[SampleBatch, Dict[str, TensorStructType]], + input_dict: SampleBatch, explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["MultiAgentEpisode"]] = None, @@ -279,19 +300,14 @@ def compute_actions_from_input_dict( to construct the input_dict for the Model. Args: - input_dict: A SampleBatch or input dict containing the Tensors + input_dict (SampleBatch): A SampleBatch containing the Tensors to compute actions. `input_dict` already abides to the Policy's as well as the Model's view requirements and can thus be passed to the Model as-is. - explore: Whether to pick an exploitation or exploration + explore (bool): Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). - timestep: The current (sampling) time step. - episodes: This provides access to all of the internal episodes' - state, which may be useful for model-based or multi-agent - algorithms. - - Keyword Args: - kwargs: Forward compatibility placeholder. + timestep (Optional[int]): The current (sampling) time step. + kwargs: forward compatibility placeholder Returns: Tuple: diff --git a/rllib/policy/policy_template.py b/rllib/policy/policy_template.py index d3463df7eaf71..ea231ed2abc8f 100644 --- a/rllib/policy/policy_template.py +++ b/rllib/policy/policy_template.py @@ -7,13 +7,12 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils import add_mixins, force_list, NullContextManager from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch, try_import_jax -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.torch_ops import convert_to_non_torch_type from ray.rllib.utils.typing import ModelGradients, TensorType, \ TrainerConfigDict diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 389278a1a4328..9192d5ba6d4d5 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -183,7 +183,7 @@ def concat_samples( >>> print(SampleBatch.concat_samples([b1, b2])) {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])} """ - if any(isinstance(s, MultiAgentBatch) for s in samples): + if isinstance(samples[0], MultiAgentBatch): return MultiAgentBatch.concat_samples(samples) concatd_seq_lens = [] concat_samples = [] @@ -1171,12 +1171,7 @@ def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch": policy_batches = collections.defaultdict(list) env_steps = 0 for s in samples: - # Some batches in `samples` are not MultiAgentBatch. if not isinstance(s, MultiAgentBatch): - # If empty SampleBatch: ok (just ignore). - if isinstance(s, SampleBatch) and len(s) <= 0: - continue - # Otherwise: Error. raise ValueError( "`MultiAgentBatch.concat_samples()` can only concat " "MultiAgentBatch types, not {}!".format(type(s).__name__)) diff --git a/rllib/policy/tests/test_compute_log_likelihoods.py b/rllib/policy/tests/test_compute_log_likelihoods.py index 330ea381bddf9..52259d6ea6e60 100644 --- a/rllib/policy/tests/test_compute_log_likelihoods.py +++ b/rllib/policy/tests/test_compute_log_likelihoods.py @@ -57,7 +57,7 @@ def do_test_log_likelihood(run, explore=True, # Do not unsquash actions # (remain in normalized [-1.0; 1.0] space). - unsquash_action=False, + unsquash_actions=False, )) # Test all taken actions for their log-likelihoods vs expected values. diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 4f4deb15c05e3..bebc9fa185b26 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -10,16 +10,15 @@ import ray import ray.experimental.tf_utils from ray.util.debug import log_once -from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils import force_list -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override +from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import summarize -from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.annotations import Deprecated from ray.rllib.utils.framework import try_import_tf, get_variable -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.tf_ops import get_gpu_devices @@ -424,18 +423,14 @@ def compute_actions( timestep = timestep if timestep is not None else self.global_timestep builder = TFRunBuilder(self.get_session(), "compute_actions") - - input_dict = {SampleBatch.OBS: obs_batch} - if state_batches: - for i, s in enumerate(state_batches): - input_dict[f"state_in_{i}"] = s - if prev_action_batch is not None: - input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch - if prev_reward_batch is not None: - input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch - to_fetch = self._build_compute_actions( - builder, input_dict=input_dict, explore=explore, timestep=timestep) + builder, + obs_batch=obs_batch, + state_batches=state_batches, + prev_action_batch=prev_action_batch, + prev_reward_batch=prev_reward_batch, + explore=explore, + timestep=timestep) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) @@ -1010,12 +1005,6 @@ def _build_compute_actions(self, # TODO: (sven) This can be deprecated after trajectory view API flag is # removed and always True. else: - if log_once("_build_compute_actions_input_dict"): - deprecation_warning( - old="_build_compute_actions(.., obs_batch=.., ..)", - new="_build_compute_actions(.., input_dict=..)", - error=False, - ) state_batches = state_batches or [] if len(self._state_inputs) != len(state_batches): raise ValueError( diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index f2ec7dfaadcc7..fb7e9519ec878 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -6,16 +6,15 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy from ray.rllib.policy import eager_tf_policy -from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils import add_mixins, force_list from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY -from ray.rllib.utils.typing import AgentID, ModelGradients, TensorType, \ - TrainerConfigDict +from ray.rllib.utils.typing import AgentID, ModelGradients, PolicyID, \ + TensorType, TrainerConfigDict if TYPE_CHECKING: from ray.rllib.evaluation import MultiAgentEpisode @@ -54,7 +53,7 @@ def build_tf_policy( extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[ str, TensorType]]] = None, validate_spaces: Optional[Callable[ - [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, + [PolicyID, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_loss_init: Optional[Callable[[ diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index bf1c69410ff83..f50729d005ed2 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -5,23 +5,21 @@ import math import numpy as np import os -import threading import time -import tree # pip install dm_tree +import threading from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, \ TYPE_CHECKING import ray from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.utils import force_list, NullContextManager from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.threading import with_lock @@ -705,34 +703,6 @@ def apply_gradients(self, gradients: ModelGradients) -> None: self._optimizers[0].step() - @DeveloperAPI - def get_tower_stats(self, stats_name: str) -> List[TensorStructType]: - """Returns list of per-tower stats, copied to this Policy's device. - - Args: - stats_name: The name of the stats to average over (this str - must exist as a key inside each tower's `tower_stats` dict). - - Returns: - The list of stats tensor (structs) of all towers, copied to this - Policy's device. - - Raises: - AssertionError: If the `stats_name` cannot be found in any one - of the tower's `tower_stats` dicts. - """ - data = [] - for tower in self.model_gpu_towers: - if stats_name in tower.tower_stats: - data.append( - tree.map_structure(lambda s: s.to(self.device), - tower.tower_stats[stats_name])) - assert len(data) > 0, \ - f"Stats `{stats_name}` not found in any of the towers (you have " \ - f"{len(self.model_gpu_towers)} towers in total)! Make " \ - "sure you call the loss function on at least one of the towers." - return data - @override(Policy) @DeveloperAPI def get_weights(self) -> ModelWeights: diff --git a/rllib/tests/test_exec_api.py b/rllib/tests/test_exec_api.py index 11339f08640b5..b415c4faadf46 100644 --- a/rllib/tests/test_exec_api.py +++ b/rllib/tests/test_exec_api.py @@ -4,7 +4,6 @@ from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ STEPS_TRAINED_COUNTER -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import framework_iterator @@ -30,7 +29,7 @@ def test_exec_plan_stats(ray_start_regular): result = trainer.train() assert isinstance(result, dict) assert "info" in result - assert LEARNER_INFO in result["info"] + assert "learner" in result["info"] assert STEPS_SAMPLED_COUNTER in result["info"] assert STEPS_TRAINED_COUNTER in result["info"] assert "timers" in result diff --git a/rllib/tests/test_supported_multi_agent.py b/rllib/tests/test_supported_multi_agent.py index 2c114cec4d02f..0f4063bb2e886 100644 --- a/rllib/tests/test_supported_multi_agent.py +++ b/rllib/tests/test_supported_multi_agent.py @@ -4,9 +4,7 @@ from ray.rllib.agents.registry import get_trainer_class from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ MultiAgentMountainCar -from ray.rllib.policy.policy import PolicySpec -from ray.rllib.utils.test_utils import check_train_results, \ - framework_iterator +from ray.rllib.utils.test_utils import framework_iterator from ray.tune import register_env @@ -15,23 +13,7 @@ def check_support_multiagent(alg, config): lambda _: MultiAgentMountainCar({"num_agents": 2})) register_env("multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})) - - # Simulate a simple multi-agent setup. - policies = { - "policy_0": PolicySpec(config={"gamma": 0.99}), - "policy_1": PolicySpec(config={"gamma": 0.95}), - } - policy_ids = list(policies.keys()) - - def policy_mapping_fn(agent_id, episode, worker, **kwargs): - pol_id = policy_ids[agent_id] - return pol_id - - config["multiagent"] = { - "policies": policies, - "policy_mapping_fn": policy_mapping_fn, - } - + config["log_level"] = "ERROR" for fw in framework_iterator(config): if fw in ["tf2", "tfe"] and \ alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]: @@ -43,9 +25,7 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): a = get_trainer_class(alg)( config=config, env="multi_agent_cartpole") - results = a.train() - check_train_results(results) - print(results) + print(a.train()) a.stop() diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index d290d3ef87f68..993558e77d223 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -69,11 +69,6 @@ def _do_check(alg, config, a_name, o_name): try: a = get_trainer_class(alg)(config=config, env=RandomEnv) - except ray.exceptions.RayActorError as e: - if isinstance(e.args[2], UnsupportedSpaceException): - stat = "unsupported" - else: - raise except UnsupportedSpaceException: stat = "unsupported" else: @@ -104,11 +99,10 @@ def _do_check(alg, config, a_name, o_name): _do_check(alg, config, a_name, o_name) # Do the remaining obs spaces. assert len(OBSERVATION_SPACES_TO_TEST) >= len(ACTION_SPACES_TO_TEST) - fixed_action_key = next(iter(ACTION_SPACES_TO_TEST.keys())) for i, o_name in enumerate(OBSERVATION_SPACES_TO_TEST.keys()): if i < len(ACTION_SPACES_TO_TEST): continue - _do_check(alg, config, fixed_action_key, o_name) + _do_check(alg, config, "discrete", o_name) class TestSupportedSpacesPG(unittest.TestCase): diff --git a/rllib/utils/__init__.py b/rllib/utils/__init__.py index e720bfebfc468..4f1f33083f01c 100644 --- a/rllib/utils/__init__.py +++ b/rllib/utils/__init__.py @@ -11,7 +11,7 @@ from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \ PolynomialSchedule, ExponentialSchedule, ConstantSchedule from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - check_train_results, framework_iterator + framework_iterator from ray.tune.utils import merge_dicts, deep_update @@ -77,7 +77,6 @@ def __exit__(self, *args): "add_mixins", "check", "check_compute_single_action", - "check_train_results", "deep_update", "deprecation_warning", "fc", diff --git a/rllib/utils/exploration/stochastic_sampling.py b/rllib/utils/exploration/stochastic_sampling.py index 593233625de15..daa6089d483b4 100644 --- a/rllib/utils/exploration/stochastic_sampling.py +++ b/rllib/utils/exploration/stochastic_sampling.py @@ -1,7 +1,7 @@ import functools import gym import numpy as np -from typing import Optional, Union +from typing import Union from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 @@ -61,12 +61,11 @@ def __init__(self, dtype=np.int64) @override(Exploration) - def get_exploration_action( - self, - *, - action_distribution: ActionDistribution, - timestep: Optional[Union[int, TensorType]] = None, - explore: bool = True): + def get_exploration_action(self, + *, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: bool = True): if self.framework == "torch": return self._get_torch_exploration_action(action_distribution, timestep, explore) @@ -75,7 +74,7 @@ def get_exploration_action( timestep, explore) def _get_tf_exploration_action_op(self, action_dist, timestep, explore): - ts = self.last_timestep + 1 + ts = timestep if timestep is not None else self.last_timestep + 1 stochastic_actions = tf.cond( pred=tf.convert_to_tensor(ts < self.random_timesteps), @@ -101,7 +100,10 @@ def _get_tf_exploration_action_op(self, action_dist, timestep, explore): # Increment `last_timestep` by 1 (or set to `timestep`). if self.framework in ["tf2", "tfe"]: - self.last_timestep.assign_add(1) + if timestep is None: + self.last_timestep.assign_add(1) + else: + self.last_timestep.assign(timestep) return action, logp else: assign_op = (tf1.assign_add(self.last_timestep, 1) diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/rllib/utils/metrics/learner_info.py b/rllib/utils/metrics/learner_info.py deleted file mode 100644 index ebe44a7c9fcda..0000000000000 --- a/rllib/utils/metrics/learner_info.py +++ /dev/null @@ -1,84 +0,0 @@ -from collections import defaultdict -import numpy as np -import tree # pip install dm_tree -from typing import Dict - -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.typing import PolicyID - -# Instant metrics (keys for metrics.info). -LEARNER_INFO = "learner" -# By convention, metrics from optimizing the loss can be reported in the -# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. -LEARNER_STATS_KEY = "learner_stats" - - -class LearnerInfoBuilder: - def __init__(self, num_devices: int = 1): - self.num_devices = num_devices - self.results_all_towers = defaultdict(list) - self.is_finalized = False - - def add_learn_on_batch_results( - self, - results: Dict, - policy_id: PolicyID = DEFAULT_POLICY_ID, - ) -> None: - """Adds a policy.learn_on_(loaded)?_batch() result to this builder. - - Args: - results: The results returned by Policy.learn_on_batch or - Policy.learn_on_loaded_batch. - policy_id: The policy's ID, whose learn_on_(loaded)_batch method - returned `results`. - """ - assert not self.is_finalized, \ - "LearnerInfo already finalized! Cannot add more results." - - # No towers: Single CPU. - if "tower_0" not in results: - self.results_all_towers[policy_id].append(results) - # Multi-GPU case: - else: - self.results_all_towers[policy_id].append( - tree.map_structure_with_path( - lambda p, *s: all_tower_reduce(p, *s), - *(results.pop("tower_{}".format(tower_num)) - for tower_num in range(self.num_devices)))) - for k, v in results.items(): - if k == LEARNER_STATS_KEY: - for k1, v1 in results[k].items(): - self.results_all_towers[policy_id][-1][ - LEARNER_STATS_KEY][k1] = v1 - else: - self.results_all_towers[policy_id][-1][k] = v - - def finalize(self): - self.is_finalized = True - - info = {} - for policy_id, results_all_towers in self.results_all_towers.items(): - # Reduce mean across all minibatch SGD steps (axis=0 to keep - # all shapes as-is). - info[policy_id] = tree.map_structure( - lambda *s: None if s[0] is None else np.nanmean(s, axis=0), - *results_all_towers) - - return info - - -def all_tower_reduce(path, *tower_data): - """Reduces stats across towers based on their stats-dict paths.""" - # TD-errors: Need to stay per batch item in order to be able to update - # each item's weight in a prioritized replay buffer. - if len(path) == 1 and path[0] == "td_error": - return np.concatenate(tower_data, axis=0) - - # Min stats: Reduce min. - if path[-1].startswith("min_"): - return np.nanmin(tower_data) - # Max stats: Reduce max. - elif path[-1].startswith("max_"): - return np.nanmax(tower_data) - # Everything else: Reduce mean. - return np.nanmean(tower_data) diff --git a/rllib/utils/multi_agent.py b/rllib/utils/multi_agent.py index 50d5227c54e75..b23726cb393db 100644 --- a/rllib/utils/multi_agent.py +++ b/rllib/utils/multi_agent.py @@ -1,13 +1,9 @@ -from typing import Tuple - from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.typing import MultiAgentPolicyConfigDict, \ - PartialTrainerConfigDict +from ray.rllib.utils.typing import PartialTrainerConfigDict -def check_multi_agent(config: PartialTrainerConfigDict) -> \ - Tuple[MultiAgentPolicyConfigDict, bool]: +def check_multi_agent(config: PartialTrainerConfigDict): """Checks, whether a (partial) config defines a multi-agent setup. Args: @@ -15,25 +11,18 @@ def check_multi_agent(config: PartialTrainerConfigDict) -> \ to check for multi-agent. Returns: - The resulting (all fixed) multi-agent policy dict and whether we - have a multi-agent setup or not. + Tuple[MultiAgentPolicyConfigDict, bool]: The resulting (all + fixed) multi-agent policy dict and whether we have a + multi-agent setup or not. """ multiagent_config = config["multiagent"] policies = multiagent_config.get("policies") - - # Nothing specified in config dict -> Assume simple single agent setup - # with DEFAULT_POLICY_ID as only policy. if not policies: policies = {DEFAULT_POLICY_ID} - # Policies given as set (of PolicyIDs) -> Setup each policy automatically - # via empty PolicySpec (will make RLlib infer obs- and action spaces - # as well as the Policy's class). if isinstance(policies, set): policies = multiagent_config["policies"] = { pid: PolicySpec() for pid in policies } - # Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only - # PolicyID found in policies dict. is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies return policies, is_multiagent diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index 6b4f060a95598..b163c2a36fcd4 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -1,17 +1,38 @@ """Utils for minibatch SGD across multiple RLlib policies.""" -import logging import numpy as np +import logging +from collections import defaultdict import random +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, \ MultiAgentBatch -from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder logger = logging.getLogger(__name__) -def standardized(array: np.ndarray): +def averaged(kv, axis=None): + """Average the value lists of a dictionary. + + For non-scalar values, we simply pick the first value. + + Args: + kv (dict): dictionary with values that are lists of floats. + + Returns: + dictionary with single averaged float as values. + """ + out = {} + for k, v in kv.items(): + if v[0] is not None and not isinstance(v[0], dict): + out[k] = np.mean(v, axis=axis) + else: + out[k] = v[0] + return out + + +def standardized(array): """Normalize the values in an array. Args: @@ -86,12 +107,7 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) - # Use LearnerInfoBuilder as a unified way to build the final - # results dict from `learn_on_loaded_batch` call(s). - # This makes sure results dicts always have the same structure - # no matter the setup (multi-GPU, multi-agent, minibatch SGD, - # tf vs torch). - learner_info_builder = LearnerInfoBuilder(num_devices=1) + fetches = defaultdict(dict) for policy_id in policies.keys(): if policy_id not in samples.policy_batches: continue @@ -100,14 +116,23 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, for field in standardize_fields: batch[field] = standardized(batch[field]) + learner_stats = defaultdict(list) + model_stats = defaultdict(list) + custom_callbacks_stats = defaultdict(list) + for i in range(num_sgd_iter): for minibatch in minibatches(batch, sgd_minibatch_size): - results = (local_worker.learn_on_batch( + batch_fetches = (local_worker.learn_on_batch( MultiAgentBatch({ policy_id: minibatch }, minibatch.count)))[policy_id] - learner_info_builder.add_learn_on_batch_results( - results, policy_id) - - learner_info = learner_info_builder.finalize() - return learner_info + for k, v in batch_fetches.get(LEARNER_STATS_KEY, {}).items(): + learner_stats[k].append(v) + for k, v in batch_fetches.get("model", {}).items(): + model_stats[k].append(v) + for k, v in batch_fetches.get("custom_metrics", {}).items(): + custom_callbacks_stats[k].append(v) + fetches[policy_id][LEARNER_STATS_KEY] = averaged(learner_stats) + fetches[policy_id]["model"] = averaged(model_stats) + fetches[policy_id]["custom_metrics"] = averaged(custom_callbacks_stats) + return fetches diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 5fcb16da6471e..f119d3806968f 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -1,12 +1,10 @@ from collections import Counter import copy -from gym.spaces import Box +import gym import logging import numpy as np -import random import re import time -import tree # pip install dm_tree from typing import Any, Dict, List import yaml @@ -31,8 +29,7 @@ def framework_iterator(config=None, frameworks=("tf2", "tf", "tfe", "torch"), - session=False, - with_eager_tracing=False): + session=False): """An generator that allows for looping through n frameworks for testing. Provides the correct config entries ("framework") as well @@ -47,8 +44,6 @@ def framework_iterator(config=None, and yield that as second return value (otherwise yield (fw, None)). Also sets a seed (42) on the session to make the test deterministic. - with_eager_tracing: Include `eager_tracing=True` in the returned - configs, when framework=[tfe|tf2]. Yields: str: If enter_session is False: @@ -108,15 +103,7 @@ def framework_iterator(config=None, elif fw == "tf": assert not tf1.executing_eagerly() - # Additionally loop through eager_tracing=True + False, if necessary. - if fw in ["tf2", "tfe"] and with_eager_tracing: - for tracing in [True, False]: - config["eager_tracing"] = tracing - yield fw if session is False else (fw, sess) - config["eager_tracing"] = False - # Yield current framework + tf-session (if necessary). - else: - yield fw if session is False else (fw, sess) + yield fw if session is False else (fw, sess) # Exit any context we may have entered. if eager_ctx: @@ -273,6 +260,31 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False): "ERROR: x ({}) is the same as y ({})!".format(x, y) +def check_learning_achieved(tune_results, min_reward, evaluation=False): + """Throws an error if `min_reward` is not reached within tune_results. + + Checks the last iteration found in tune_results for its + "episode_reward_mean" value and compares it to `min_reward`. + + Args: + tune_results: The tune.run returned results object. + min_reward (float): The min reward that must be reached. + + Raises: + ValueError: If `min_reward` not reached. + """ + # Get maximum reward of all trials + # (check if at least one trial achieved some learning) + avg_rewards = [(trial.last_result["episode_reward_mean"] + if not evaluation else + trial.last_result["evaluation"]["episode_reward_mean"]) + for trial in tune_results.trials] + best_avg_reward = max(avg_rewards) + if best_avg_reward < min_reward: + raise ValueError("`stop-reward` of {} not reached!".format(min_reward)) + print("ok") + + def check_compute_single_action(trainer, include_state=False, include_prev_action_reward=False): @@ -288,120 +300,17 @@ def check_compute_single_action(trainer, Raises: ValueError: If anything unexpected happens. """ - # Have to import this here to avoid circular dependency. - from ray.rllib.policy.sample_batch import SampleBatch - - # Some Trainers may not abide to the standard API. try: pol = trainer.get_policy() except AttributeError: pol = trainer.policy - # Get the policy's model. model = pol.model action_space = pol.action_space - def _test(what, method_to_test, obs_space, full_fetch, explore, timestep, - unsquash, clip): - call_kwargs = {} - if what is trainer: - call_kwargs["full_fetch"] = full_fetch - - obs = obs_space.sample() - if isinstance(obs_space, Box): - obs = np.clip(obs, -1.0, 1.0) - state_in = None - if include_state: - state_in = model.get_initial_state() - if not state_in: - state_in = [] - i = 0 - while f"state_in_{i}" in model.view_requirements: - state_in.append(model.view_requirements[f"state_in_{i}"] - .space.sample()) - i += 1 - action_in = action_space.sample() \ - if include_prev_action_reward else None - reward_in = 1.0 if include_prev_action_reward else None - - if method_to_test == "input_dict": - assert what is pol - - input_dict = {SampleBatch.OBS: obs} - if include_prev_action_reward: - input_dict[SampleBatch.PREV_ACTIONS] = action_in - input_dict[SampleBatch.PREV_REWARDS] = reward_in - if state_in: - for i, s in enumerate(state_in): - input_dict[f"state_in_{i}"] = s - input_dict_batched = SampleBatch( - tree.map_structure(lambda s: np.expand_dims(s, 0), input_dict)) - action = pol.compute_actions_from_input_dict( - input_dict=input_dict_batched, - explore=explore, - timestep=timestep, - **call_kwargs) - # Unbatch everything to be able to compare against single - # action below. - # ARS and ES return action batches as lists. - if isinstance(action[0], list): - action = (np.array(action[0]), action[1], action[2]) - action = tree.map_structure(lambda s: s[0], action) - - try: - action2 = pol.compute_single_action( - input_dict=input_dict, - explore=explore, - timestep=timestep, - **call_kwargs) - # Make sure these are the same, unless we have exploration - # switched on (or noisy layers). - if not explore and not pol.config.get("noisy"): - check(action, action2) - except TypeError: - pass - else: - action = what.compute_single_action( - obs, - state_in, - prev_action=action_in, - prev_reward=reward_in, - explore=explore, - timestep=timestep, - unsquash_action=unsquash, - clip_action=clip, - **call_kwargs) - - state_out = None - if state_in or full_fetch or what is pol: - action, state_out, _ = action - if state_out: - for si, so in zip(state_in, state_out): - check(list(si.shape), so.shape) - - # Test whether unsquash/clipping works on the Trainer's - # compute_single_action method: Both flags should force the action - # to be within the space's bounds. - if method_to_test == "single" and what == trainer: - if not action_space.contains(action) and \ - (clip or unsquash or not isinstance(action_space, Box)): - raise ValueError( - f"Returned action ({action}) of trainer/policy {what} " - f"not in Env's action_space {action_space}") - # We are operating in normalized space: Expect only smaller action - # values. - if isinstance(action_space, Box) and not unsquash and \ - what.config.get("normalize_actions") and \ - np.any(np.abs(action) > 3.0): - raise ValueError( - f"Returned action ({action}) of trainer/policy {what} " - "should be in normalized space, but seems too large/small " - "for that!") - - # Loop through: Policy vs Trainer; Different API methods to calculate - # actions; unsquash option; clip option; full fetch or not. for what in [pol, trainer]: if what is trainer: + method_to_test = trainer.compute_single_action # Get the obs-space from Workers.env (not Policy) due to possible # pre-processor up front. worker_set = getattr(trainer, "workers", @@ -414,134 +323,53 @@ def _test(what, method_to_test, obs_space, full_fetch, explore, timestep, lambda p: p.observation_space) obs_space = getattr(obs_space, "original_space", obs_space) else: + method_to_test = pol.compute_single_action obs_space = pol.observation_space - for method_to_test in ["single"] + \ - (["input_dict"] if what is pol else []): - for explore in [True, False]: - for full_fetch in ([False, True] - if what is trainer else [False]): - timestep = random.randint(0, 100000) - for unsquash in [True, False]: - for clip in ([False] if unsquash else [True, False]): - _test(what, method_to_test, obs_space, full_fetch, - explore, timestep, unsquash, clip) - - -def check_learning_achieved(tune_results, min_reward, evaluation=False): - """Throws an error if `min_reward` is not reached within tune_results. - - Checks the last iteration found in tune_results for its - "episode_reward_mean" value and compares it to `min_reward`. - - Args: - tune_results: The tune.run returned results object. - min_reward (float): The min reward that must be reached. - - Raises: - ValueError: If `min_reward` not reached. - """ - # Get maximum reward of all trials - # (check if at least one trial achieved some learning) - avg_rewards = [(trial.last_result["episode_reward_mean"] - if not evaluation else - trial.last_result["evaluation"]["episode_reward_mean"]) - for trial in tune_results.trials] - best_avg_reward = max(avg_rewards) - if best_avg_reward < min_reward: - raise ValueError("`stop-reward` of {} not reached!".format(min_reward)) - print("ok") - - -def check_train_results(train_results): - """Checks proper structure of a Trainer.train() returned dict. + for explore in [True, False]: + for full_fetch in ([False, True] if what is trainer else [False]): + call_kwargs = {} + if what is trainer: + call_kwargs["full_fetch"] = full_fetch + else: + call_kwargs["clip_actions"] = True + + obs = obs_space.sample() + if isinstance(obs_space, gym.spaces.Box): + obs = np.clip(obs, -1.0, 1.0) + state_in = None + if include_state: + state_in = model.get_initial_state() + if not state_in: + state_in = [] + i = 0 + while f"state_in_{i}" in model.view_requirements: + state_in.append(model.view_requirements[ + f"state_in_{i}"].space.sample()) + i += 1 + action_in = action_space.sample() \ + if include_prev_action_reward else None + reward_in = 1.0 if include_prev_action_reward else None + action = method_to_test( + obs, + state_in, + prev_action=action_in, + prev_reward=reward_in, + explore=explore, + **call_kwargs) - Args: - train_results: The train results dict to check. + state_out = None + if state_in or full_fetch or what is pol: + action, state_out, _ = action + if state_out: + for si, so in zip(state_in, state_out): + check(list(si.shape), so.shape) - Raises: - AssertionError: If `train_results` doesn't have the proper structure or - data in it. - """ - # Import these here to avoid circular dependencies. - from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID - from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ - LEARNER_STATS_KEY - from ray.rllib.utils.multi_agent import check_multi_agent - - # Assert that some keys are where we would expect them. - for key in [ - "agent_timesteps_total", - "config", - "custom_metrics", - "episode_len_mean", - "episode_reward_max", - "episode_reward_mean", - "episode_reward_min", - "episodes_total", - "hist_stats", - "info", - "iterations_since_restore", - "num_healthy_workers", - "perf", - "policy_reward_max", - "policy_reward_mean", - "policy_reward_min", - "sampler_perf", - "time_since_restore", - "time_this_iter_s", - "timesteps_since_restore", - "timesteps_total", - "timers", - "time_total_s", - "training_iteration", - ]: - assert key in train_results, \ - f"'{key}' not found in `train_results` ({train_results})!" - - _, is_multi_agent = check_multi_agent(train_results["config"]) - - # Check in particular the "info" dict. - info = train_results["info"] - assert LEARNER_INFO in info, \ - f"'learner' not in train_results['infos'] ({info})!" - assert "num_steps_trained" in info,\ - f"'num_steps_trained' not in train_results['infos'] ({info})!" - - learner_info = info[LEARNER_INFO] - - # Make sure we have a default_policy key if we are not in a - # multi-agent setup. - if not is_multi_agent: - # APEX algos sometimes have an empty learner info dict (no metrics - # collected yet). - assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, \ - f"'{DEFAULT_POLICY_ID}' not found in " \ - f"train_results['infos']['learner'] ({learner_info})!" - - for pid, policy_stats in learner_info.items(): - if pid == "batch_count": - continue - # Expect td-errors to be per batch-item. - if "td_error" in policy_stats: - configured_b = train_results["config"]["train_batch_size"] - actual_b = policy_stats["td_error"].shape[0] - # R2D2 case. - if (configured_b - actual_b) / actual_b > 0.1: - assert configured_b / ( - train_results["config"]["model"]["max_seq_len"] + - train_results["config"]["burn_in"]) == actual_b - - # Make sure each policy has the LEARNER_STATS_KEY under it. - assert LEARNER_STATS_KEY in policy_stats - learner_stats = policy_stats[LEARNER_STATS_KEY] - for key, value in learner_stats.items(): - # Min- and max-stats should be single values. - if key.startswith("min_") or key.startswith("max_"): - assert np.isscalar( - value), f"'key' value not a scalar ({value})!" - - return train_results + if not action_space.contains(action): + raise ValueError( + "Returned action ({}) of trainer/policy {} not in " + "Env's action_space " + "({})!".format(action, what, action_space)) def run_learning_tests_from_yaml( diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index 1b577be7ef727..20b0ea3d75f98 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -146,7 +146,7 @@ def zero_logps_from_actions(actions: TensorStructType) -> TensorType: # `deterministic_actions` or `stochastic_actions`). In case # actions are just [B], zeros_like works just fine here, but if # actions are [B, ...], we have to reduce logp back to just [B]. - while len(logp_.shape) > 1: + if len(logp_.shape) > 1: logp_ = logp_[:, 0] return logp_ diff --git a/rllib/utils/tf_run_builder.py b/rllib/utils/tf_run_builder.py index 28a48558f73e7..82b904bd13164 100644 --- a/rllib/utils/tf_run_builder.py +++ b/rllib/utils/tf_run_builder.py @@ -59,10 +59,7 @@ def get(self, to_fetch): _count = 0 -def run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None): - if feed_dict is None: - feed_dict = {} - +def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None): if timeline_dir: from tensorflow.python.client import timeline diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index 90ccc64aad126..a27be53cc2695 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -48,8 +48,8 @@ def atanh(x): def concat_multi_gpu_td_errors(policy): td_error = torch.cat( [ - t.tower_stats.get("td_error", torch.tensor([0.0])).to( - policy.device) for t in policy.model_gpu_towers + getattr(t, "td_error", torch.tensor([0.0])).to(policy.device) + for t in policy.model_gpu_towers ], dim=0) policy.td_error = td_error @@ -132,7 +132,7 @@ def explained_variance(y, pred): y_var = torch.var(y, dim=[0]) diff_var = torch.var(y - pred, dim=[0]) min_ = torch.tensor([-1.0]).to(pred.device) - return torch.max(min_, 1 - (diff_var / y_var))[0] + return torch.max(min_, 1 - (diff_var / y_var)) def global_norm(tensors): diff --git a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h index 56be36f4c87ff..483464c1ff6eb 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h @@ -17,7 +17,6 @@ namespace gcs { class MockGcsNodeManager : public GcsNodeManager { public: - MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr) {} MOCK_METHOD(void, HandleRegisterNode, (const rpc::RegisterNodeRequest &request, rpc::RegisterNodeReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h b/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h index 627e3357879e7..f612e6d1d2841 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h +++ b/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h @@ -1,4 +1,4 @@ -// Copyright 2021 The Ray Authors. +// Copyright The Ray Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,8 +30,8 @@ class MockGcsPlacementGroupSchedulerInterface public: MOCK_METHOD(void, ScheduleUnplacedBundles, (std::shared_ptr placement_group, - PGSchedulingFailureCallback failure_callback, - PGSchedulingSuccessfulCallback success_callback), + std::function)> failure_callback, + std::function)> success_callback), (override)); MOCK_METHOD((absl::flat_hash_map>), GetBundlesOnNode, (const NodeID &node_id), (override)); @@ -63,12 +63,11 @@ namespace gcs { class MockGcsScheduleStrategy : public GcsScheduleStrategy { public: - MOCK_METHOD( - ScheduleResult, Schedule, - (const std::vector> &bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD(ScheduleMap, Schedule, + (std::vector> & bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -79,12 +78,11 @@ namespace gcs { class MockGcsPackStrategy : public GcsPackStrategy { public: - MOCK_METHOD( - ScheduleResult, Schedule, - (const std::vector> &bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD(ScheduleMap, Schedule, + (std::vector> & bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -95,12 +93,11 @@ namespace gcs { class MockGcsSpreadStrategy : public GcsSpreadStrategy { public: - MOCK_METHOD( - ScheduleResult, Schedule, - (const std::vector> &bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD(ScheduleMap, Schedule, + (std::vector> & bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -111,12 +108,11 @@ namespace gcs { class MockGcsStrictPackStrategy : public GcsStrictPackStrategy { public: - MOCK_METHOD( - ScheduleResult, Schedule, - (const std::vector> &bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD(ScheduleMap, Schedule, + (std::vector> & bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -127,12 +123,11 @@ namespace gcs { class MockGcsStrictSpreadStrategy : public GcsStrictSpreadStrategy { public: - MOCK_METHOD( - ScheduleResult, Schedule, - (const std::vector> &bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD(ScheduleMap, Schedule, + (std::vector> & bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -165,8 +160,8 @@ class MockGcsPlacementGroupScheduler : public GcsPlacementGroupScheduler { public: MOCK_METHOD(void, ScheduleUnplacedBundles, (std::shared_ptr placement_group, - PGSchedulingFailureCallback failure_handler, - PGSchedulingSuccessfulCallback success_handler), + std::function)> failure_handler, + std::function)> success_handler), (override)); MOCK_METHOD(void, DestroyPlacementGroupBundleResourcesIfExists, (const PlacementGroupID &placement_group_id), (override)); diff --git a/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h b/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h index 764bee572cabc..d981be23a5472 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h @@ -17,7 +17,6 @@ namespace gcs { class MockGcsResourceManager : public GcsResourceManager { public: - using GcsResourceManager::GcsResourceManager; MOCK_METHOD(void, HandleGetResources, (const rpc::GetResourcesRequest &request, rpc::GetResourcesReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/pubsub/gcs_pub_sub.h b/src/mock/ray/gcs/pubsub/gcs_pub_sub.h deleted file mode 100644 index 21e500da0a002..0000000000000 --- a/src/mock/ray/gcs/pubsub/gcs_pub_sub.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2021 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace ray { -namespace gcs { - -class MockGcsPubSub : public GcsPubSub { - public: - MOCK_METHOD(Status, Publish, - (const std::string &channel, const std::string &id, const std::string &data, - const StatusCallback &done), - (override)); -}; - -} // namespace gcs -} // namespace ray diff --git a/src/mock/ray/gcs/store_client/in_memory_store_client.h b/src/mock/ray/gcs/store_client/in_memory_store_client.h deleted file mode 100644 index 08af16a075a17..0000000000000 --- a/src/mock/ray/gcs/store_client/in_memory_store_client.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2021 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace ray { -namespace gcs { - -class MockInMemoryStoreClient : public InMemoryStoreClient { - public: - MOCK_METHOD(Status, AsyncPut, - (const std::string &table_name, const std::string &key, - const std::string &data, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncPutWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const std::string &data, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncGet, - (const std::string &table_name, const std::string &key, - const OptionalItemCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncGetByIndex, - (const std::string &table_name, const std::string &index_key, - (const MapCallback &callback)), - (override)); - MOCK_METHOD(Status, AsyncGetAll, - (const std::string &table_name, - (const MapCallback &callback)), - (override)); - MOCK_METHOD(Status, AsyncDelete, - (const std::string &table_name, const std::string &key, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncDeleteWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncBatchDelete, - (const std::string &table_name, const std::vector &keys, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, - (const std::string &table_name, const std::vector &keys, - const std::vector &index_keys, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncDeleteByIndex, - (const std::string &table_name, const std::string &index_key, - const StatusCallback &callback), - (override)); - MOCK_METHOD(int, GetNextJobID, (), (override)); -}; - -} // namespace gcs -} // namespace ray diff --git a/src/mock/ray/gcs/store_client/redis_store_client.h b/src/mock/ray/gcs/store_client/redis_store_client.h deleted file mode 100644 index 153a69755d3b7..0000000000000 --- a/src/mock/ray/gcs/store_client/redis_store_client.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2021 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace ray { -namespace gcs { - -class MockRedisStoreClient : public RedisStoreClient { - public: - MockRedisStoreClient() : RedisStoreClient(nullptr) {} - MOCK_METHOD(Status, AsyncPut, - (const std::string &table_name, const std::string &key, - const std::string &data, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncPutWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const std::string &data, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncGet, - (const std::string &table_name, const std::string &key, - const OptionalItemCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncGetByIndex, - (const std::string &table_name, const std::string &index_key, - (const MapCallback &callback)), - (override)); - MOCK_METHOD(Status, AsyncGetAll, - (const std::string &table_name, - (const MapCallback &callback)), - (override)); - MOCK_METHOD(Status, AsyncDelete, - (const std::string &table_name, const std::string &key, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncDeleteWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncBatchDelete, - (const std::string &table_name, const std::vector &keys, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, - (const std::string &table_name, const std::vector &keys, - const std::vector &index_keys, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncDeleteByIndex, - (const std::string &table_name, const std::string &index_key, - const StatusCallback &callback), - (override)); - MOCK_METHOD(int, GetNextJobID, (), (override)); -}; - -} // namespace gcs -} // namespace ray diff --git a/src/mock/ray/gcs/store_client/store_client.h b/src/mock/ray/gcs/store_client/store_client.h deleted file mode 100644 index 6f4e3b5382735..0000000000000 --- a/src/mock/ray/gcs/store_client/store_client.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2021 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace ray { -namespace gcs { - -class MockStoreClient : public StoreClient { - public: - MOCK_METHOD(Status, AsyncPut, - (const std::string &table_name, const std::string &key, - const std::string &data, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncPutWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const std::string &data, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncGet, - (const std::string &table_name, const std::string &key, - const OptionalItemCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncGetByIndex, - (const std::string &table_name, const std::string &index_key, - (const MapCallback &callback)), - (override)); - MOCK_METHOD(Status, AsyncGetAll, - (const std::string &table_name, - (const MapCallback &callback)), - (override)); - MOCK_METHOD(Status, AsyncDelete, - (const std::string &table_name, const std::string &key, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncDeleteWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncBatchDelete, - (const std::string &table_name, const std::vector &keys, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, - (const std::string &table_name, const std::vector &keys, - const std::vector &index_keys, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncDeleteByIndex, - (const std::string &table_name, const std::string &index_key, - const StatusCallback &callback), - (override)); - MOCK_METHOD(int, GetNextJobID, (), (override)); -}; - -} // namespace gcs -} // namespace ray diff --git a/src/mock/ray/pubsub/publisher.h b/src/mock/ray/pubsub/publisher.h deleted file mode 100644 index 7094a9afadeac..0000000000000 --- a/src/mock/ray/pubsub/publisher.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2021 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace ray { -namespace pubsub { -namespace pub_internal { - -template -class MockSubscriptionIndex : public SubscriptionIndex { - public: -}; - -} // namespace pub_internal -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { -namespace pub_internal { - -class MockLongPollConnection : public LongPollConnection { - public: -}; - -} // namespace pub_internal -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { -namespace pub_internal { - -class MockSubscriber : public Subscriber { - public: -}; - -} // namespace pub_internal -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { - -class MockPublisherInterface : public PublisherInterface { - public: - MOCK_METHOD(bool, RegisterSubscription, - (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, - const std::string &key_id_binary), - (override)); - MOCK_METHOD(void, Publish, - (const rpc::ChannelType channel_type, const rpc::PubMessage &pub_message, - const std::string &key_id_binary), - (override)); - MOCK_METHOD(void, PublishFailure, - (const rpc::ChannelType channel_type, const std::string &key_id_binary), - (override)); - MOCK_METHOD(bool, UnregisterSubscription, - (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, - const std::string &key_id_binary), - (override)); -}; - -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { - -class MockPublisher : public Publisher { - public: - MOCK_METHOD(bool, RegisterSubscription, - (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, - const std::string &key_id_binary), - (override)); - MOCK_METHOD(void, Publish, - (const rpc::ChannelType channel_type, const rpc::PubMessage &pub_message, - const std::string &key_id_binary), - (override)); - MOCK_METHOD(void, PublishFailure, - (const rpc::ChannelType channel_type, const std::string &key_id_binary), - (override)); - MOCK_METHOD(bool, UnregisterSubscription, - (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, - const std::string &key_id_binary), - (override)); -}; - -} // namespace pubsub -} // namespace ray diff --git a/src/mock/ray/pubsub/subscriber.h b/src/mock/ray/pubsub/subscriber.h deleted file mode 100644 index 38dc5f32afb65..0000000000000 --- a/src/mock/ray/pubsub/subscriber.h +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2021 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace ray { -namespace pubsub { - -template -class MockSubscriptionInfo : public SubscriptionInfo { - public: -}; - -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { - -class MockSubscribeChannelInterface : public SubscribeChannelInterface { - public: - MOCK_METHOD(void, Subscribe, - (const rpc::Address &publisher_address, const std::string &key_id_binary, - SubscriptionCallback subscription_callback, - SubscriptionFailureCallback subscription_failure_callback), - (override)); - MOCK_METHOD(bool, Unsubscribe, - (const rpc::Address &publisher_address, const std::string &key_id_binary), - (override)); - MOCK_METHOD(void, HandlePublishedMessage, - (const rpc::Address &publisher_address, const rpc::PubMessage &pub_message), - (const, override)); - MOCK_METHOD(void, HandlePublisherFailure, (const rpc::Address &publisher_address), - (override)); - MOCK_METHOD(void, HandlePublisherFailure, - (const rpc::Address &publisher_address, const std::string &key_id_binary), - (override)); - MOCK_METHOD(bool, SubscriptionExists, (const PublisherID &publisher_id), (override)); - MOCK_METHOD(const rpc::ChannelType, GetChannelType, (), (const, override)); - MOCK_METHOD(bool, CheckNoLeaks, (), (const, override)); - MOCK_METHOD(std::string, DebugString, (), (const, override)); -}; - -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { - -template -class MockSubscriberChannel : public SubscriberChannel { - public: - MOCK_METHOD(void, Subscribe, - (const rpc::Address &publisher_address, const std::string &key_id, - SubscriptionCallback subscription_callback, - SubscriptionFailureCallback subscription_failure_callback), - (override)); - MOCK_METHOD(bool, Unsubscribe, - (const rpc::Address &publisher_address, const std::string &key_id), - (override)); - MOCK_METHOD(bool, CheckNoLeaks, (), (const, override)); - MOCK_METHOD(void, HandlePublishedMessage, - (const rpc::Address &publisher_address, const rpc::PubMessage &pub_message), - (const, override)); - MOCK_METHOD(void, HandlePublisherFailure, (const rpc::Address &publisher_address), - (override)); - MOCK_METHOD(void, HandlePublisherFailure, - (const rpc::Address &publisher_address, const std::string &key_id_binary), - (override)); - MOCK_METHOD(bool, SubscriptionExists, (const PublisherID &publisher_id), (override)); - MOCK_METHOD(const rpc::ChannelType, GetChannelType, (), (const, override)); - MOCK_METHOD(std::string, DebugString, (), (const, override)); -}; - -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { - -class MockWaitForObjectEvictionChannel : public WaitForObjectEvictionChannel { - public: -}; - -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { - -class MockWaitForRefRemovedChannel : public WaitForRefRemovedChannel { - public: -}; - -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { - -class MockObjectLocationsChannel : public ObjectLocationsChannel { - public: -}; - -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { - -class MockSubscriberInterface : public SubscriberInterface { - public: - MOCK_METHOD(void, Subscribe, - (std::unique_ptr sub_message, - const rpc::ChannelType channel_type, const rpc::Address &publisher_address, - const std::string &key_id_binary, - SubscriptionCallback subscription_callback, - SubscriptionFailureCallback subscription_failure_callback), - (override)); - MOCK_METHOD(bool, Unsubscribe, - (const rpc::ChannelType channel_type, const rpc::Address &publisher_address, - const std::string &key_id_binary), - (override)); - MOCK_METHOD(std::string, DebugString, (), (const, override)); -}; - -} // namespace pubsub -} // namespace ray - -namespace ray { -namespace pubsub { - -class MockSubscriberClientInterface : public SubscriberClientInterface { - public: - MOCK_METHOD(void, PubsubLongPolling, - (const rpc::PubsubLongPollingRequest &request, - const rpc::ClientCallback &callback), - (override)); - MOCK_METHOD(void, PubsubCommandBatch, - (const rpc::PubsubCommandBatchRequest &request, - const rpc::ClientCallback &callback), - (override)); -}; - -} // namespace pubsub -} // namespace ray diff --git a/src/mock/ray/raylet/node_manager.h b/src/mock/ray/raylet/node_manager.h index 1ce3563ba450d..7edc1c9916d07 100644 --- a/src/mock/ray/raylet/node_manager.h +++ b/src/mock/ray/raylet/node_manager.h @@ -67,11 +67,6 @@ class MockNodeManager : public NodeManager { rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleReportWorkerBacklog, - (const rpc::ReportWorkerBacklogRequest &request, - rpc::ReportWorkerBacklogReply *reply, - rpc::SendReplyCallback send_reply_callback), - (override)); MOCK_METHOD(void, HandleReturnWorker, (const rpc::ReturnWorkerRequest &request, rpc::ReturnWorkerReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h b/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h index 3c5d4498af18b..498a5088b7194 100644 --- a/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h +++ b/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h @@ -33,6 +33,8 @@ class MockClusterTaskManagerInterface : public ClusterTaskManagerInterface { (const, override)); MOCK_METHOD(void, TaskFinished, (std::shared_ptr worker, RayTask *task), (override)); + MOCK_METHOD(void, ReturnWorkerResources, (std::shared_ptr worker), + (override)); MOCK_METHOD(bool, CancelTask, (const TaskID &task_id, bool runtime_env_setup_failed), (override)); MOCK_METHOD(void, QueueAndScheduleTask, diff --git a/src/mock/ray/raylet_client/raylet_client.h b/src/mock/ray/raylet_client/raylet_client.h index c2dc3dd43097c..cafd952e5d6e4 100644 --- a/src/mock/ray/raylet_client/raylet_client.h +++ b/src/mock/ray/raylet_client/raylet_client.h @@ -35,12 +35,6 @@ class MockWorkerLeaseInterface : public WorkerLeaseInterface { const ray::rpc::ClientCallback &callback, const int64_t backlog_size), (override)); - MOCK_METHOD( - void, RequestWorkerLease, - (const rpc::TaskSpec &task_spec, - const ray::rpc::ClientCallback &callback, - const int64_t backlog_size), - (override)); MOCK_METHOD(ray::Status, ReturnWorker, (int worker_port, const WorkerID &worker_id, bool disconnect_worker), (override)); @@ -72,7 +66,7 @@ class MockResourceReserveInterface : public ResourceReserveInterface { (override)); MOCK_METHOD( void, CancelResourceReserve, - (const BundleSpecification &bundle_spec, + (BundleSpecification & bundle_spec, const ray::rpc::ClientCallback &callback), (override)); MOCK_METHOD(void, ReleaseUnusedBundles, @@ -112,27 +106,41 @@ class MockResourceTrackingInterface : public ResourceTrackingInterface { namespace ray { class MockRayletClientInterface : public RayletClientInterface { + public: + MOCK_METHOD(void, GetSystemConfig, + (const rpc::ClientCallback &callback), + (override)); + MOCK_METHOD(void, GetGcsServerAddress, + (const rpc::ClientCallback &callback), + (override)); +}; + +} // namespace ray + +namespace ray { +namespace raylet { + +class MockRayletConnection : public RayletConnection { + public: +}; + +} // namespace raylet +} // namespace ray + +namespace ray { +namespace raylet { + +class MockRayletClient : public RayletClient { public: MOCK_METHOD(ray::Status, WaitForDirectActorCallArgs, (const std::vector &references, int64_t tag), (override)); - MOCK_METHOD(void, ReportWorkerBacklog, - (const WorkerID &worker_id, - const std::vector &backlog_reports), - (override)); MOCK_METHOD( void, RequestWorkerLease, (const ray::TaskSpecification &resource_spec, const ray::rpc::ClientCallback &callback, const int64_t backlog_size), (override)); - MOCK_METHOD( - void, RequestWorkerLease, - (const rpc::TaskSpec &resource_spec, - const ray::rpc::ClientCallback &callback, - const int64_t backlog_size), - (override)); - MOCK_METHOD(ray::Status, ReturnWorker, (int worker_port, const WorkerID &worker_id, bool disconnect_worker), (override)); @@ -156,7 +164,7 @@ class MockRayletClientInterface : public RayletClientInterface { (override)); MOCK_METHOD( void, CancelResourceReserve, - (const BundleSpecification &bundle_spec, + (BundleSpecification & bundle_spec, const ray::rpc::ClientCallback &callback), (override)); MOCK_METHOD(void, ReleaseUnusedBundles, @@ -183,4 +191,5 @@ class MockRayletClientInterface : public RayletClientInterface { (override)); }; +} // namespace raylet } // namespace ray diff --git a/src/mock/ray/rpc/worker/core_worker_client.h b/src/mock/ray/rpc/worker/core_worker_client.h deleted file mode 100644 index a4646cef99e16..0000000000000 --- a/src/mock/ray/rpc/worker/core_worker_client.h +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2021 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace ray { -namespace rpc { - -class MockWorkerAddress : public WorkerAddress { - public: -}; - -} // namespace rpc -} // namespace ray - -namespace ray { -namespace rpc { - -class MockCoreWorkerClientInterface : public ray::pubsub::MockSubscriberClientInterface, - public CoreWorkerClientInterface { - public: - MOCK_METHOD(const rpc::Address &, Addr, (), (const, override)); - MOCK_METHOD(void, PushActorTask, - (std::unique_ptr request, bool skip_queue, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, PushNormalTask, - (std::unique_ptr request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, StealTasks, - (std::unique_ptr request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, DirectActorCallArgWaitComplete, - (const DirectActorCallArgWaitCompleteRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, GetObjectStatus, - (const GetObjectStatusRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, WaitForActorOutOfScope, - (const WaitForActorOutOfScopeRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, PubsubLongPolling, - (const PubsubLongPollingRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, PubsubCommandBatch, - (const PubsubCommandBatchRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, UpdateObjectLocationBatch, - (const UpdateObjectLocationBatchRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, GetObjectLocationsOwner, - (const GetObjectLocationsOwnerRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, KillActor, - (const KillActorRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, CancelTask, - (const CancelTaskRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, RemoteCancelTask, - (const RemoteCancelTaskRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, GetCoreWorkerStats, - (const GetCoreWorkerStatsRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, LocalGC, - (const LocalGCRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, SpillObjects, - (const SpillObjectsRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, RestoreSpilledObjects, - (const RestoreSpilledObjectsRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, DeleteSpilledObjects, - (const DeleteSpilledObjectsRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, AddSpilledUrl, - (const AddSpilledUrlRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, PlasmaObjectReady, - (const PlasmaObjectReadyRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(void, Exit, - (const ExitRequest &request, const ClientCallback &callback), - (override)); - MOCK_METHOD(void, AssignObjectOwner, - (const AssignObjectOwnerRequest &request, - const ClientCallback &callback), - (override)); - MOCK_METHOD(int64_t, ClientProcessedUpToSeqno, (), (override)); -}; - -} // namespace rpc -} // namespace ray diff --git a/src/mock/ray/rpc/worker/core_worker_client_pool.h b/src/mock/ray/rpc/worker/core_worker_client_pool.h deleted file mode 100644 index d4e1ec607e5a2..0000000000000 --- a/src/mock/ray/rpc/worker/core_worker_client_pool.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2021 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace ray { -namespace rpc { - -class MockCoreWorkerClientPool : public CoreWorkerClientPool { - public: -}; - -} // namespace rpc -} // namespace ray diff --git a/src/ray/common/bundle_spec.cc b/src/ray/common/bundle_spec.cc index c5b4a711e0275..339a492360d21 100644 --- a/src/ray/common/bundle_spec.cc +++ b/src/ray/common/bundle_spec.cc @@ -74,10 +74,6 @@ PlacementGroupID BundleSpecification::PlacementGroupId() const { return PlacementGroupID::FromBinary(message_->bundle_id().placement_group_id()); } -NodeID BundleSpecification::NodeId() const { - return NodeID::FromBinary(message_->node_id()); -} - int64_t BundleSpecification::Index() const { return message_->bundle_id().bundle_index(); } @@ -93,19 +89,16 @@ std::string BundleSpecification::DebugString() const { std::string FormatPlacementGroupResource(const std::string &original_resource_name, const PlacementGroupID &group_id, int64_t bundle_index) { - std::stringstream os; + std::string str; if (bundle_index >= 0) { - os << original_resource_name << kGroupKeyword << std::to_string(bundle_index) << "_" - << group_id.Hex(); + str = original_resource_name + "_group_" + std::to_string(bundle_index) + "_" + + group_id.Hex(); } else { RAY_CHECK(bundle_index == -1) << "Invalid index " << bundle_index; - os << original_resource_name << kGroupKeyword << group_id.Hex(); + str = original_resource_name + "_group_" + group_id.Hex(); } - std::string result = os.str(); - RAY_DCHECK(GetOriginalResourceName(result) == original_resource_name) - << "Generated: " << GetOriginalResourceName(result) - << " Original: " << original_resource_name; - return result; + RAY_CHECK(GetOriginalResourceName(str) == original_resource_name) << str; + return str; } std::string FormatPlacementGroupResource(const std::string &original_resource_name, @@ -116,12 +109,12 @@ std::string FormatPlacementGroupResource(const std::string &original_resource_na bool IsBundleIndex(const std::string &resource, const PlacementGroupID &group_id, const int bundle_index) { - return resource.find(kGroupKeyword + std::to_string(bundle_index) + "_" + - group_id.Hex()) != std::string::npos; + return resource.find("_group_" + std::to_string(bundle_index) + "_" + group_id.Hex()) != + std::string::npos; } std::string GetOriginalResourceName(const std::string &resource) { - auto idx = resource.find(kGroupKeyword); + auto idx = resource.find("_group_"); RAY_CHECK(idx >= 0) << "This isn't a placement group resource " << resource; return resource.substr(0, idx); } diff --git a/src/ray/common/bundle_spec.h b/src/ray/common/bundle_spec.h index bca5396fdc71a..8437704509b58 100644 --- a/src/ray/common/bundle_spec.h +++ b/src/ray/common/bundle_spec.h @@ -32,9 +32,6 @@ typedef std::function ScheduleBundleCallback; /// address and the raylet's port. typedef std::function SpillbackBundleCallback; -const std::string kGroupKeyword = "_group_"; -const size_t kGroupKeywordSize = kGroupKeyword.size(); - class BundleSpecification : public MessageWrapper { public: /// Construct from a protobuf message object. @@ -57,9 +54,6 @@ class BundleSpecification : public MessageWrapper { // Return the Placement Group id which the Bundle belong to. PlacementGroupID PlacementGroupId() const; - // Get a node ID that this bundle is scheduled on. - NodeID NodeId() const; - // Return the index of the bundle. int64_t Index() const; diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index 7eb51a953e215..73743820b2b9b 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 19180ef356b38..780c1b70d3098 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -51,6 +51,3 @@ constexpr int kMessagePackOffset = 9; /// Filename of "shim process" that sets up Python worker environment. /// Should be kept in sync with SETUP_WORKER_FILENAME in ray.ray_constants. constexpr char kSetupWorkerFilename[] = "setup_worker.py"; - -/// The version of Ray -constexpr char kRayVersion[] = "2.0.0.dev0"; diff --git a/src/ray/common/id.h b/src/ray/common/id.h index 889128e81df11..0fc5d45599392 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -492,7 +492,6 @@ std::string BaseID::Hex() const { constexpr char hex[] = "0123456789abcdef"; const uint8_t *id = Data(); std::string result; - result.reserve(T::Size()); for (size_t i = 0; i < T::Size(); i++) { unsigned int val = id[i]; result.push_back(hex[val >> 4]); diff --git a/src/ray/common/network_util.h b/src/ray/common/network_util.h index 8f268ec46b389..08bef7ae873af 100644 --- a/src/ray/common/network_util.h +++ b/src/ray/common/network_util.h @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 0a6a61357b79f..53e0bf4d72450 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -183,7 +183,8 @@ RAY_CONFIG(int64_t, worker_register_timeout_seconds, 30) RAY_CONFIG(int64_t, redis_db_connect_retries, 50) RAY_CONFIG(int64_t, redis_db_connect_wait_milliseconds, 100) -/// The object manager's global timer interval in milliseconds. +/// Timeout, in milliseconds, to wait before retrying a failed pull in the +/// ObjectManager. RAY_CONFIG(int, object_manager_timer_freq_ms, 100) /// Timeout, in milliseconds, to wait before retrying a failed pull in the @@ -220,8 +221,14 @@ RAY_CONFIG(int32_t, maximum_profile_table_rows_count, 10 * 1000) /// message. RAY_CONFIG(uint32_t, object_store_get_max_ids_to_print_in_warning, 20) +// TODO: fix win32 timeout in ci and unify these two. +#ifdef _MSC_VER /// Number of threads used by rpc server in gcs server. RAY_CONFIG(uint32_t, gcs_server_rpc_server_thread_num, 1) +#else +/// Number of threads used by rpc server in gcs server. +RAY_CONFIG(uint32_t, gcs_server_rpc_server_thread_num, 8) +#endif /// Allow up to 5 seconds for connecting to gcs service. /// Note: this only takes effect when gcs service is enabled. RAY_CONFIG(int64_t, gcs_service_connect_retries, 50) @@ -234,10 +241,8 @@ RAY_CONFIG(uint64_t, gcs_redis_heartbeat_interval_milliseconds, 100) RAY_CONFIG(uint32_t, gcs_lease_worker_retry_interval_ms, 200) /// Duration to wait between retries for creating actor in gcs server. RAY_CONFIG(uint32_t, gcs_create_actor_retry_interval_ms, 200) -/// Exponential backoff params for gcs to retry creating a placement group -RAY_CONFIG(uint32_t, gcs_create_placement_group_retry_min_interval_ms, 200) -RAY_CONFIG(uint32_t, gcs_create_placement_group_retry_max_interval_ms, 5000) -RAY_CONFIG(double, gcs_create_placement_group_retry_multiplier, 1.5); +/// Duration to wait between retries for creating placement group in gcs server. +RAY_CONFIG(uint32_t, gcs_create_placement_group_retry_interval_ms, 200) /// Maximum number of destroyed actors in GCS server memory cache. RAY_CONFIG(uint32_t, maximum_gcs_destroyed_actor_cached_count, 100000) /// Maximum number of dead nodes in GCS server memory cache. @@ -306,18 +311,12 @@ RAY_CONFIG(int64_t, task_rpc_inlined_bytes_limit, 10 * 1024 * 1024) /// pipelining task submission. RAY_CONFIG(uint32_t, max_tasks_in_flight_per_worker, 1) -/// Maximum number of pending lease requests per scheduling category -RAY_CONFIG(uint64_t, max_pending_lease_requests_per_scheduling_category, 10) - /// Interval to restart dashboard agent after the process exit. RAY_CONFIG(uint32_t, agent_restart_interval_ms, 1000) /// Wait timeout for dashboard agent register. RAY_CONFIG(uint32_t, agent_register_timeout_ms, 30 * 1000) -/// Max restart count for the dashboard agent. -RAY_CONFIG(uint32_t, agent_max_restart_count, 5) - /// If the agent manager fails to communicate with the dashboard agent, we will retry /// after this interval. RAY_CONFIG(uint32_t, agent_manager_retry_interval_ms, 1000); @@ -326,8 +325,12 @@ RAY_CONFIG(uint32_t, agent_manager_retry_interval_ms, 1000); /// load reported by each raylet. RAY_CONFIG(int64_t, max_resource_shapes_per_load_report, 100) +/// If true, the worker's queue backlog size will be propagated to the heartbeat batch +/// data. +RAY_CONFIG(bool, report_worker_backlog, true) + /// The timeout for synchronous GCS requests in seconds. -RAY_CONFIG(int64_t, gcs_server_request_timeout_seconds, 60) +RAY_CONFIG(int64_t, gcs_server_request_timeout_seconds, 5) /// Whether to enable worker prestarting: https://github.com/ray-project/ray/issues/12052 RAY_CONFIG(bool, enable_worker_prestart, true) @@ -475,7 +478,7 @@ RAY_CONFIG(int64_t, grpc_keepalive_time_ms, 10000); RAY_CONFIG(int64_t, grpc_keepalive_timeout_ms, 20000); /// Whether to use log reporter in event framework -RAY_CONFIG(bool, event_log_reporter_enabled, true) +RAY_CONFIG(bool, event_log_reporter_enabled, false) /// Whether to use log reporter in event framework RAY_CONFIG(bool, actor_register_async, true) @@ -488,11 +491,3 @@ RAY_CONFIG(bool, scheduler_avoid_gpu_nodes, true) /// Whether to skip running local GC in runtime env. RAY_CONFIG(bool, runtime_env_skip_local_gc, false) - -/// Whether or not use TLS. -RAY_CONFIG(int64_t, USE_TLS, 0) - -/// Location of TLS credentials -RAY_CONFIG(std::string, TLS_SERVER_CERT, "") -RAY_CONFIG(std::string, TLS_SERVER_KEY, "") -RAY_CONFIG(std::string, TLS_CA_CERT, "") diff --git a/src/ray/common/ray_internal_flag_def.h b/src/ray/common/ray_internal_flag_def.h index 0f42d63d3f1ef..20f1ef8ccc3e3 100644 --- a/src/ray/common/ray_internal_flag_def.h +++ b/src/ray/common/ray_internal_flag_def.h @@ -27,6 +27,3 @@ RAY_INTERNAL_FLAG(std::string, JOB_ID, "") /// Raylet process ID. RAY_INTERNAL_FLAG(std::string, RAYLET_PID, "") - -/// Override the random node ID for testing. -RAY_INTERNAL_FLAG(std::string, OVERRIDE_NODE_ID_FOR_TESTING, "") diff --git a/src/ray/common/runtime_env_manager.cc b/src/ray/common/runtime_env_manager.cc index 9e39488fa9149..2ec95cdecee8f 100644 --- a/src/ray/common/runtime_env_manager.cc +++ b/src/ray/common/runtime_env_manager.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "ray/common/runtime_env_manager.h" - #include "ray/util/logging.h" namespace ray { @@ -21,12 +20,17 @@ void RuntimeEnvManager::AddURIReference(const std::string &hex_id, const rpc::RuntimeEnv &runtime_env) { const auto &uris = runtime_env.uris(); for (const auto &uri : uris) { - if (unused_uris_.count(uri)) { - unused_uris_.erase(uri); - } - uri_reference_[uri]++; - id_to_uris_[hex_id].push_back(uri); + AddURIReference(hex_id, uri); + } +} + +void RuntimeEnvManager::AddURIReference(const std::string &hex_id, + const std::string &uri) { + if (unused_uris_.count(uri)) { + unused_uris_.erase(uri); } + uri_reference_[uri]++; + id_to_uris_[hex_id].push_back(uri); } const std::vector &RuntimeEnvManager::GetReferences( diff --git a/src/ray/common/runtime_env_manager.h b/src/ray/common/runtime_env_manager.h index f9c59d74784bb..510aa5fe53aa9 100644 --- a/src/ray/common/runtime_env_manager.h +++ b/src/ray/common/runtime_env_manager.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once #include - #include "ray/common/id.h" #include "src/ray/protobuf/common.pb.h" @@ -38,6 +37,12 @@ class RuntimeEnvManager { /// \param[in] runtime_env The runtime env used by the id. void AddURIReference(const std::string &hex_id, const rpc::RuntimeEnv &runtime_env); + /// Increase the reference of URI by URI and runtime_env. + /// + /// \param[in] hex_id The id of the runtime env. It can be an actor or job id. + /// \param[in] uri The URI referenced by the id. + void AddURIReference(const std::string &hex_id, const std::string &uri); + /// Get the reference of URIs by id. /// /// \param[in] hex_id The id of to look. diff --git a/src/ray/common/task/task.cc b/src/ray/common/task/task.cc index 4765751afa3fc..291829e36f567 100644 --- a/src/ray/common/task/task.cc +++ b/src/ray/common/task/task.cc @@ -18,9 +18,10 @@ namespace ray { -RayTask::RayTask(const rpc::Task &message) +RayTask::RayTask(const rpc::Task &message, int64_t backlog_size) : task_spec_(message.task_spec()), - task_execution_spec_(message.task_execution_spec()) { + task_execution_spec_(message.task_execution_spec()), + backlog_size_(backlog_size) { ComputeDependencies(); } @@ -49,6 +50,10 @@ void RayTask::CopyTaskExecutionSpec(const RayTask &task) { task_execution_spec_ = task.task_execution_spec_; } +void RayTask::SetBacklogSize(int64_t backlog_size) { backlog_size_ = backlog_size; } + +int64_t RayTask::BacklogSize() const { return backlog_size_; } + std::string RayTask::DebugString() const { std::ostringstream stream; stream << "task_spec={" << task_spec_.DebugString() << "}, task_execution_spec={" diff --git a/src/ray/common/task/task.h b/src/ray/common/task/task.h index 52c0e9246dab2..c21ec9c94da8e 100644 --- a/src/ray/common/task/task.h +++ b/src/ray/common/task/task.h @@ -47,7 +47,9 @@ class RayTask { /// Construct a `RayTask` object from a protobuf message. /// /// \param message The protobuf message. - explicit RayTask(const rpc::Task &message); + /// \param backlog_size The size of the task owner's backlog size for this + /// task's shape. + explicit RayTask(const rpc::Task &message, int64_t backlog_size = -1); /// Construct a `RayTask` object from a `TaskSpecification` and a /// `TaskExecutionSpecification`. @@ -101,6 +103,10 @@ class RayTask { /// Returns the cancellation task callback, or nullptr. const CancelTaskCallback &OnCancellation() const { return on_cancellation_; } + void SetBacklogSize(int64_t backlog_size); + + int64_t BacklogSize() const; + std::string DebugString() const; private: @@ -127,6 +133,8 @@ class RayTask { /// For direct task calls, overrides the cancellation behaviour to send an /// RPC back to the submitting worker. mutable CancelTaskCallback on_cancellation_ = nullptr; + /// The size of the core worker's backlog when this task was submitted. + int64_t backlog_size_ = -1; }; } // namespace ray diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 0c3d77beb5993..353406fd3c820 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -132,10 +132,8 @@ ray::FunctionDescriptor TaskSpecification::FunctionDescriptor() const { return ray::FunctionDescriptorBuilder::FromProto(message_->function_descriptor()); } -rpc::RuntimeEnv TaskSpecification::RuntimeEnv() const { return message_->runtime_env(); } - std::string TaskSpecification::SerializedRuntimeEnv() const { - return message_->runtime_env().serialized_runtime_env(); + return message_->serialized_runtime_env(); } bool TaskSpecification::HasRuntimeEnv() const { @@ -147,7 +145,8 @@ int TaskSpecification::GetRuntimeEnvHash() const { if (RayConfig::instance().worker_resource_limits_enabled()) { required_resource = GetRequiredResources().GetResourceMap(); } - WorkerCacheKey env = {SerializedRuntimeEnv(), required_resource}; + WorkerCacheKey env = {OverrideEnvironmentVariables(), SerializedRuntimeEnv(), + required_resource}; return env.IntHash(); } @@ -240,6 +239,11 @@ std::string TaskSpecification::GetDebuggerBreakpoint() const { return message_->debugger_breakpoint(); } +std::unordered_map +TaskSpecification::OverrideEnvironmentVariables() const { + return MapFromProtobuf(message_->override_environment_variables()); +} + bool TaskSpecification::IsDriverTask() const { return message_->type() == TaskType::DRIVER_TASK; } @@ -394,9 +398,11 @@ std::string TaskSpecification::CallSiteString() const { } WorkerCacheKey::WorkerCacheKey( + const std::unordered_map override_environment_variables, const std::string serialized_runtime_env, const std::unordered_map required_resources) - : serialized_runtime_env(serialized_runtime_env), + : override_environment_variables(override_environment_variables), + serialized_runtime_env(serialized_runtime_env), required_resources(std::move(required_resources)) {} bool WorkerCacheKey::operator==(const WorkerCacheKey &k) const { @@ -405,7 +411,8 @@ bool WorkerCacheKey::operator==(const WorkerCacheKey &k) const { } bool WorkerCacheKey::EnvIsEmpty() const { - return (serialized_runtime_env == "" || serialized_runtime_env == "{}") && + return override_environment_variables.size() == 0 && + (serialized_runtime_env == "" || serialized_runtime_env == "{}") && required_resources.empty(); } @@ -417,6 +424,19 @@ std::size_t WorkerCacheKey::Hash() const { // runtime envs. hash_ = 0; } else { + std::vector> env_vars( + override_environment_variables.begin(), override_environment_variables.end()); + // The environment doesn't depend the order of the variables, so the hash should not + // either. Sort the variables so different permutations yield the same hash. + std::sort(env_vars.begin(), env_vars.end()); + for (auto &pair : env_vars) { + // TODO(architkulkarni): boost::hash_combine isn't guaranteed to be equal during + // separate runs of a program, which may cause problems if these hashes are + // communicated between different Raylets and compared. + boost::hash_combine(hash_, pair.first); + boost::hash_combine(hash_, pair.second); + } + boost::hash_combine(hash_, serialized_runtime_env); std::vector> resource_vars( diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 24dbf4afbae21..8b10b163cc3cc 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -100,8 +100,6 @@ class TaskSpecification : public MessageWrapper { ray::FunctionDescriptor FunctionDescriptor() const; - [[nodiscard]] rpc::RuntimeEnv RuntimeEnv() const; - std::string SerializedRuntimeEnv() const; bool HasRuntimeEnv() const; @@ -172,6 +170,8 @@ class TaskSpecification : public MessageWrapper { std::string GetDebuggerBreakpoint() const; + std::unordered_map OverrideEnvironmentVariables() const; + bool IsDriverTask() const; Language GetLanguage() const; @@ -254,7 +254,7 @@ class TaskSpecification : public MessageWrapper { /// Field storing required placement resources. Initialized in constructor. std::shared_ptr required_placement_resources_; /// Cached scheduling class of this task. - SchedulingClass sched_cls_id_ = 0; + SchedulingClass sched_cls_id_; /// Below static fields could be mutated in `ComputeResources` concurrently due to /// multi-threading, we need a mutex to protect it. @@ -275,10 +275,13 @@ class WorkerCacheKey { /// Create a cache key with the given environment variable overrides and serialized /// runtime_env. /// + /// \param override_environment_variables The environment variable overrides set in this /// worker. \param serialized_runtime_env The JSON-serialized runtime env for this /// worker. \param required_resources The required resouce. - WorkerCacheKey(const std::string serialized_runtime_env, - const std::unordered_map required_resources); + WorkerCacheKey( + const std::unordered_map override_environment_variables, + const std::string serialized_runtime_env, + const std::unordered_map required_resources); bool operator==(const WorkerCacheKey &k) const; @@ -290,7 +293,8 @@ class WorkerCacheKey { /// Get the hash for this worker's environment. /// - /// \return The hash of the serialized runtime_env. + /// \return The hash of the override_environment_variables and the serialized + /// runtime_env. std::size_t Hash() const; /// Get the int-valued hash for this worker's environment, useful for portability in @@ -300,6 +304,8 @@ class WorkerCacheKey { int IntHash() const; private: + /// The environment variable overrides for this worker. + const std::unordered_map override_environment_variables; /// The JSON-serialized runtime env for this worker. const std::string serialized_runtime_env; /// The required resources for this worker. diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index 57ee5b811663e..c011829c2603d 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -106,7 +106,8 @@ class TaskSpecBuilder { const BundleID &bundle_id, bool placement_group_capture_child_tasks, const std::string &debugger_breakpoint, const std::string &serialized_runtime_env = "{}", - const std::vector &runtime_env_uris = {}, + const std::unordered_map &override_environment_variables = + {}, const std::string &concurrency_group_name = "") { message_->set_type(TaskType::NORMAL_TASK); message_->set_name(name); @@ -128,11 +129,11 @@ class TaskSpecBuilder { message_->set_placement_group_capture_child_tasks( placement_group_capture_child_tasks); message_->set_debugger_breakpoint(debugger_breakpoint); - message_->mutable_runtime_env()->set_serialized_runtime_env(serialized_runtime_env); - for (const std::string &uri : runtime_env_uris) { - message_->mutable_runtime_env()->add_uris(uri); - } + message_->set_serialized_runtime_env(serialized_runtime_env); message_->set_concurrency_group_name(concurrency_group_name); + for (const auto &env : override_environment_variables) { + (*message_->mutable_override_environment_variables())[env.first] = env.second; + } return *this; } diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index dfb4fd9a39f28..016d16ddc8851 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -60,13 +60,14 @@ struct TaskOptions { std::unordered_map &resources, const std::string &concurrency_group_name = "", const std::string &serialized_runtime_env = "{}", - const std::vector &runtime_env_uris = {}) + const std::unordered_map + &override_environment_variables = {}) : name(name), num_returns(num_returns), resources(resources), concurrency_group_name(concurrency_group_name), serialized_runtime_env(serialized_runtime_env), - runtime_env_uris(runtime_env_uris) {} + override_environment_variables(override_environment_variables) {} /// The name of this task. std::string name; @@ -76,10 +77,12 @@ struct TaskOptions { std::unordered_map resources; /// The name of the concurrency group in which this task will be executed. std::string concurrency_group_name; - // Runtime Env used by this task. Propagated to child actors and tasks. + // Runtime Env used by this task. Propagated to child actors and tasks. std::string serialized_runtime_env; - // URIs contained in the runtime_env. - std::vector runtime_env_uris; + /// Environment variables to update for this task. Maps a variable name to its + /// value. Can override existing environment variables and introduce new ones. + /// Propagated to child actors and/or tasks. + const std::unordered_map override_environment_variables; }; /// Options for actor creation tasks. @@ -94,7 +97,8 @@ struct ActorCreationOptions { BundleID placement_options = std::make_pair(PlacementGroupID::Nil(), -1), bool placement_group_capture_child_tasks = true, const std::string &serialized_runtime_env = "{}", - const std::vector &runtime_env_uris = {}, + const std::unordered_map &override_environment_variables = + {}, const std::vector &concurrency_groups = {}) : max_restarts(max_restarts), max_task_retries(max_task_retries), @@ -109,7 +113,7 @@ struct ActorCreationOptions { placement_options(placement_options), placement_group_capture_child_tasks(placement_group_capture_child_tasks), serialized_runtime_env(serialized_runtime_env), - runtime_env_uris(runtime_env_uris), + override_environment_variables(override_environment_variables), concurrency_groups(concurrency_groups.begin(), concurrency_groups.end()){}; /// Maximum number of times that the actor should be restarted if it dies @@ -151,8 +155,10 @@ struct ActorCreationOptions { bool placement_group_capture_child_tasks = true; // Runtime Env used by this actor. Propagated to child actors and tasks. std::string serialized_runtime_env; - // URIs contained in the runtime_env. - std::vector runtime_env_uris; + /// Environment variables to update for this actor. Maps a variable name to its + /// value. Can override existing environment variables and introduce new ones. + /// Propagated to child actors and/or tasks. + const std::unordered_map override_environment_variables; /// The actor concurrency groups to indicate how this actor perform its /// methods concurrently. const std::vector concurrency_groups; diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index ab8f6c1884764..37e7797e62676 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -168,7 +168,12 @@ bool WorkerContext::ShouldCaptureChildTasksInPlacementGroup() const { } const std::string &WorkerContext::GetCurrentSerializedRuntimeEnv() const { - return runtime_env_.serialized_runtime_env(); + return serialized_runtime_env_; +} + +const std::unordered_map + &WorkerContext::GetCurrentOverrideEnvironmentVariables() const { + return override_environment_variables_; } void WorkerContext::SetCurrentTaskId(const TaskID &task_id) { @@ -181,9 +186,10 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { if (task_spec.IsNormalTask()) { current_task_is_direct_call_ = true; // TODO(architkulkarni): Once workers are cached by runtime env, we should - // only set runtime_env_ once and then RAY_CHECK that we + // only set serialized_runtime_env_ once and then RAY_CHECK that we // never see a new one. - runtime_env_ = task_spec.RuntimeEnv(); + serialized_runtime_env_ = task_spec.SerializedRuntimeEnv(); + override_environment_variables_ = task_spec.OverrideEnvironmentVariables(); } else if (task_spec.IsActorCreationTask()) { RAY_CHECK(current_actor_id_.IsNil()); current_actor_id_ = task_spec.ActorCreationId(); @@ -193,7 +199,8 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { is_detached_actor_ = task_spec.IsDetachedActor(); current_actor_placement_group_id_ = task_spec.PlacementGroupBundleId().first; placement_group_capture_child_tasks_ = task_spec.PlacementGroupCaptureChildTasks(); - runtime_env_ = task_spec.RuntimeEnv(); + serialized_runtime_env_ = task_spec.SerializedRuntimeEnv(); + override_environment_variables_ = task_spec.OverrideEnvironmentVariables(); } else if (task_spec.IsActorTask()) { RAY_CHECK(current_actor_id_ == task_spec.ActorId()); } else { diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 3c5f35718235a..a403ee367c973 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -42,6 +42,9 @@ class WorkerContext { const std::string &GetCurrentSerializedRuntimeEnv() const; + const std::unordered_map + &GetCurrentOverrideEnvironmentVariables() const; + // TODO(edoakes): remove this once Python core worker uses the task interfaces. void SetCurrentTaskId(const TaskID &task_id); @@ -95,8 +98,10 @@ class WorkerContext { PlacementGroupID current_actor_placement_group_id_; // Whether or not we should implicitly capture parent's placement group. bool placement_group_capture_child_tasks_; - // The runtime env for the current actor or task. - rpc::RuntimeEnv runtime_env_; + // The JSON-serialized runtime env for the current actor or task. + std::string serialized_runtime_env_ = "{}"; + // The environment variable overrides for the current actor or task. + std::unordered_map override_environment_variables_; /// The id of the (main) thread that constructed this worker context. boost::thread::id main_thread_id_; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 0ceb78c7405b8..e9251cbf990ac 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -27,11 +27,34 @@ namespace ray { namespace core { -namespace { // Duration between internal book-keeping heartbeats. const uint64_t kInternalHeartbeatMillis = 1000; +void BuildCommonTaskSpec( + TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, + const std::string name, const TaskID ¤t_task_id, const uint64_t task_index, + const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, + const std::vector> &args, uint64_t num_returns, + const std::unordered_map &required_resources, + const std::unordered_map &required_placement_resources, + const BundleID &bundle_id, bool placement_group_capture_child_tasks, + const std::string debugger_breakpoint, const std::string &serialized_runtime_env, + const std::unordered_map &override_environment_variables, + const std::string &concurrency_group_name = "") { + // Build common task spec. + builder.SetCommonTaskSpec( + task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, + current_task_id, task_index, caller_id, address, num_returns, required_resources, + required_placement_resources, bundle_id, placement_group_capture_child_tasks, + debugger_breakpoint, serialized_runtime_env, override_environment_variables, + concurrency_group_name); + // Set task arguments. + for (const auto &arg : args) { + builder.AddArg(*arg); + } +} + JobID GetProcessJobID(const CoreWorkerOptions &options) { if (options.worker_type == WorkerType::DRIVER) { RAY_CHECK(!options.job_id.IsNil()); @@ -66,16 +89,6 @@ ObjectLocation CreateObjectLocation(const rpc::GetObjectLocationsOwnerReply &rep /// The global instance of `CoreWorkerProcess`. std::unique_ptr core_worker_process; -/// Teriminate the process without cleaning up the resources. -/// It will flush the log if logging_enabled is set to true. -void QuickExit(bool logging_enabled) { - if (logging_enabled) { - RayLog::ShutDownRayLog(); - } - _Exit(1); -} -} // namespace - thread_local std::weak_ptr CoreWorkerProcess::current_core_worker_; void CoreWorkerProcess::Initialize(const CoreWorkerOptions &options) { @@ -90,11 +103,10 @@ void CoreWorkerProcess::Shutdown() { } RAY_CHECK(core_worker_process->options_.worker_type == WorkerType::DRIVER) << "The `Shutdown` interface is for driver only."; - auto global_worker = core_worker_process->GetGlobalWorker(); - RAY_CHECK(global_worker); - global_worker->Disconnect(); - global_worker->Shutdown(); - core_worker_process->RemoveWorker(global_worker); + RAY_CHECK(core_worker_process->global_worker_); + core_worker_process->global_worker_->Disconnect(); + core_worker_process->global_worker_->Shutdown(); + core_worker_process->RemoveWorker(core_worker_process->global_worker_); core_worker_process.reset(); } @@ -135,8 +147,18 @@ CoreWorkerProcess::CoreWorkerProcess(const CoreWorkerOptions &options) // NOTE(kfstorm): any initialization depending on RayConfig must happen after this line. InitializeSystemConfig(); - if (ShouldCreateGlobalWorkerOnConstruction()) { - CreateWorker(); + if (options_.num_workers == 1) { + // We need to create the worker instance here if: + // 1. This is a driver process. In this case, the driver is ready to use right after + // the CoreWorkerProcess::Initialize. + // 2. This is a Python worker process. In this case, Python will invoke some core + // worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need + // to create the worker instance here. One example of invocations is + // https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281. + if (options_.worker_type == WorkerType::DRIVER || + options_.language == Language::PYTHON) { + CreateWorker(); + } } // Assume stats module will be initialized exactly once in once process. @@ -146,7 +168,7 @@ CoreWorkerProcess::CoreWorkerProcess(const CoreWorkerOptions &options) // Initialize stats in core worker global tags. const ray::stats::TagsType global_tags = { {ray::stats::ComponentKey, "core_worker"}, - {ray::stats::VersionKey, kRayVersion}, + {ray::stats::VersionKey, "2.0.0.dev0"}, {ray::stats::NodeAddressKey, options_.node_ip_address}}; // NOTE(lingxuan.zlx): We assume RayConfig is initialized before it's used. @@ -234,23 +256,11 @@ void CoreWorkerProcess::InitializeSystemConfig() { RayConfig::instance().initialize(promise.get_future().get()); } -bool CoreWorkerProcess::ShouldCreateGlobalWorkerOnConstruction() const { - // We need to create the worker instance here if: - // 1. This is a driver process. In this case, the driver is ready to use right after - // the CoreWorkerProcess::Initialize. - // 2. This is a Python worker process. In this case, Python will invoke some core - // worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need - // to create the worker instance here. One example of invocations is - // https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281. - return options_.num_workers == 1 && (options_.worker_type == WorkerType::DRIVER || - options_.language == Language::PYTHON); -} - std::shared_ptr CoreWorkerProcess::TryGetWorker(const WorkerID &worker_id) { if (!core_worker_process) { return nullptr; } - absl::ReaderMutexLock workers_lock(&core_worker_process->mutex_); + absl::ReaderMutexLock workers_lock(&core_worker_process->worker_map_mutex_); auto it = core_worker_process->workers_.find(worker_id); if (it != core_worker_process->workers_.end()) { return it->second; @@ -261,19 +271,8 @@ std::shared_ptr CoreWorkerProcess::TryGetWorker(const WorkerID &work CoreWorker &CoreWorkerProcess::GetCoreWorker() { EnsureInitialized(); if (core_worker_process->options_.num_workers == 1) { - auto global_worker = core_worker_process->GetGlobalWorker(); - if (core_worker_process->ShouldCreateGlobalWorkerOnConstruction() && !global_worker) { - // This could only happen when the worker has already been shutdown. - // In this case, we should exit without crashing. - // TODO (scv119): A better solution could be returning error code - // and handling it at language frontend. - RAY_LOG(ERROR) << "The global worker has already been shutdown. This happens when " - "the language frontend accesses the Ray's worker after it is " - "shutdown. The process will exit"; - QuickExit(core_worker_process->options_.enable_logging); - } - RAY_CHECK(global_worker) << "global_worker_ must not be NULL"; - return *global_worker; + RAY_CHECK(core_worker_process->global_worker_) << "global_worker_ must not be NULL"; + return *core_worker_process->global_worker_; } auto ptr = current_core_worker_.lock(); RAY_CHECK(ptr != nullptr) @@ -284,7 +283,7 @@ CoreWorker &CoreWorkerProcess::GetCoreWorker() { void CoreWorkerProcess::SetCurrentThreadWorkerId(const WorkerID &worker_id) { EnsureInitialized(); if (core_worker_process->options_.num_workers == 1) { - RAY_CHECK(core_worker_process->GetGlobalWorker()->GetWorkerID() == worker_id); + RAY_CHECK(core_worker_process->global_worker_->GetWorkerID() == worker_id); return; } current_core_worker_ = core_worker_process->GetWorker(worker_id); @@ -292,28 +291,23 @@ void CoreWorkerProcess::SetCurrentThreadWorkerId(const WorkerID &worker_id) { std::shared_ptr CoreWorkerProcess::GetWorker( const WorkerID &worker_id) const { - absl::ReaderMutexLock lock(&mutex_); + absl::ReaderMutexLock lock(&worker_map_mutex_); auto it = workers_.find(worker_id); RAY_CHECK(it != workers_.end()) << "Worker " << worker_id << " not found."; return it->second; } -std::shared_ptr CoreWorkerProcess::GetGlobalWorker() { - absl::ReaderMutexLock lock(&mutex_); - return global_worker_; -} - std::shared_ptr CoreWorkerProcess::CreateWorker() { auto worker = std::make_shared( options_, global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom()); RAY_LOG(DEBUG) << "Worker " << worker->GetWorkerID() << " is created."; - absl::WriterMutexLock lock(&mutex_); if (options_.num_workers == 1) { global_worker_ = worker; } current_core_worker_ = worker; + absl::MutexLock lock(&worker_map_mutex_); workers_.emplace(worker->GetWorkerID(), worker); RAY_CHECK(workers_.size() <= static_cast(options_.num_workers)); return worker; @@ -321,7 +315,6 @@ std::shared_ptr CoreWorkerProcess::CreateWorker() { void CoreWorkerProcess::RemoveWorker(std::shared_ptr worker) { worker->WaitForShutdown(); - absl::WriterMutexLock lock(&mutex_); if (global_worker_) { RAY_CHECK(global_worker_ == worker); } else { @@ -329,6 +322,7 @@ void CoreWorkerProcess::RemoveWorker(std::shared_ptr worker) { } current_core_worker_.reset(); { + absl::MutexLock lock(&worker_map_mutex_); workers_.erase(worker->GetWorkerID()); RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID(); } @@ -342,10 +336,9 @@ void CoreWorkerProcess::RunTaskExecutionLoop() { RAY_CHECK(core_worker_process->options_.worker_type == WorkerType::WORKER); if (core_worker_process->options_.num_workers == 1) { // Run the task loop in the current thread only if the number of workers is 1. - auto worker = core_worker_process->GetGlobalWorker(); - if (!worker) { - worker = core_worker_process->CreateWorker(); - } + auto worker = core_worker_process->global_worker_ + ? core_worker_process->global_worker_ + : core_worker_process->CreateWorker(); worker->RunTaskExecutionLoop(); core_worker_process->RemoveWorker(worker); } else { @@ -377,9 +370,9 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ periodical_runner_(io_service_), task_queue_length_(0), num_executed_tasks_(0), + task_execution_service_work_(task_execution_service_), resource_ids_(new ResourceMappingType()), - grpc_service_(io_service_, *this), - task_execution_service_work_(task_execution_service_) { + grpc_service_(io_service_, *this) { RAY_LOG(DEBUG) << "Constructing CoreWorker, worker_id: " << worker_id; // Initialize task receivers. @@ -416,8 +409,11 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Avoid using FATAL log or RAY_CHECK here because they may create a core dump file. RAY_LOG(ERROR) << "Failed to register worker " << worker_id << " to Raylet. " << raylet_client_status; + if (options_.enable_logging) { + RayLog::ShutDownRayLog(); + } // Quit the process immediately. - QuickExit(options_.enable_logging); + _Exit(1); } connected_ = true; @@ -431,8 +427,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Start RPC server after all the task receivers are properly initialized and we have // our assigned port from the raylet. core_worker_server_ = std::make_unique( - WorkerTypeString(options_.worker_type), assigned_port, - options_.node_ip_address == "127.0.0.1"); + WorkerTypeString(options_.worker_type), assigned_port); core_worker_server_->RegisterService(grpc_service_); core_worker_server_->Run(); @@ -531,6 +526,10 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ options_.worker_type != WorkerType::RESTORE_WORKER), /*get_current_call_site=*/boost::bind(&CoreWorker::CurrentCallSite, this))); memory_store_.reset(new CoreWorkerMemoryStore( + [this](const RayObject &object, const ObjectID &object_id) { + PutObjectIntoPlasma(object, object_id); + return Status::OK(); + }, reference_counter_, local_raylet_client_, options_.check_signals, [this](const RayObject &obj) { // Run this on the event loop to avoid calling back into the language runtime @@ -657,8 +656,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ std::move(lease_policy), memory_store_, task_manager_, local_raylet_id, RayConfig::instance().worker_lease_timeout_milliseconds(), actor_creator_, RayConfig::instance().max_tasks_in_flight_per_worker(), - boost::asio::steady_timer(io_service_), - RayConfig::instance().max_pending_lease_requests_per_scheduling_category()); + boost::asio::steady_timer(io_service_)); auto report_locality_data_callback = [this](const ObjectID &object_id, const absl::flat_hash_set &locations, uint64_t object_size) { @@ -739,11 +737,6 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ }, event_stats_print_interval_ms); } - - // Set event context for current core worker thread. - RayEventContext::Instance().SetEventContext( - ray::rpc::Event_SourceType::Event_SourceType_CORE_WORKER, - {{"worker_id", worker_id.Hex()}}); } void CoreWorker::Shutdown() { @@ -940,25 +933,17 @@ void CoreWorker::RegisterToGcs() { } void CoreWorker::CheckForRayletFailure() { - bool should_shutdown = false; // When running worker process in container, the worker parent process is not raylet. // So we add RAY_RAYLET_PID enviroment to ray worker process. if (auto env_pid = RayConfig::instance().RAYLET_PID(); !env_pid.empty()) { auto pid = static_cast(std::stoi(env_pid)); if (!IsProcessAlive(pid)) { RAY_LOG(ERROR) << "Raylet failed. Shutting down. Raylet PID: " << pid; - should_shutdown = true; + Shutdown(); } } else if (!IsParentProcessAlive()) { RAY_LOG(ERROR) << "Raylet failed. Shutting down."; - should_shutdown = true; - } - if (should_shutdown) { - if (options_.worker_type == WorkerType::WORKER) { - task_execution_service_.post([this]() { Shutdown(); }, "CoreWorker.Shutdown"); - } else { - Shutdown(); - } + Shutdown(); } } @@ -986,12 +971,6 @@ void CoreWorker::InternalHeartbeat() { direct_actor_submitter_->CheckTimeoutTasks(); } - // Periodically report the lastest backlog so that - // local raylet will have the eventually consistent view of worker backlogs - // even in cases where backlog reports from direct_task_transport - // are lost or reordered. - direct_task_submitter_->ReportWorkerBacklog(); - // Check for unhandled exceptions to raise after a timeout on the driver. // Only do this for TTY, since shells like IPython sometimes save references // to the result and prevent normal result deletion from handling. @@ -1013,6 +992,36 @@ CoreWorker::GetAllReferenceCounts() const { return counts; } +void CoreWorker::PutObjectIntoPlasma(const RayObject &object, const ObjectID &object_id) { + bool object_exists; + // This call will only be used by PromoteObjectToPlasma, which means that the + // object will always owned by us. + RAY_CHECK_OK(plasma_store_provider_->Put( + object, object_id, /* owner_address = */ rpc_address_, &object_exists)); + if (!object_exists) { + // Tell the raylet to pin the object **after** it is created. + RAY_LOG(DEBUG) << "Pinning put object " << object_id; + local_raylet_client_->PinObjectIDs( + rpc_address_, {object_id}, + [this, object_id](const Status &status, const rpc::PinObjectIDsReply &reply) { + // Only release the object once the raylet has responded to avoid the race + // condition that the object could be evicted before the raylet pins it. + if (!plasma_store_provider_->Release(object_id).ok()) { + RAY_LOG(ERROR) << "Failed to release ObjectID (" << object_id + << "), might cause a leak in plasma."; + } + }); + } + RAY_CHECK(memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id)); +} + +void CoreWorker::PromoteObjectToPlasma(const ObjectID &object_id) { + auto value = memory_store_->GetOrPromoteToPlasma(object_id); + if (value) { + PutObjectIntoPlasma(*value, object_id); + } +} + const rpc::Address &CoreWorker::GetRpcAddress() const { return rpc_address_; } rpc::Address CoreWorker::GetOwnerAddress(const ObjectID &object_id) const { @@ -1052,6 +1061,7 @@ void CoreWorker::GetOwnershipInfo(const ObjectID &object_id, rpc::Address *owner "which task will create them. " "If this was not how your object ID was generated, please file an issue " "at https://github.com/ray-project/ray/issues/"; + RAY_LOG(DEBUG) << "Promoted object to plasma " << object_id; rpc::GetObjectStatusReply object_status; // Optimization: if the object exists, serialize and inline its status. This also @@ -1625,37 +1635,6 @@ std::unordered_map AddPlacementGroupConstraint( return resources; } -void CoreWorker::BuildCommonTaskSpec( - TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, - const std::string &name, const TaskID ¤t_task_id, uint64_t task_index, - const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, - const std::vector> &args, uint64_t num_returns, - const std::unordered_map &required_resources, - const std::unordered_map &required_placement_resources, - const BundleID &bundle_id, bool placement_group_capture_child_tasks, - const std::string &debugger_breakpoint, const std::string &serialized_runtime_env, - const std::vector &runtime_env_uris, - const std::string &concurrency_group_name) { - // Build common task spec. - builder.SetCommonTaskSpec( - task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, - current_task_id, task_index, caller_id, address, num_returns, required_resources, - required_placement_resources, bundle_id, placement_group_capture_child_tasks, - debugger_breakpoint, - // TODO(SongGuyang): Move the logic of `prepare_runtime_env` from Python to Core - // Worker. A common process is needed. - // If runtime env is not provided, use job config. Only for Java and C++ because it - // has been set in Python by `prepare_runtime_env`. - (serialized_runtime_env.empty() || serialized_runtime_env == "{}") - ? job_config_->runtime_env().serialized_runtime_env() - : serialized_runtime_env, - runtime_env_uris, concurrency_group_name); - // Set task arguments. - for (const auto &arg : args) { - builder.AddArg(*arg); - } -} - std::vector CoreWorker::SubmitTask( const RayFunction &function, const std::vector> &args, const TaskOptions &task_options, int max_retries, bool retry_exceptions, @@ -1673,13 +1652,21 @@ std::vector CoreWorker::SubmitTask( auto task_name = task_options.name.empty() ? function.GetFunctionDescriptor()->DefaultTaskName() : task_options.name; + // Propagate existing environment variable overrides, but override them with any new + // ones + std::unordered_map current_override_environment_variables = + worker_context_.GetCurrentOverrideEnvironmentVariables(); + std::unordered_map override_environment_variables = + task_options.override_environment_variables; + override_environment_variables.insert(current_override_environment_variables.begin(), + current_override_environment_variables.end()); // TODO(ekl) offload task building onto a thread pool for performance - BuildCommonTaskSpec(builder, worker_context_.GetCurrentJobID(), task_id, task_name, - worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), - rpc_address_, function, args, task_options.num_returns, - constrained_resources, required_resources, placement_options, - placement_group_capture_child_tasks, debugger_breakpoint, - task_options.serialized_runtime_env, task_options.runtime_env_uris); + BuildCommonTaskSpec( + builder, worker_context_.GetCurrentJobID(), task_id, task_name, + worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, + function, args, task_options.num_returns, constrained_resources, required_resources, + placement_options, placement_group_capture_child_tasks, debugger_breakpoint, + task_options.serialized_runtime_env, override_environment_variables); builder.SetNormalTaskSpec(max_retries, retry_exceptions); TaskSpecification task_spec = builder.Build(); RAY_LOG(DEBUG) << "Submit task " << task_spec.DebugString(); @@ -1715,7 +1702,12 @@ Status CoreWorker::CreateActor(const RayFunction &function, const JobID job_id = worker_context_.GetCurrentJobID(); // Propagate existing environment variable overrides, but override them with any new // ones - std::vector return_ids; + std::unordered_map current_override_environment_variables = + worker_context_.GetCurrentOverrideEnvironmentVariables(); + std::unordered_map override_environment_variables = + actor_creation_options.override_environment_variables; + override_environment_variables.insert(current_override_environment_variables.begin(), + current_override_environment_variables.end()); TaskSpecBuilder builder; auto new_placement_resources = AddPlacementGroupConstraint(actor_creation_options.placement_resources, @@ -1736,7 +1728,7 @@ Status CoreWorker::CreateActor(const RayFunction &function, actor_creation_options.placement_group_capture_child_tasks, "", /* debugger_breakpoint */ actor_creation_options.serialized_runtime_env, - actor_creation_options.runtime_env_uris); + override_environment_variables); auto actor_handle = std::make_unique( actor_id, GetCallerId(), rpc_address_, job_id, @@ -1913,6 +1905,7 @@ std::vector CoreWorker::SubmitActorTask( const auto task_name = task_options.name.empty() ? function.GetFunctionDescriptor()->DefaultTaskName() : task_options.name; + const std::unordered_map override_environment_variables = {}; BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id, task_name, worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, function, args, num_returns, task_options.resources, @@ -1920,7 +1913,7 @@ std::vector CoreWorker::SubmitActorTask( true, /* placement_group_capture_child_tasks */ "", /* debugger_breakpoint */ "{}", /* serialized_runtime_env */ - {}, /* runtime_env_uris */ + override_environment_variables, task_options.concurrency_group_name); // NOTE: placement_group_capture_child_tasks and runtime_env will // be ignored in the actor because we should always follow the actor's option. @@ -2191,14 +2184,6 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, task_queue_length_ -= 1; num_executed_tasks_ += 1; - // Modify the worker's per function counters. - std::string func_name = task_spec.FunctionDescriptor()->CallString(); - { - absl::MutexLock l(&task_counter_.tasks_counter_mutex_); - task_counter_.Add(TaskCounter::kPending, func_name, -1); - task_counter_.Add(TaskCounter::kRunning, func_name, 1); - } - if (!options_.is_local_mode) { worker_context_.SetCurrentTask(task_spec); SetCurrentTaskId(task_spec.TaskId()); @@ -2294,16 +2279,8 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, resource_ids_.reset(new ResourceMappingType()); } } - - // Modify the worker's per function counters. - { - absl::MutexLock l(&task_counter_.tasks_counter_mutex_); - task_counter_.Add(TaskCounter::kRunning, func_name, -1); - task_counter_.Add(TaskCounter::kFinished, func_name, 1); - } - - RAY_LOG(DEBUG) << "Finished executing task " << task_spec.TaskId() - << ", status=" << status; + RAY_LOG(INFO) << "Finished executing task " << task_spec.TaskId() + << ", status=" << status; if (status.IsCreationTaskError()) { Exit(rpc::WorkerExitType::CREATION_TASK_ERROR, creation_task_exception_pb_bytes); } else if (status.IsIntentionalSystemExit()) { @@ -2470,15 +2447,8 @@ void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request, return; } - // Increment the task_queue_length and per function counter. + // Increment the task_queue_length task_queue_length_ += 1; - std::string func_name = - FunctionDescriptorBuilder::FromProto(request.task_spec().function_descriptor()) - ->CallString(); - { - absl::MutexLock l(&task_counter_.tasks_counter_mutex_); - task_counter_.Add(TaskCounter::kPending, func_name, 1); - } // For actor tasks, we just need to post a HandleActorTask instance to the task // execution service. @@ -2885,10 +2855,13 @@ void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request, << " has received a force kill request after the cancellation. Killing " "a worker..."; Disconnect(); - // NOTE(hchen): Use `QuickExit()` to force-exit this process without doing cleanup. + if (options_.enable_logging) { + RayLog::ShutDownRayLog(); + } + // NOTE(hchen): Use `_Exit()` to force-exit this process without doing cleanup. // `exit()` will destruct static objects in an incorrect order, which will lead to // core dumps. - QuickExit(options_.enable_logging); + _Exit(1); } } @@ -2921,10 +2894,13 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request, "please create the Java actor with some dynamic options to make it being " "hosted in a dedicated worker process."; } - // NOTE(hchen): Use `QuickExit()` to force-exit this process without doing cleanup. + if (options_.enable_logging) { + RayLog::ShutDownRayLog(); + } + // NOTE(hchen): Use `_Exit()` to force-exit this process without doing cleanup. // `exit()` will destruct static objects in an incorrect order, which will lead to // core dumps. - QuickExit(options_.enable_logging); + _Exit(1); } else { Exit(rpc::WorkerExitType::INTENDED_EXIT); } @@ -3081,16 +3057,15 @@ void CoreWorker::HandleExit(const rpc::ExitRequest &request, rpc::ExitReply *rep // any object pinning RPCs in flight. bool is_idle = !own_objects && pins_in_flight == 0; reply->set_success(is_idle); - send_reply_callback( - Status::OK(), - [this, is_idle]() { - // If the worker is idle, we exit. - if (is_idle) { - Exit(rpc::WorkerExitType::IDLE_EXIT); - } - }, - // We need to kill it regardless if the RPC failed. - [this]() { Exit(rpc::WorkerExitType::INTENDED_EXIT); }); + send_reply_callback(Status::OK(), + [this, is_idle]() { + // If the worker is idle, we exit. + if (is_idle) { + Exit(rpc::WorkerExitType::IDLE_EXIT); + } + }, + // We need to kill it regardless if the RPC failed. + [this]() { Exit(rpc::WorkerExitType::INTENDED_EXIT); }); } void CoreWorker::HandleAssignObjectOwner(const rpc::AssignObjectOwnerRequest &request, @@ -3216,25 +3191,6 @@ std::shared_ptr CoreWorker::GetGcsClient() const { return gcs_cl bool CoreWorker::IsExiting() const { return exiting_; } -std::unordered_map> CoreWorker::GetActorCallStats() - const { - absl::MutexLock l(&task_counter_.tasks_counter_mutex_); - std::unordered_map> total_counts; - - for (const auto &count : task_counter_.pending_tasks_counter_map_) { - total_counts[count.first].resize(3, 0); - total_counts[count.first][0] = count.second; - } - for (const auto &count : task_counter_.running_tasks_counter_map_) { - total_counts[count.first][1] = count.second; - } - for (const auto &count : task_counter_.finished_tasks_counter_map_) { - total_counts[count.first][2] = count.second; - } - - return total_counts; -} - Status CoreWorker::WaitForActorRegistered(const std::vector &ids) { std::vector actor_ids; for (const auto &id : ids) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 883a1b013ff81..3ef1e2476f6d2 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -294,29 +294,23 @@ class CoreWorkerProcess { void InitializeSystemConfig(); - /// Check that if the global worker should be created on construction. - bool ShouldCreateGlobalWorkerOnConstruction() const; - /// Get the `CoreWorker` instance by worker ID. /// /// \param[in] workerId The worker ID. /// \return The `CoreWorker` instance. std::shared_ptr GetWorker(const WorkerID &worker_id) const - LOCKS_EXCLUDED(mutex_); + LOCKS_EXCLUDED(worker_map_mutex_); /// Create a new `CoreWorker` instance. /// /// \return The newly created `CoreWorker` instance. - std::shared_ptr CreateWorker() LOCKS_EXCLUDED(mutex_); + std::shared_ptr CreateWorker() LOCKS_EXCLUDED(worker_map_mutex_); /// Remove an existing `CoreWorker` instance. /// /// \param[in] The existing `CoreWorker` instance. /// \return Void. - void RemoveWorker(std::shared_ptr worker) LOCKS_EXCLUDED(mutex_); - - /// Get the `GlobalWorker` instance, if the number of workers is 1. - std::shared_ptr GetGlobalWorker() LOCKS_EXCLUDED(mutex_); + void RemoveWorker(std::shared_ptr worker) LOCKS_EXCLUDED(worker_map_mutex_); /// The various options. const CoreWorkerOptions options_; @@ -326,16 +320,17 @@ class CoreWorkerProcess { static thread_local std::weak_ptr current_core_worker_; /// The only core worker instance, if the number of workers is 1. - std::shared_ptr global_worker_ GUARDED_BY(mutex_); + std::shared_ptr global_worker_; /// The worker ID of the global worker, if the number of workers is 1. const WorkerID global_worker_id_; /// Map from worker ID to worker. - std::unordered_map> workers_ GUARDED_BY(mutex_); + std::unordered_map> workers_ + GUARDED_BY(worker_map_mutex_); - /// To protect access to workers_ and global_worker_ - mutable absl::Mutex mutex_; + /// To protect accessing the `workers_` map. + mutable absl::Mutex worker_map_mutex_; }; /// The root class that contains all the core and language-independent functionalities @@ -445,6 +440,22 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// (local, submitted_task) reference counts. For debugging purposes. std::unordered_map> GetAllReferenceCounts() const; + /// Put an object into plasma. It's a version of Put that directly put the + /// object into plasma and also pin the object. + /// + /// \param[in] The ray object. + /// \param[in] object_id The object ID to serialize. + /// appended to the serialized object ID. + void PutObjectIntoPlasma(const RayObject &object, const ObjectID &object_id); + + /// Promote an object to plasma. If the + /// object already exists locally, it will be put into the plasma store. If + /// it doesn't yet exist, it will be spilled to plasma once available. + /// + /// \param[in] object_id The object ID to serialize. + /// appended to the serialized object ID. + void PromoteObjectToPlasma(const ObjectID &object_id); + /// Get the RPC address of this worker. /// /// \param[out] The RPC address of this worker. @@ -1033,24 +1044,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Return true if the core worker is in the exit process. bool IsExiting() const; - /// Retrieve the current statistics about tasks being received and executing. - /// \return an unordered_map mapping function name to list of (num_received, - /// num_executing, num_executed). It is a std map instead of absl due to its - /// interface with language bindings. - std::unordered_map> GetActorCallStats() const; - private: - void BuildCommonTaskSpec( - TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, - const std::string &name, const TaskID ¤t_task_id, uint64_t task_index, - const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, - const std::vector> &args, uint64_t num_returns, - const std::unordered_map &required_resources, - const std::unordered_map &required_placement_resources, - const BundleID &bundle_id, bool placement_group_capture_child_tasks, - const std::string &debugger_breakpoint, const std::string &serialized_runtime_env, - const std::vector &runtime_env_uris, - const std::string &concurrency_group_name = ""); void SetCurrentTaskId(const TaskID &task_id); void SetActorId(const ActorID &actor_id); @@ -1372,6 +1366,12 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Number of executed tasks. std::atomic num_executed_tasks_; + /// Event loop where tasks are processed. + instrumented_io_context task_execution_service_; + + /// The asio work to keep task_execution_service_ alive. + boost::asio::io_service::work task_execution_service_work_; + /// Profiler including a background thread that pushes profiling events to the GCS. std::shared_ptr profiler_; @@ -1390,14 +1390,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { // Interface that receives tasks from direct actor calls. std::unique_ptr direct_task_receiver_; - /// Event loop where tasks are processed. - /// task_execution_service_ should be destructed first to avoid - /// issues like https://github.com/ray-project/ray/issues/18857 - instrumented_io_context task_execution_service_; - - /// The asio work to keep task_execution_service_ alive. - boost::asio::io_service::work task_execution_service_work_; - // Queue of tasks to resubmit when the specified time passes. std::deque> to_resubmit_ GUARDED_BY(mutex_); @@ -1416,47 +1408,14 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void PlasmaCallback(SetResultCallback success, std::shared_ptr ray_object, ObjectID object_id, void *py_future); - /// we are shutting down and not running further tasks. - /// when exiting_ is set to true HandlePushTask becomes no-op. - std::atomic exiting_ = false; + /// Whether we are shutting down and not running further tasks. + bool exiting_ = false; int64_t max_direct_call_object_size_; friend class CoreWorkerTest; std::unique_ptr job_config_; - - /// Simple container for per function task counters. The counters will be - /// keyed by the function name in task spec. - struct TaskCounter { - /// A task can only be one of the following state. Received state in particular - /// covers from the point of RPC call to beginning execution. - enum TaskStatusType { kPending, kRunning, kFinished }; - - /// This mutex should be used by caller to ensure consistency when transitioning - /// a task's state. - mutable absl::Mutex tasks_counter_mutex_; - absl::flat_hash_map pending_tasks_counter_map_ - GUARDED_BY(tasks_counter_mutex_); - absl::flat_hash_map running_tasks_counter_map_ - GUARDED_BY(tasks_counter_mutex_); - absl::flat_hash_map finished_tasks_counter_map_ - GUARDED_BY(tasks_counter_mutex_); - - void Add(TaskStatusType type, const std::string &func_name, int value) { - tasks_counter_mutex_.AssertHeld(); - if (type == kPending) { - pending_tasks_counter_map_[func_name] += value; - } else if (type == kRunning) { - running_tasks_counter_map_[func_name] += value; - } else if (type == kFinished) { - finished_tasks_counter_map_[func_name] += value; - } else { - RAY_CHECK(false) << "This line should not be reached."; - } - } - }; - TaskCounter task_counter_; }; } // namespace core diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc index 6af083669f5ed..70a2626847574 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc @@ -217,9 +217,10 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *env, } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *env, jclass, - jbyteArray objectId) { +Java_io_ray_runtime_object_NativeObjectStore_nativePromoteAndGetOwnershipInfo( + JNIEnv *env, jclass, jbyteArray objectId) { auto object_id = JavaByteArrayToId(env, objectId); + CoreWorkerProcess::GetCoreWorker().PromoteObjectToPlasma(object_id); rpc::Address address; // TODO(ekl) send serialized object status to Java land. std::string serialized_object_status; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h index 9358f4473c228..8001bbf20df06 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h @@ -105,12 +105,13 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *, jcl /* * Class: io_ray_runtime_object_NativeObjectStore - * Method: nativeGetOwnershipInfo + * Method: nativePromoteAndGetOwnershipInfo * Signature: ([B)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *, jclass, - jbyteArray); +Java_io_ray_runtime_object_NativeObjectStore_nativePromoteAndGetOwnershipInfo(JNIEnv *, + jclass, + jbyteArray); /* * Class: io_ray_runtime_object_NativeObjectStore diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index 56a0ad473c64d..dd05bc76aa6e0 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -235,9 +235,9 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env, ray_namespace, /*is_asyncio=*/false, placement_options, - /*placement_group_capture_child_tasks=*/true, - /*serialized_runtime_env=*/"{}", - /*runtime_env_uris=*/{}, + true, + "{}", + {}, concurrency_groups}; return actor_creation_options; } diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index e0a9a783dd657..58c67a2010213 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index c265bc7af753b..5877d7f654dfc 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -16,9 +16,9 @@ #include -#include "absl/functional/bind_front.h" #include "gmock/gmock.h" #include "gtest/gtest.h" + #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/asio/periodical_runner.h" #include "ray/common/ray_object.h" @@ -270,7 +270,7 @@ class MockWorkerClient : public MockCoreWorkerClientInterface { auto borrower_callback = [=]() { auto ref_removed_callback = - absl::bind_front(&ReferenceCounter::HandleRefRemoved, &rc_); + boost::bind(&ReferenceCounter::HandleRefRemoved, &rc_, _1); rc_.SetRefRemovedCallback(object_id, contained_in_id, owner_address, ref_removed_callback); }; @@ -656,7 +656,7 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { auto subscriber = std::make_shared(); auto rc = std::shared_ptr(new ReferenceCounter( rpc::WorkerAddress(rpc::Address()), publisher.get(), subscriber.get())); - CoreWorkerMemoryStore store(rc); + CoreWorkerMemoryStore store(nullptr, rc); // Tests putting an object with no references is ignored. RAY_CHECK(store.Put(buffer, id2)); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index b32b612166820..680c9c13616bc 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -139,11 +139,13 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { } CoreWorkerMemoryStore::CoreWorkerMemoryStore( + std::function store_in_plasma, std::shared_ptr counter, std::shared_ptr raylet_client, std::function check_signals, std::function unhandled_exception_handler) - : ref_counter_(std::move(counter)), + : store_in_plasma_(store_in_plasma), + ref_counter_(counter), raylet_client_(raylet_client), check_signals_(check_signals), unhandled_exception_handler_(unhandled_exception_handler) {} @@ -184,6 +186,24 @@ std::shared_ptr CoreWorkerMemoryStore::GetIfExists(const ObjectID &ob return ptr; } +std::shared_ptr CoreWorkerMemoryStore::GetOrPromoteToPlasma( + const ObjectID &object_id) { + absl::MutexLock lock(&mu_); + auto iter = objects_.find(object_id); + if (iter != objects_.end()) { + auto obj = iter->second; + obj->SetAccessed(); + if (obj->IsInPlasmaError()) { + return nullptr; + } + return obj; + } + RAY_CHECK(store_in_plasma_ != nullptr) + << "Cannot promote object without plasma provider callback."; + promoted_to_plasma_.insert(object_id); + return nullptr; +} + bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_id) { std::vector)>> async_callbacks; auto object_entry = std::make_shared(object.GetData(), object.GetMetadata(), @@ -192,6 +212,7 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ // TODO(edoakes): we should instead return a flag to the caller to put the object in // plasma. + bool should_put_in_plasma = false; { absl::MutexLock lock(&mu_); @@ -207,6 +228,15 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ object_async_get_requests_.erase(async_callback_it); } + auto promoted_it = promoted_to_plasma_.find(object_id); + if (promoted_it != promoted_to_plasma_.end()) { + RAY_CHECK(store_in_plasma_ != nullptr); + // Only need to promote to plasma if it wasn't already put into plasma + // by the task that created the object. + should_put_in_plasma = !object.IsInPlasmaError(); + promoted_to_plasma_.erase(promoted_it); + } + bool should_add_entry = true; auto object_request_iter = object_get_requests_.find(object_id); if (object_request_iter != object_get_requests_.end()) { @@ -238,6 +268,14 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ } } + // Must be called without holding the lock because store_in_plasma_ goes + // through the regular CoreWorker::Put() codepath, which calls into the + // in-memory store (would cause deadlock). + if (should_put_in_plasma) { + store_in_plasma_(object, object_id); + stored_in_direct_memory = false; + } + // It's important for performance to run the callbacks outside the lock. for (const auto &cb : async_callbacks) { cb(object_entry); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 70bebac7f01a5..542fac1ea2ea6 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -44,10 +44,12 @@ class CoreWorkerMemoryStore { public: /// Create a memory store. /// + /// \param[in] store_in_plasma If not null, this is used to spill to plasma. /// \param[in] counter If not null, this enables ref counting for local objects, /// and the `remove_after_get` flag for Get() will be ignored. /// \param[in] raylet_client If not null, used to notify tasks blocked / unblocked. CoreWorkerMemoryStore( + std::function store_in_plasma = nullptr, std::shared_ptr counter = nullptr, std::shared_ptr raylet_client = nullptr, std::function check_signals = nullptr, @@ -102,6 +104,14 @@ class CoreWorkerMemoryStore { void GetAsync(const ObjectID &object_id, std::function)> callback); + /// Get a single object if available. If the object is not local yet, or if the object + /// is local but is ErrorType::OBJECT_IN_PLASMA, then nullptr will be returned, and + /// the store will ensure the object is promoted to plasma once available. + /// + /// \param[in] object_id The object id to get. + /// \return pointer to the local object, or nullptr if promoted to plasma. + std::shared_ptr GetOrPromoteToPlasma(const ObjectID &object_id); + /// Delete a list of objects from the object store. /// NOTE(swang): Objects that contain IsInPlasmaError will not be /// deleted from the in-memory store. Instead, any future Get @@ -177,6 +187,9 @@ class CoreWorkerMemoryStore { /// properly. void EraseObjectAndUpdateStats(const ObjectID &object_id) EXCLUSIVE_LOCKS_REQUIRED(mu_); + /// Optional callback for putting objects into the plasma store. + std::function store_in_plasma_; + /// If enabled, holds a reference to local worker ref counter. TODO(ekl) make this /// mandatory once Java is supported. std::shared_ptr ref_counter_ = nullptr; @@ -187,6 +200,9 @@ class CoreWorkerMemoryStore { /// Protects the data structures below. mutable absl::Mutex mu_; + /// Set of objects that should be promoted to plasma once available. + absl::flat_hash_set promoted_to_plasma_ GUARDED_BY(mu_); + /// Map from object ID to `RayObject`. /// NOTE: This map should be modified by EmplaceObjectAndUpdateStats and /// EraseObjectAndUpdateStats. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 3e0ddd631d45f..29d95cb8fa9b8 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include "absl/container/flat_hash_map.h" diff --git a/src/ray/core_worker/test/direct_task_transport_mock_test.cc b/src/ray/core_worker/test/direct_task_transport_mock_test.cc index 8312d79a0bc43..0af5c20c4eb15 100644 --- a/src/ray/core_worker/test/direct_task_transport_mock_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_mock_test.cc @@ -28,7 +28,7 @@ using namespace ::testing; class DirectTaskTransportTest : public ::testing::Test { public: void SetUp() override { - raylet_client = std::make_shared(); + raylet_client = std::make_shared(); task_finisher = std::make_shared(); actor_creator = std::make_shared(); lease_policy = std::make_shared(); @@ -57,7 +57,7 @@ class DirectTaskTransportTest : public ::testing::Test { } std::unique_ptr task_submitter; - std::shared_ptr raylet_client; + std::shared_ptr raylet_client; std::shared_ptr task_finisher; std::shared_ptr actor_creator; std::shared_ptr lease_policy; diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index b631b1d372177..473136255bc72 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -153,19 +153,6 @@ class MockRayletClient : public WorkerLeaseInterface { return Status::OK(); } - void ReportWorkerBacklog( - const WorkerID &worker_id, - const std::vector &backlog_reports) override { - reported_backlog_size = 0; - reported_backlogs.clear(); - for (const auto &backlog_report : backlog_reports) { - reported_backlog_size += backlog_report.backlog_size(); - const TaskSpecification resource_spec(backlog_report.resource_spec()); - const SchedulingClass scheduling_class = resource_spec.GetSchedulingClass(); - reported_backlogs[scheduling_class] = backlog_report.backlog_size(); - } - } - void RequestWorkerLease( const TaskSpecification &resource_spec, const rpc::ClientCallback &callback, @@ -174,14 +161,6 @@ class MockRayletClient : public WorkerLeaseInterface { callbacks.push_back(callback); } - void RequestWorkerLease( - const rpc::TaskSpec &task_spec, - const ray::rpc::ClientCallback &callback, - const int64_t backlog_size = -1) override { - num_workers_requested += 1; - callbacks.push_back(callback); - } - void ReleaseUnusedWorkers( const std::vector &workers_in_use, const rpc::ClientCallback &callback) override {} @@ -243,8 +222,6 @@ class MockRayletClient : public WorkerLeaseInterface { int num_workers_returned = 0; int num_workers_disconnected = 0; int num_leases_canceled = 0; - int reported_backlog_size = 0; - std::map reported_backlogs; std::list> callbacks = {}; std::list> cancel_callbacks = {}; }; @@ -269,18 +246,11 @@ class MockActorCreator : public ActorCreatorInterface { } void AsyncWaitForActorRegisterFinish(const ActorID &, - gcs::StatusCallback callback) override { - callbacks.push_back(callback); - } + gcs::StatusCallback callback) override {} - [[nodiscard]] bool IsActorInRegistering(const ActorID &actor_id) const override { - return actor_pending; - } + bool IsActorInRegistering(const ActorID &actor_id) const override { return false; } ~MockActorCreator() {} - - std::list callbacks; - bool actor_pending = false; }; class MockLeasePolicy : public LeasePolicyInterface { @@ -302,87 +272,40 @@ class MockLeasePolicy : public LeasePolicyInterface { int num_lease_policy_consults = 0; }; -TEST(LocalDependencyResolverTest, TestNoDependencies) { - auto store = std::make_shared(); - auto task_finisher = std::make_shared(); - MockActorCreator actor_creator; - LocalDependencyResolver resolver(*store, *task_finisher, actor_creator); - TaskSpecification task; - bool ok = false; - resolver.ResolveDependencies(task, [&ok](Status) { ok = true; }); - ASSERT_TRUE(ok); - ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); -} - -TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies1) { - // Actor dependency resolved first. - auto store = std::make_shared(); - auto task_finisher = std::make_shared(); - MockActorCreator actor_creator; - LocalDependencyResolver resolver(*store, *task_finisher, actor_creator); - TaskSpecification task; - ObjectID obj = ObjectID::FromRandom(); - task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); - - ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); - ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id); - task.GetMutableMessage().add_args()->add_nested_inlined_refs()->set_object_id( - actor_handle_id.Binary()); - - int num_resolved = 0; - actor_creator.actor_pending = true; - resolver.ResolveDependencies(task, [&](const Status &) { num_resolved++; }); - ASSERT_EQ(num_resolved, 0); - ASSERT_EQ(resolver.NumPendingTasks(), 1); +TEST(TestMemoryStore, TestPromoteToPlasma) { + size_t num_plasma_puts = 0; + auto mem = std::make_shared( + [&](const RayObject &obj, const ObjectID &obj_id) { num_plasma_puts += 1; }); + ObjectID obj1 = ObjectID::FromRandom(); + ObjectID obj2 = ObjectID::FromRandom(); + auto data = GenerateRandomObject(); + ASSERT_TRUE(mem->Put(*data, obj1)); - for (const auto &cb : actor_creator.callbacks) { - cb(Status()); - } - ASSERT_EQ(num_resolved, 0); + // Test getting an already existing object. + ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj1) != nullptr); + ASSERT_TRUE(num_plasma_puts == 0); - std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); - auto metadata = const_cast(reinterpret_cast(meta.data())); - auto meta_buffer = std::make_shared(metadata, meta.size()); - auto data = RayObject(nullptr, meta_buffer, std::vector()); - ASSERT_TRUE(store->Put(data, obj)); - ASSERT_EQ(num_resolved, 1); + // Testing getting an object that doesn't exist yet causes promotion. + ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) == nullptr); + ASSERT_TRUE(num_plasma_puts == 0); + ASSERT_FALSE(mem->Put(*data, obj2)); + ASSERT_TRUE(num_plasma_puts == 1); - ASSERT_EQ(resolver.NumPendingTasks(), 0); + // The next time you get it, it's already there so no need to promote. + ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) != nullptr); + ASSERT_TRUE(num_plasma_puts == 1); } -TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies2) { - // Object dependency resolved first. +TEST(LocalDependencyResolverTest, TestNoDependencies) { auto store = std::make_shared(); auto task_finisher = std::make_shared(); MockActorCreator actor_creator; LocalDependencyResolver resolver(*store, *task_finisher, actor_creator); TaskSpecification task; - ObjectID obj = ObjectID::FromRandom(); - task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); - - ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); - ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id); - task.GetMutableMessage().add_args()->add_nested_inlined_refs()->set_object_id( - actor_handle_id.Binary()); - - int num_resolved = 0; - actor_creator.actor_pending = true; - resolver.ResolveDependencies(task, [&](const Status &) { num_resolved++; }); - ASSERT_EQ(num_resolved, 0); - ASSERT_EQ(resolver.NumPendingTasks(), 1); - - std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); - auto metadata = const_cast(reinterpret_cast(meta.data())); - auto meta_buffer = std::make_shared(metadata, meta.size()); - auto data = RayObject(nullptr, meta_buffer, std::vector()); - ASSERT_EQ(num_resolved, 0); - ASSERT_TRUE(store->Put(data, obj)); - - for (const auto &cb : actor_creator.callbacks) { - cb(Status()); - } - ASSERT_EQ(num_resolved, 1); - ASSERT_EQ(resolver.NumPendingTasks(), 0); + bool ok = false; + resolver.ResolveDependencies(task, [&ok](Status) { ok = true; }); + ASSERT_TRUE(ok); + ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); } TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { @@ -640,78 +563,9 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 2); - - TaskSpecification task1 = BuildEmptyTaskSpec(); - TaskSpecification task2 = BuildEmptyTaskSpec(); - TaskSpecification task3 = BuildEmptyTaskSpec(); - - ASSERT_TRUE(submitter.SubmitTask(task1).ok()); - ASSERT_TRUE(submitter.SubmitTask(task2).ok()); - ASSERT_TRUE(submitter.SubmitTask(task3).ok()); - ASSERT_EQ(lease_policy->num_lease_policy_consults, 2); - ASSERT_EQ(raylet_client->num_workers_requested, 2); - ASSERT_EQ(raylet_client->num_workers_returned, 0); - ASSERT_EQ(raylet_client->reported_backlog_size, 0); - ASSERT_EQ(worker_client->callbacks.size(), 0); - - // Trigger the periodic backlog report - submitter.ReportWorkerBacklog(); - ASSERT_EQ(raylet_client->reported_backlog_size, 1); - - // Task 1 is pushed; worker 3 is requested. - ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, NodeID::Nil())); - ASSERT_EQ(worker_client->callbacks.size(), 1); - ASSERT_EQ(lease_policy->num_lease_policy_consults, 3); - ASSERT_EQ(raylet_client->num_workers_requested, 3); - ASSERT_EQ(raylet_client->reported_backlog_size, 0); - - // Task 2 is pushed; no more workers requested. - ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); - ASSERT_EQ(worker_client->callbacks.size(), 2); - ASSERT_EQ(lease_policy->num_lease_policy_consults, 3); - ASSERT_EQ(raylet_client->num_workers_requested, 3); - ASSERT_EQ(raylet_client->reported_backlog_size, 0); - - // Task 3 is pushed; no more workers requested. - ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, NodeID::Nil())); - ASSERT_EQ(worker_client->callbacks.size(), 3); - ASSERT_EQ(lease_policy->num_lease_policy_consults, 3); - ASSERT_EQ(raylet_client->num_workers_requested, 3); - ASSERT_EQ(raylet_client->reported_backlog_size, 0); - - // All workers returned. - while (!worker_client->callbacks.empty()) { - ASSERT_TRUE(worker_client->ReplyPushTask()); - } - ASSERT_EQ(raylet_client->num_workers_returned, 3); - ASSERT_EQ(raylet_client->num_workers_disconnected, 0); - ASSERT_EQ(task_finisher->num_tasks_complete, 3); - ASSERT_EQ(task_finisher->num_tasks_failed, 0); - ASSERT_EQ(raylet_client->num_leases_canceled, 0); - ASSERT_EQ(raylet_client->reported_backlog_size, 0); - ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); - - // Check that there are no entries left in the scheduling_key_entries_ hashmap. These - // would otherwise cause a memory leak. - ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic()); -} - -TEST(DirectTaskTransportTest, TestSubmitMultipleTasks) { - rpc::Address address; - auto raylet_client = std::make_shared(); - auto worker_client = std::make_shared(); - auto store = std::make_shared(); - auto client_pool = std::make_shared( - [&](const rpc::Address &addr) { return worker_client; }); - auto task_finisher = std::make_shared(); - auto actor_creator = std::make_shared(); - auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -722,21 +576,18 @@ TEST(DirectTaskTransportTest, TestSubmitMultipleTasks) { ASSERT_TRUE(submitter.SubmitTask(task3).ok()); ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); ASSERT_EQ(raylet_client->num_workers_requested, 1); - ASSERT_EQ(raylet_client->reported_backlog_size, 0); // Task 1 is pushed; worker 2 is requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 1); ASSERT_EQ(lease_policy->num_lease_policy_consults, 2); ASSERT_EQ(raylet_client->num_workers_requested, 2); - ASSERT_EQ(raylet_client->reported_backlog_size, 1); // Task 2 is pushed; worker 3 is requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 2); ASSERT_EQ(lease_policy->num_lease_policy_consults, 3); ASSERT_EQ(raylet_client->num_workers_requested, 3); - ASSERT_EQ(raylet_client->reported_backlog_size, 0); // Task 3 is pushed; no more workers requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, NodeID::Nil())); @@ -753,7 +604,6 @@ TEST(DirectTaskTransportTest, TestSubmitMultipleTasks) { ASSERT_EQ(task_finisher->num_tasks_complete, 3); ASSERT_EQ(task_finisher->num_tasks_failed, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 0); - ASSERT_EQ(raylet_client->reported_backlog_size, 0); ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); // Check that there are no entries left in the scheduling_key_entries_ hashmap. These @@ -771,9 +621,9 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -834,9 +684,9 @@ TEST(DirectTaskTransportTest, TestRetryLeaseCancellation) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); TaskSpecification task3 = BuildEmptyTaskSpec(); @@ -947,9 +797,9 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -1163,9 +1013,9 @@ void TestSchedulingKey(const std::shared_ptr store, auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); ASSERT_TRUE(submitter.SubmitTask(same1).ok()); ASSERT_TRUE(submitter.SubmitTask(same2).ok()); @@ -1280,65 +1130,6 @@ TEST(DirectTaskTransportTest, TestSchedulingKeys) { TestSchedulingKey(store, same_deps_1, same_deps_2, different_deps); } -TEST(DirectTaskTransportTest, TestBacklogReport) { - rpc::Address address; - auto raylet_client = std::make_shared(); - auto worker_client = std::make_shared(); - auto store = std::make_shared(); - auto client_pool = std::make_shared( - [&](const rpc::Address &addr) { return worker_client; }); - auto task_finisher = std::make_shared(); - auto actor_creator = std::make_shared(); - auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, 1, absl::nullopt, 1); - - TaskSpecification task1 = BuildEmptyTaskSpec(); - - std::unordered_map resources1({{"a", 1.0}}); - std::unordered_map resources2({{"b", 2.0}}); - FunctionDescriptor descriptor1 = - FunctionDescriptorBuilder::BuildPython("a", "", "", ""); - FunctionDescriptor descriptor2 = - FunctionDescriptorBuilder::BuildPython("b", "", "", ""); - ObjectID plasma1 = ObjectID::FromRandom(); - ObjectID plasma2 = ObjectID::FromRandom(); - // Force plasma objects to be promoted. - std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); - auto metadata = const_cast(reinterpret_cast(meta.data())); - auto meta_buffer = std::make_shared(metadata, meta.size()); - auto plasma_data = RayObject(nullptr, meta_buffer, std::vector()); - ASSERT_TRUE(store->Put(plasma_data, plasma1)); - ASSERT_TRUE(store->Put(plasma_data, plasma2)); - - // Same SchedulingClass, different SchedulingKey - TaskSpecification task2 = BuildTaskSpec(resources1, descriptor1); - task2.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( - plasma1.Binary()); - TaskSpecification task3 = BuildTaskSpec(resources1, descriptor1); - task3.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( - plasma2.Binary()); - TestSchedulingKey(store, task2, task2, task3); - - TaskSpecification task4 = BuildTaskSpec(resources2, descriptor2); - - ASSERT_TRUE(submitter.SubmitTask(task1).ok()); - // One is requested and one is in the backlog for each SchedulingKey - ASSERT_TRUE(submitter.SubmitTask(task2).ok()); - ASSERT_TRUE(submitter.SubmitTask(task2).ok()); - ASSERT_TRUE(submitter.SubmitTask(task3).ok()); - ASSERT_TRUE(submitter.SubmitTask(task3).ok()); - ASSERT_TRUE(submitter.SubmitTask(task4).ok()); - ASSERT_TRUE(submitter.SubmitTask(task4).ok()); - - submitter.ReportWorkerBacklog(); - ASSERT_EQ(raylet_client->reported_backlogs.size(), 3); - ASSERT_EQ(raylet_client->reported_backlogs[task1.GetSchedulingClass()], 0); - ASSERT_EQ(raylet_client->reported_backlogs[task2.GetSchedulingClass()], 2); - ASSERT_EQ(raylet_client->reported_backlogs[task4.GetSchedulingClass()], 1); -} - TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { rpc::Address address; auto raylet_client = std::make_shared(); @@ -1349,10 +1140,10 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), - /*lease_timeout_ms=*/5, actor_creator, 1, absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, + lease_policy, store, task_finisher, + NodeID::Nil(), + /*lease_timeout_ms=*/5, actor_creator); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); TaskSpecification task3 = BuildEmptyTaskSpec(); @@ -1531,8 +1322,7 @@ TEST(DirectTaskTransportTest, TestPipeliningConcurrentWorkerLeases) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker, - absl::nullopt, 1); + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); // Prepare 20 tasks and save them in a vector. std::vector tasks; @@ -1606,8 +1396,7 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker, - absl::nullopt, 2); + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); // prepare 30 tasks and save them in a vector std::vector tasks; @@ -1616,16 +1405,16 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) { } ASSERT_EQ(tasks.size(), 30); - // Submit the 30 tasks and check that two workers are requested + // Submit the 30 tasks and check that one worker is requested for (auto task : tasks) { ASSERT_TRUE(submitter.SubmitTask(task).ok()); } - ASSERT_EQ(raylet_client->num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 1); // Task 1-10 are pushed, and a new worker is requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 10); - ASSERT_EQ(raylet_client->num_workers_requested, 3); + ASSERT_EQ(raylet_client->num_workers_requested, 2); // The lease is not cancelled, as there is more work to do ASSERT_EQ(raylet_client->num_leases_canceled, 0); @@ -1652,7 +1441,7 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) { ASSERT_EQ(raylet_client->num_workers_returned, 1); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(task_finisher->num_tasks_complete, 30); - ASSERT_EQ(raylet_client->num_leases_canceled, 2); + ASSERT_EQ(raylet_client->num_leases_canceled, 1); ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease()); // The second lease request is returned immediately. @@ -1662,19 +1451,8 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) { ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 30); ASSERT_EQ(task_finisher->num_tasks_failed, 0); - ASSERT_EQ(raylet_client->num_workers_requested, 3); - ASSERT_EQ(raylet_client->num_leases_canceled, 3); - ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease()); - - // The third lease request is returned immediately. - ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); - ASSERT_EQ(worker_client->callbacks.size(), 0); - ASSERT_EQ(raylet_client->num_workers_returned, 3); - ASSERT_EQ(raylet_client->num_workers_disconnected, 0); - ASSERT_EQ(task_finisher->num_tasks_complete, 30); - ASSERT_EQ(task_finisher->num_tasks_failed, 0); - ASSERT_EQ(raylet_client->num_leases_canceled, 3); - ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease()); + ASSERT_EQ(raylet_client->num_leases_canceled, 1); + ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); // Check that there are no entries left in the scheduling_key_entries_ hashmap. These // would otherwise cause a memory leak. @@ -1698,8 +1476,7 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker, - absl::nullopt, 1); + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); // prepare 30 tasks and save them in a vector std::vector tasks; @@ -1884,8 +1661,7 @@ TEST(DirectTaskTransportTest, TestStealingTasks) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker, - absl::nullopt, 1); + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); // prepare 20 tasks and save them in a vector std::vector tasks; @@ -2065,8 +1841,7 @@ TEST(DirectTaskTransportTest, TestNoStealingByExpiredWorker) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), 1000, actor_creator, max_tasks_in_flight_per_worker, absl::nullopt, - 1); + NodeID::Nil(), 1000, actor_creator, max_tasks_in_flight_per_worker); // prepare 30 tasks and save them in a vector std::vector tasks; @@ -2204,24 +1979,23 @@ TEST(DirectTaskTransportTest, TestNoWorkerRequestedIfStealingUnavailable) { uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker, - absl::nullopt, 2); + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); - // prepare 10 tasks and save them in a vector + // prepare 2 tasks and save them in a vector std::vector tasks; for (int i = 0; i < 10; i++) { tasks.push_back(BuildEmptyTaskSpec()); } ASSERT_EQ(tasks.size(), 10); - // submit all tasks + // submit both tasks for (int i = 1; i <= 10; i++) { auto task = tasks.front(); ASSERT_TRUE(submitter.SubmitTask(task).ok()); tasks.erase(tasks.begin()); } ASSERT_EQ(tasks.size(), 0); - ASSERT_EQ(raylet_client->num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 1); ASSERT_EQ(task_finisher->num_tasks_complete, 0); ASSERT_EQ(task_finisher->num_tasks_failed, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 0); @@ -2232,7 +2006,7 @@ TEST(DirectTaskTransportTest, TestNoWorkerRequestedIfStealingUnavailable) { std::string worker1_id = "worker1_ID_abcdefghijklmnopq"; ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil(), false, worker1_id)); - ASSERT_EQ(raylet_client->num_workers_requested, 3); + ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 0); @@ -2246,7 +2020,7 @@ TEST(DirectTaskTransportTest, TestNoWorkerRequestedIfStealingUnavailable) { ASSERT_TRUE(worker_client->ReplyPushTask()); } - ASSERT_EQ(raylet_client->num_workers_requested, 3); + ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 9); @@ -2262,23 +2036,23 @@ TEST(DirectTaskTransportTest, TestNoWorkerRequestedIfStealingUnavailable) { worker2_id)); // Check that no more workers are requested now that there are no more stealable tasks. - ASSERT_EQ(raylet_client->num_workers_requested, 3); + ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_returned, 1); ASSERT_EQ(task_finisher->num_tasks_complete, 9); ASSERT_EQ(task_finisher->num_tasks_failed, 0); - ASSERT_EQ(raylet_client->num_leases_canceled, 1); + ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(worker_client->callbacks.size(), 1); ASSERT_EQ(worker_client->steal_callbacks.size(), 0); // Last task runs and first worker is returned ASSERT_TRUE(worker_client->ReplyPushTask()); - ASSERT_EQ(raylet_client->num_workers_requested, 3); + ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_returned, 2); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 10); ASSERT_EQ(task_finisher->num_tasks_failed, 0); - ASSERT_EQ(raylet_client->num_leases_canceled, 2); + ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(worker_client->steal_callbacks.size(), 0); } diff --git a/src/ray/core_worker/test/memory_store_test.cc b/src/ray/core_worker/test/memory_store_test.cc index feee9973db850..84a7c8f7996ac 100644 --- a/src/ray/core_worker/test/memory_store_test.cc +++ b/src/ray/core_worker/test/memory_store_test.cc @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/synchronization/mutex.h" + #include "ray/core_worker/store_provider/memory_store/memory_store.h" -#include "absl/synchronization/mutex.h" #include "gtest/gtest.h" #include "ray/common/test_util.h" @@ -28,7 +29,8 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { std::shared_ptr provider = std::make_shared( - nullptr, nullptr, nullptr, [&](const RayObject &obj) { unhandled_count++; }); + nullptr, nullptr, nullptr, nullptr, + [&](const RayObject &obj) { unhandled_count++; }); RayObject obj1(rpc::ErrorType::TASK_EXECUTION_EXCEPTION); RayObject obj2(rpc::ErrorType::TASK_EXECUTION_EXCEPTION); auto id1 = ObjectID::FromRandom(); @@ -50,7 +52,7 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { RAY_CHECK(provider->Put(obj1, id1)); RAY_CHECK(provider->Put(obj1, id2)); RAY_UNUSED(provider->Get({id1}, 1, 100, context, false, &results)); - RAY_UNUSED(provider->Get({id2}, 1, 100, context, false, &results)); + provider->GetOrPromoteToPlasma(id2); provider->Delete({id1, id2}); ASSERT_EQ(unhandled_count, 0); @@ -66,7 +68,8 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { TEST(TestMemoryStore, TestMemoryStoreStats) { /// Simple validation for test memory store stats. std::shared_ptr provider = - std::make_shared(nullptr, nullptr, nullptr, nullptr); + std::make_shared(nullptr, nullptr, nullptr, nullptr, + nullptr); // Iterate through the memory store and compare the values that are obtained by // GetMemoryStoreStatisticalData. diff --git a/src/ray/core_worker/transport/dependency_resolver.cc b/src/ray/core_worker/transport/dependency_resolver.cc index da52aff657627..3948c3732f1c4 100644 --- a/src/ray/core_worker/transport/dependency_resolver.cc +++ b/src/ray/core_worker/transport/dependency_resolver.cc @@ -139,13 +139,11 @@ void LocalDependencyResolver::ResolveDependencies( for (const auto &actor_id : state->actor_dependencies) { actor_creator_.AsyncWaitForActorRegisterFinish( - actor_id, [this, state, on_complete](const Status &status) { + actor_id, [state, on_complete](Status status) { if (!status.ok()) { state->status = status; } - if (--state->actor_dependencies_remaining == 0 && - state->obj_dependencies_remaining == 0) { - num_pending_--; + if (--state->actor_dependencies_remaining == 0) { on_complete(state->status); } }); diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index 912abca3b7aed..508f205337566 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -374,14 +374,15 @@ void CoreWorkerDirectTaskSubmitter::CancelWorkerLeaseIfNeeded( RAY_LOG(DEBUG) << "Task queue is empty, and there are no stealable tasks; canceling lease request"; - for (auto &pending_lease_request : scheduling_key_entry.pending_lease_requests) { + auto &pending_lease_request = scheduling_key_entry.pending_lease_request; + if (pending_lease_request.first) { // There is an in-flight lease request. Cancel it. - auto lease_client = GetOrConnectLeaseClient(&pending_lease_request.second); - auto &task_id = pending_lease_request.first; - RAY_LOG(DEBUG) << "Canceling lease request " << task_id; + auto &lease_client = pending_lease_request.first; + auto &lease_id = pending_lease_request.second; + RAY_LOG(DEBUG) << "Canceling lease request " << lease_id; lease_client->CancelWorkerLease( - task_id, [this, scheduling_key](const Status &status, - const rpc::CancelWorkerLeaseReply &reply) { + lease_id, [this, scheduling_key](const Status &status, + const rpc::CancelWorkerLeaseReply &reply) { absl::MutexLock lock(&mu_); if (status.ok() && !reply.success()) { // The cancellation request can fail if the raylet does not have @@ -422,58 +423,15 @@ CoreWorkerDirectTaskSubmitter::GetOrConnectLeaseClient( return lease_client; } -void CoreWorkerDirectTaskSubmitter::ReportWorkerBacklog() { - absl::MutexLock lock(&mu_); - ReportWorkerBacklogInternal(); -} - -void CoreWorkerDirectTaskSubmitter::ReportWorkerBacklogInternal() { - absl::flat_hash_map> backlogs; - for (auto &scheduling_key_and_entry : scheduling_key_entries_) { - const SchedulingClass scheduling_class = std::get<0>(scheduling_key_and_entry.first); - if (backlogs.find(scheduling_class) == backlogs.end()) { - backlogs[scheduling_class].first = scheduling_key_and_entry.second.resource_spec; - backlogs[scheduling_class].second = 0; - } - // We report backlog size per scheduling class not per scheduling key - // so we need to aggregate backlog sizes of different scheduling keys - // with the same scheduling class - backlogs[scheduling_class].second += scheduling_key_and_entry.second.BacklogSize(); - scheduling_key_and_entry.second.last_reported_backlog_size = - scheduling_key_and_entry.second.BacklogSize(); - } - - std::vector backlog_reports; - for (const auto &backlog : backlogs) { - rpc::WorkerBacklogReport backlog_report; - backlog_report.mutable_resource_spec()->CopyFrom(backlog.second.first.GetMessage()); - backlog_report.set_backlog_size(backlog.second.second); - backlog_reports.emplace_back(backlog_report); - } - local_lease_client_->ReportWorkerBacklog(WorkerID::FromBinary(rpc_address_.worker_id()), - backlog_reports); -} - -void CoreWorkerDirectTaskSubmitter::ReportWorkerBacklogIfNeeded( - const SchedulingKey &scheduling_key) { - const auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; - - if (scheduling_key_entry.last_reported_backlog_size != - scheduling_key_entry.BacklogSize()) { - ReportWorkerBacklogInternal(); - } -} - void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( const SchedulingKey &scheduling_key, const rpc::Address *raylet_address) { auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; + auto &pending_lease_request = scheduling_key_entry.pending_lease_request; - if (scheduling_key_entry.pending_lease_requests.size() == - max_pending_lease_requests_per_scheduling_category_) { + if (pending_lease_request.first) { + // There's already an outstanding lease request for this type of task. return; } - RAY_CHECK(scheduling_key_entry.pending_lease_requests.size() < - max_pending_lease_requests_per_scheduling_category_); // Check whether we really need a new worker or whether we have // enough room in an existing worker's pipeline to send the new tasks. If the pipelines @@ -486,7 +444,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( return; } - const auto &task_queue = scheduling_key_entry.task_queue; + auto &task_queue = scheduling_key_entry.task_queue; // Check if the task queue is empty. If that is the case, it only makes sense to // consider requesting a new worker if work stealing is enabled, and there is at least a // worker with stealable tasks. If work stealing is not enabled, or there is no tasks @@ -503,18 +461,15 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( } return; } - } else if (scheduling_key_entry.task_queue.size() <= - scheduling_key_entry.pending_lease_requests.size()) { - // All tasks have corresponding pending leases, no need to request more - return; } - num_leases_requested_++; // Create a TaskSpecification with an overwritten TaskID to make sure we don't reuse the // same TaskID to request a worker + num_leases_requested_++; auto resource_spec_msg = scheduling_key_entry.resource_spec.GetMutableMessage(); resource_spec_msg.set_task_id(TaskID::ForFakeTask().Binary()); - const TaskSpecification resource_spec = TaskSpecification(resource_spec_msg); + TaskSpecification resource_spec = TaskSpecification(resource_spec_msg); + rpc::Address best_node_address; if (raylet_address == nullptr) { // If no raylet address is given, find the best worker for our next lease request. @@ -523,17 +478,22 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( } auto lease_client = GetOrConnectLeaseClient(raylet_address); - const TaskID task_id = resource_spec.TaskId(); + TaskID task_id = resource_spec.TaskId(); + // Subtract 1 so we don't double count the task we are requesting for. + int64_t queue_size = task_queue.size() - 1; lease_client->RequestWorkerLease( resource_spec, - [this, scheduling_key, task_id, raylet_address = *raylet_address]( - const Status &status, const rpc::RequestWorkerLeaseReply &reply) { + [this, scheduling_key](const Status &status, + const rpc::RequestWorkerLeaseReply &reply) { absl::MutexLock lock(&mu_); auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; - auto lease_client = GetOrConnectLeaseClient(&raylet_address); - scheduling_key_entry.pending_lease_requests.erase(task_id); + auto &pending_lease_request = scheduling_key_entry.pending_lease_request; + RAY_CHECK(pending_lease_request.first); + auto lease_client = std::move(pending_lease_request.first); + const auto task_id = pending_lease_request.second; + pending_lease_request = std::make_pair(nullptr, TaskID::Nil()); if (status.ok()) { if (reply.runtime_env_setup_failed()) { @@ -591,9 +551,8 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( RAY_LOG(FATAL) << status.ToString(); } }, - task_queue.size()); - scheduling_key_entry.pending_lease_requests.emplace(task_id, *raylet_address); - ReportWorkerBacklogIfNeeded(scheduling_key); + queue_size); + pending_lease_request = std::make_pair(lease_client, task_id); } void CoreWorkerDirectTaskSubmitter::PushNormalTask( diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 7731d95ad6626..25f7a18912795 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -66,9 +66,7 @@ class CoreWorkerDirectTaskSubmitter { int64_t lease_timeout_ms, std::shared_ptr actor_creator, uint32_t max_tasks_in_flight_per_worker = ::RayConfig::instance().max_tasks_in_flight_per_worker(), - absl::optional cancel_timer = absl::nullopt, - uint64_t max_pending_lease_requests_per_scheduling_category = - ::RayConfig::instance().max_pending_lease_requests_per_scheduling_category()) + absl::optional cancel_timer = absl::nullopt) : rpc_address_(rpc_address), local_lease_client_(lease_client), lease_client_factory_(lease_client_factory), @@ -80,8 +78,6 @@ class CoreWorkerDirectTaskSubmitter { actor_creator_(actor_creator), client_cache_(core_worker_client_pool), max_tasks_in_flight_per_worker_(max_tasks_in_flight_per_worker), - max_pending_lease_requests_per_scheduling_category_( - max_pending_lease_requests_per_scheduling_category), cancel_retry_timer_(std::move(cancel_timer)) {} /// Schedule a task for direct submission to a worker. @@ -111,11 +107,6 @@ class CoreWorkerDirectTaskSubmitter { return num_leases_requested_; } - /// Report worker backlog information to the local raylet. - /// Since each worker only reports to its local rayet - /// we avoid double counting backlogs in autoscaler. - void ReportWorkerBacklog(); - private: /// Schedule more work onto an idle worker or return it back to the raylet if /// no more tasks are queued for submission. If an error was encountered @@ -136,14 +127,6 @@ class CoreWorkerDirectTaskSubmitter { std::shared_ptr GetOrConnectLeaseClient( const rpc::Address *raylet_address) EXCLUSIVE_LOCKS_REQUIRED(mu_); - /// Report worker backlog information to the local raylet - void ReportWorkerBacklogInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_); - - /// Report backlog if the backlog size is changed for this scheduling key - /// since last report - void ReportWorkerBacklogIfNeeded(const SchedulingKey &scheduling_key) - EXCLUSIVE_LOCKS_REQUIRED(mu_); - /// Request a new worker from the raylet if no such requests are currently in /// flight and there are tasks queued. If a raylet address is provided, then /// the worker should be requested from the raylet at that address. Else, the @@ -254,9 +237,6 @@ class CoreWorkerDirectTaskSubmitter { // worker using a single lease. const uint32_t max_tasks_in_flight_per_worker_; - // Max number of pending lease requests per SchedulingKey. - const uint64_t max_pending_lease_requests_per_scheduling_category_; - /// A LeaseEntry struct is used to condense the metadata about a single executor: /// (1) The lease client through which the worker should be returned /// (2) The expiration time of a worker's lease. @@ -316,7 +296,8 @@ class CoreWorkerDirectTaskSubmitter { struct SchedulingKeyEntry { // Keep track of pending worker lease requests to the raylet. - absl::flat_hash_map pending_lease_requests; + std::pair, TaskID> pending_lease_request = + std::make_pair(nullptr, TaskID::Nil()); TaskSpecification resource_spec = TaskSpecification(); // Tasks that are queued for execution. We keep an individual queue per // scheduling class to ensure fairness. @@ -327,12 +308,11 @@ class CoreWorkerDirectTaskSubmitter { absl::flat_hash_set(); // Keep track of how many tasks with this SchedulingKey are in flight, in total uint32_t total_tasks_in_flight = 0; - int64_t last_reported_backlog_size = 0; // Check whether it's safe to delete this SchedulingKeyEntry from the // scheduling_key_entries_ hashmap. inline bool CanDelete() const { - if (pending_lease_requests.empty() && task_queue.empty() && + if (!pending_lease_request.first && task_queue.empty() && active_workers.size() == 0 && total_tasks_in_flight == 0) { return true; } @@ -359,18 +339,6 @@ class CoreWorkerDirectTaskSubmitter { // If any worker has more than one task in flight, then that task can be stolen. return total_tasks_in_flight > active_workers.size(); } - - // Get the current backlog size for this scheduling key - [[nodiscard]] inline int64_t BacklogSize() const { - if (task_queue.size() < pending_lease_requests.size()) { - // During work stealing we may have more pending lease requests than the number of - // queued tasks - return 0; - } - - // Subtract tasks with pending lease requests so we don't double count them. - return task_queue.size() - pending_lease_requests.size(); - } }; // For each Scheduling Key, scheduling_key_entries_ contains a SchedulingKeyEntry struct diff --git a/src/ray/gcs/asio.h b/src/ray/gcs/asio.h index d37083986ae1e..fdcbbbf3cc3ef 100644 --- a/src/ray/gcs/asio.h +++ b/src/ray/gcs/asio.h @@ -38,7 +38,7 @@ #include #include -#include +#include #include #include diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index 6e54cb6b4f047..e3bdcd96d79ab 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -731,7 +731,7 @@ Status ServiceBasedNodeResourceInfoAccessor::AsyncUpdateResources( }); }; - sequencer_.Post(node_id, std::move(operation)); + sequencer_.Post(node_id, operation); return Status::OK(); } diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc index a4950fabb0f14..223ee7ca71b52 100644 --- a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -36,7 +36,6 @@ class GlobalStateAccessorTest : public ::testing::Test { config.grpc_server_name = "MockedGcsServer"; config.grpc_server_thread_num = 1; config.redis_address = "127.0.0.1"; - config.node_ip_address = "127.0.0.1"; config.enable_sharding_conn = false; config.redis_port = TEST_REDIS_SERVER_PORTS.front(); diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc index 0adf74b5c4e8b..0e51ca7b84cce 100644 --- a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -47,7 +47,6 @@ class ServiceBasedGcsClientTest : public ::testing::Test { config_.grpc_server_name = "MockedGcsServer"; config_.grpc_server_thread_num = 1; config_.redis_address = "127.0.0.1"; - config_.node_ip_address = "127.0.0.1"; config_.enable_sharding_conn = false; config_.redis_port = TEST_REDIS_SERVER_PORTS.front(); // Tests legacy code paths. The poller and broadcaster have their own dedicated unit diff --git a/src/ray/gcs/gcs_server/gcs_actor_distribution.cc b/src/ray/gcs/gcs_server/gcs_actor_distribution.cc index bec0fb7b89f7c..6eb523cdf730b 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_distribution.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_distribution.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "ray/gcs/gcs_server/gcs_actor_distribution.h" - #include "ray/util/event.h" namespace ray { @@ -50,9 +49,6 @@ GcsBasedActorScheduler::GcsBasedActorScheduler( gcs_resource_scheduler_(std::move(gcs_resource_scheduler)) {} NodeID GcsBasedActorScheduler::SelectNode(std::shared_ptr actor) { - if (actor->GetActorWorkerAssignment()) { - ResetActorWorkerAssignment(actor.get()); - } // TODO(Chong-Li): Java actors may not need a sole assignment (worker process). bool need_sole_actor_worker_assignment = true; if (auto selected_actor_worker_assignment = SelectOrAllocateActorWorkerAssignment( @@ -225,31 +221,5 @@ void GcsBasedActorScheduler::HandleWorkerLeaseRejectedReply( Reschedule(actor); } -void GcsBasedActorScheduler::AddResourcesChangedListener(std::function listener) { - RAY_CHECK(listener != nullptr); - resource_changed_listeners_.emplace_back(std::move(listener)); -} - -void GcsBasedActorScheduler::NotifyClusterResourcesChanged() { - for (auto &listener : resource_changed_listeners_) { - listener(); - } -} - -void GcsBasedActorScheduler::ResetActorWorkerAssignment(GcsActor *actor) { - if (gcs_resource_manager_->ReleaseResources( - actor->GetActorWorkerAssignment()->GetNodeID(), - actor->GetActorWorkerAssignment()->GetResources())) { - NotifyClusterResourcesChanged(); - }; - actor->SetActorWorkerAssignment(nullptr); -} - -void GcsBasedActorScheduler::OnActorDestruction(std::shared_ptr actor) { - if (actor && actor->GetActorWorkerAssignment()) { - ResetActorWorkerAssignment(actor.get()); - } -} - } // namespace gcs } // namespace ray \ No newline at end of file diff --git a/src/ray/gcs/gcs_server/gcs_actor_distribution.h b/src/ray/gcs/gcs_server/gcs_actor_distribution.h index 55f0f492e9a74..b8e2b6b2bd6d4 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_distribution.h +++ b/src/ray/gcs/gcs_server/gcs_actor_distribution.h @@ -93,14 +93,6 @@ class GcsBasedActorScheduler : public GcsActorScheduler { virtual ~GcsBasedActorScheduler() = default; - /// Handle the destruction of an actor. - /// - /// \param actor The actor to be destoryed. - void OnActorDestruction(std::shared_ptr actor) override; - - /// Add resources changed event handler. - void AddResourcesChangedListener(std::function listener); - protected: /// Select a node for the actor based on cluster resources. /// @@ -151,17 +143,8 @@ class GcsBasedActorScheduler : public GcsActorScheduler { void HandleWorkerLeaseRejectedReply(std::shared_ptr actor, const rpc::RequestWorkerLeaseReply &reply); - /// Reset the actor's current assignment, while releasing acquired resources. - void ResetActorWorkerAssignment(GcsActor *actor); - - /// Notify that the cluster resources are changed. - void NotifyClusterResourcesChanged(); - std::shared_ptr gcs_resource_manager_; - /// The resource changed listeners. - std::vector> resource_changed_listeners_; - /// Gcs resource scheduler std::shared_ptr gcs_resource_scheduler_; }; diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index 48a469f1cb377..c9f8c62375a6f 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -107,7 +107,6 @@ void GcsActor::SetActorWorkerAssignment( ///////////////////////////////////////////////////////////////////////////////////////// GcsActorManager::GcsActorManager( - boost::asio::io_context &io_context, std::shared_ptr scheduler, std::shared_ptr gcs_table_storage, std::shared_ptr gcs_pub_sub, RuntimeEnvManager &runtime_env_manager, @@ -116,8 +115,7 @@ GcsActorManager::GcsActorManager( std::function, boost::posix_time::milliseconds)> run_delayed, const rpc::ClientFactoryFn &worker_client_factory) - : io_context_(io_context), - gcs_actor_scheduler_(std::move(scheduler)), + : gcs_actor_scheduler_(std::move(scheduler)), gcs_table_storage_(std::move(gcs_table_storage)), gcs_pub_sub_(std::move(gcs_pub_sub)), worker_client_factory_(worker_client_factory), @@ -128,17 +126,6 @@ GcsActorManager::GcsActorManager( actor_gc_delay_(RayConfig::instance().gcs_actor_table_min_duration_ms()) { RAY_CHECK(worker_client_factory_); RAY_CHECK(destroy_owned_placement_group_if_needed_); - if (RayConfig::instance().gcs_actor_scheduling_enabled()) { - auto gcs_actor_scheduler = - std::dynamic_pointer_cast(gcs_actor_scheduler_); - gcs_actor_scheduler->AddResourcesChangedListener([this] { - bool posted = GetSchedulePendingActorsPosted(); - if (!posted) { - SetSchedulePendingActorsPosted(true); - io_context_.post([this] { SchedulePendingActors(); }); - } - }); - } } void GcsActorManager::HandleRegisterActor(const rpc::RegisterActorRequest &request, @@ -200,13 +187,13 @@ void GcsActorManager::HandleGetActorInfo(const rpc::GetActorInfoRequest &request const auto ®istered_actor_iter = registered_actors_.find(actor_id); if (registered_actor_iter != registered_actors_.end()) { - reply->unsafe_arena_set_allocated_actor_table_data( - registered_actor_iter->second->GetMutableActorTableData()); + reply->mutable_actor_table_data()->CopyFrom( + registered_actor_iter->second->GetActorTableData()); } else { const auto &destroyed_actor_iter = destroyed_actors_.find(actor_id); if (destroyed_actor_iter != destroyed_actors_.end()) { - reply->unsafe_arena_set_allocated_actor_table_data( - destroyed_actor_iter->second->GetMutableActorTableData()); + reply->mutable_actor_table_data()->CopyFrom( + destroyed_actor_iter->second->GetActorTableData()); } } @@ -223,12 +210,10 @@ void GcsActorManager::HandleGetAllActorInfo(const rpc::GetAllActorInfoRequest &r ++counts_[CountType::GET_ALL_ACTOR_INFO_REQUEST]; if (request.show_dead_jobs() == false) { for (const auto &iter : registered_actors_) { - reply->mutable_actor_table_data()->UnsafeArenaAddAllocated( - const_cast(iter.second->GetMutableActorTableData())); + reply->add_actor_table_data()->CopyFrom(iter.second->GetActorTableData()); } for (const auto &iter : destroyed_actors_) { - reply->mutable_actor_table_data()->UnsafeArenaAddAllocated( - const_cast(iter.second->GetMutableActorTableData())); + reply->add_actor_table_data()->CopyFrom(iter.second->GetActorTableData()); } RAY_LOG(DEBUG) << "Finished getting all actor info."; GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); @@ -242,9 +227,7 @@ void GcsActorManager::HandleGetAllActorInfo(const rpc::GetAllActorInfoRequest &r [reply, send_reply_callback]( const std::unordered_map &result) { for (const auto &pair : result) { - // TODO yic: Fix const cast - reply->mutable_actor_table_data()->UnsafeArenaAddAllocated( - const_cast(&pair.second)); + reply->add_actor_table_data()->CopyFrom(pair.second); } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); RAY_LOG(DEBUG) << "Finished getting all actor info."; @@ -275,8 +258,7 @@ void GcsActorManager::HandleGetNamedActorInfo( RAY_LOG(WARNING) << stream.str(); status = Status::NotFound(stream.str()); } else { - reply->unsafe_arena_set_allocated_actor_table_data( - iter->second->GetMutableActorTableData()); + reply->mutable_actor_table_data()->CopyFrom(iter->second->GetActorTableData()); RAY_LOG(DEBUG) << "Finished getting actor info, job id = " << actor_id.JobId() << ", actor id = " << actor_id; } @@ -293,9 +275,10 @@ void GcsActorManager::HandleListNamedActors(const rpc::ListNamedActorsRequest &r std::vector> actors = ListNamedActors(request.all_namespaces(), ray_namespace); for (const auto &actor : actors) { - auto named_actor_indo = reply->add_named_actors_list(); - named_actor_indo->set_ray_namespace(actor.first); - named_actor_indo->set_name(actor.second); + rpc::NamedActorInfo named_actor_info; + named_actor_info.set_ray_namespace(actor.first); + named_actor_info.set_name(actor.second); + reply->add_named_actors_list()->CopyFrom(named_actor_info); } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); ++counts_[CountType::LIST_NAMED_ACTORS_REQUEST]; @@ -398,9 +381,13 @@ Status GcsActorManager::RegisterActor(const ray::rpc::RegisterActorRequest &requ // owner to determine when the actor should be removed. PollOwnerForActorOutOfScope(actor); } else { - // If it's a detached actor, we need to register the runtime env it used to GC. - runtime_env_manager_.AddURIReference(actor->GetActorID().Hex(), - request.task_spec().runtime_env()); + // If it's a detached actor, we need to register the runtime env it used to GC + auto job_id = JobID::FromBinary(request.task_spec().job_id()); + const auto &uris = runtime_env_manager_.GetReferences(job_id.Hex()); + auto actor_id_hex = actor->GetActorID().Hex(); + for (const auto &uri : uris) { + runtime_env_manager_.AddURIReference(actor_id_hex, uri); + } } // The backend storage is supposed to be reliable, so the status must be ok. @@ -588,11 +575,6 @@ void GcsActorManager::DestroyActor(const ActorID &actor_id) { RAY_LOG(INFO) << "Tried to destroy actor that does not exist " << actor_id; return; } - - if (RayConfig::instance().gcs_actor_scheduling_enabled()) { - gcs_actor_scheduler_->OnActorDestruction(it->second); - } - const auto &task_id = it->second->GetCreationTaskSpecification().TaskId(); it->second->GetMutableActorTableData()->mutable_task_spec()->Clear(); it->second->GetMutableActorTableData()->set_timestamp(current_sys_time_ms()); @@ -975,7 +957,6 @@ void GcsActorManager::OnActorCreationSuccess(const std::shared_ptr &ac } void GcsActorManager::SchedulePendingActors() { - schedule_pending_actors_posted_ = false; if (pending_actors_.empty()) { return; } @@ -987,14 +968,6 @@ void GcsActorManager::SchedulePendingActors() { } } -bool GcsActorManager::GetSchedulePendingActorsPosted() const { - return schedule_pending_actors_posted_; -} - -void GcsActorManager::SetSchedulePendingActorsPosted(bool posted) { - schedule_pending_actors_posted_ = posted; -} - void GcsActorManager::Initialize(const GcsInitData &gcs_init_data) { const auto &jobs = gcs_init_data.Jobs(); std::unordered_map> node_to_workers; diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.h b/src/ray/gcs/gcs_server/gcs_actor_manager.h index 9050eb4dfc9fe..569c9b2b19172 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.h +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.h @@ -21,7 +21,6 @@ #include "ray/common/runtime_env_manager.h" #include "ray/common/task/task_execution_spec.h" #include "ray/common/task/task_spec.h" -#include "ray/gcs/gcs_server/gcs_actor_distribution.h" #include "ray/gcs/gcs_server/gcs_actor_scheduler.h" #include "ray/gcs/gcs_server/gcs_init_data.h" #include "ray/gcs/gcs_server/gcs_table_storage.h" @@ -87,8 +86,7 @@ class GcsActor { break; } - actor_table_data_.set_serialized_runtime_env( - task_spec.runtime_env().serialized_runtime_env()); + actor_table_data_.set_serialized_runtime_env(task_spec.serialized_runtime_env()); } /// Get the node id on which this actor is created. @@ -195,7 +193,6 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// \param gcs_table_storage Used to flush actor data to storage. /// \param gcs_pub_sub Used to publish gcs message. GcsActorManager( - boost::asio::io_context &io_context, std::shared_ptr scheduler, std::shared_ptr gcs_table_storage, std::shared_ptr gcs_pub_sub, RuntimeEnvManager &runtime_env_manager, @@ -344,10 +341,6 @@ class GcsActorManager : public rpc::ActorInfoHandler { std::string DebugString() const; - bool GetSchedulePendingActorsPosted() const; - - void SetSchedulePendingActorsPosted(bool posted); - private: /// A data structure representing an actor's owner. struct Owner { @@ -492,7 +485,6 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// according to its owner, or the owner dies. absl::flat_hash_map> owners_; - boost::asio::io_context &io_context_; /// The scheduler to schedule all registered actors. std::shared_ptr gcs_actor_scheduler_; /// Used to update actor information upon creation, deletion, etc. @@ -516,9 +508,6 @@ class GcsActorManager : public rpc::ActorInfoHandler { run_delayed_; const boost::posix_time::milliseconds actor_gc_delay_; - /// Indicate whether a call of SchedulePendingActors has been posted. - bool schedule_pending_actors_posted_; - // Debug info. enum CountType { REGISTER_ACTOR_REQUEST = 0, diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc index cc4a426cea653..81d476a80854b 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc @@ -38,6 +38,7 @@ GcsActorScheduler::GcsActorScheduler( gcs_pub_sub_(std::move(gcs_pub_sub)), schedule_failure_handler_(std::move(schedule_failure_handler)), schedule_success_handler_(std::move(schedule_success_handler)), + report_worker_backlog_(RayConfig::instance().report_worker_backlog()), raylet_client_pool_(raylet_client_pool), core_worker_clients_(client_factory) { RAY_CHECK(schedule_failure_handler_ != nullptr && schedule_success_handler_ != nullptr); @@ -229,13 +230,14 @@ void GcsActorScheduler::LeaseWorkerFromNode(std::shared_ptr actor, auto lease_client = GetOrConnectLeaseClient(remote_address); // Actor leases should be sent to the raylet immediately, so we should never build up a // backlog in GCS. + int backlog_size = report_worker_backlog_ ? 0 : -1; lease_client->RequestWorkerLease( - actor->GetActorTableData().task_spec(), + actor->GetCreationTaskSpecification(), [this, actor, node](const Status &status, const rpc::RequestWorkerLeaseReply &reply) { HandleWorkerLeaseReply(actor, node, status, reply); }, - 0); + backlog_size); } void GcsActorScheduler::RetryLeasingWorkerFromNode( diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h index 55bd6b6bd73f6..34d7d3ea3a186 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h @@ -75,11 +75,6 @@ class GcsActorSchedulerInterface { virtual void ReleaseUnusedWorkers( const std::unordered_map> &node_to_workers) = 0; - /// Handle the destruction of an actor. - /// - /// \param actor The actor to be destoryed. - virtual void OnActorDestruction(std::shared_ptr actor) = 0; - virtual ~GcsActorSchedulerInterface() {} }; @@ -151,11 +146,6 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { void ReleaseUnusedWorkers( const std::unordered_map> &node_to_workers) override; - /// Handle the destruction of an actor. - /// - /// \param actor The actor to be destoryed. - void OnActorDestruction(std::shared_ptr actor) override {} - protected: /// The GcsLeasedWorker is kind of abstraction of remote leased worker inside raylet. It /// contains the address of remote leased worker as well as the leased resources and the @@ -312,6 +302,8 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { /// The handler to handle the successful scheduling. std::function, const rpc::PushTaskReply &reply)> schedule_success_handler_; + /// Whether or not to report the backlog of actors waiting to be scheduled. + bool report_worker_backlog_; /// The nodes which are releasing unused workers. absl::flat_hash_set nodes_of_releasing_unused_workers_; /// The cached raylet clients used to communicate with raylet. diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index c84e19372a6b4..91782a712db5b 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -84,15 +84,11 @@ void GcsNodeManager::HandleUnregisterNode(const rpc::UnregisterNodeRequest &requ void GcsNodeManager::HandleGetAllNodeInfo(const rpc::GetAllNodeInfoRequest &request, rpc::GetAllNodeInfoReply *reply, rpc::SendReplyCallback send_reply_callback) { - // Here the unsafe allocate is safe here, because entry.second's life cycle is longer - // then reply. - // The request will be sent when call send_reply_callback and after that, reply will - // not be used any more. But entry is still valid. for (const auto &entry : alive_nodes_) { - reply->mutable_node_info_list()->UnsafeArenaAddAllocated(entry.second.get()); + reply->add_node_info_list()->CopyFrom(*entry.second); } for (const auto &entry : dead_nodes_) { - reply->mutable_node_info_list()->UnsafeArenaAddAllocated(entry.second.get()); + reply->add_node_info_list()->CopyFrom(*entry.second); } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); ++counts_[CountType::GET_ALL_NODE_INFO_REQUEST]; diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc index 7879a9fd71bce..f41f9d45bd6e7 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc @@ -45,26 +45,22 @@ std::string GcsPlacementGroup::GetRayNamespace() const { return placement_group_table_data_.ray_namespace(); } -std::vector> &GcsPlacementGroup::GetBundles() - const { - // Fill the cache if it wasn't. - if (cached_bundle_specs_.empty()) { - const auto &bundles = placement_group_table_data_.bundles(); - for (const auto &bundle : bundles) { - cached_bundle_specs_.push_back(std::make_shared(bundle)); - } +std::vector> GcsPlacementGroup::GetBundles() const { + const auto &bundles = placement_group_table_data_.bundles(); + std::vector> ret_bundles; + for (const auto &bundle : bundles) { + ret_bundles.push_back(std::make_shared(bundle)); } - return cached_bundle_specs_; + return ret_bundles; } -std::vector> -GcsPlacementGroup::GetUnplacedBundles() const { - const auto &bundle_specs = GetBundles(); - - std::vector> unplaced_bundles; - for (const auto &bundle : bundle_specs) { - if (bundle->NodeId().IsNil()) { - unplaced_bundles.push_back(bundle); +std::vector> GcsPlacementGroup::GetUnplacedBundles() + const { + const auto &bundles = placement_group_table_data_.bundles(); + std::vector> unplaced_bundles; + for (const auto &bundle : bundles) { + if (NodeID::FromBinary(bundle.node_id()).IsNil()) { + unplaced_bundles.push_back(std::make_shared(bundle)); } } return unplaced_bundles; @@ -87,8 +83,6 @@ std::string GcsPlacementGroup::DebugString() const { } rpc::Bundle *GcsPlacementGroup::GetMutableBundle(int bundle_index) { - // Invalidate the cache. - cached_bundle_specs_.clear(); return placement_group_table_data_.mutable_bundles(bundle_index); } @@ -182,7 +176,7 @@ void GcsPlacementGroupManager::RegisterPlacementGroup( .emplace_back(std::move(callback)); registered_placement_groups_.emplace(placement_group->GetPlacementGroupID(), placement_group); - AddToPendingQueue(placement_group); + pending_placement_groups_.emplace_back(placement_group); RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group_id, placement_group->GetPlacementGroupTableData(), @@ -227,8 +221,7 @@ PlacementGroupID GcsPlacementGroupManager::GetPlacementGroupIDByName( } void GcsPlacementGroupManager::OnPlacementGroupCreationFailed( - std::shared_ptr placement_group, ExponentialBackOff backoff, - bool is_feasible) { + std::shared_ptr placement_group, bool is_feasible) { RAY_LOG(DEBUG) << "Failed to create placement group " << placement_group->GetName() << ", id: " << placement_group->GetPlacementGroupID() << ", try again."; @@ -236,6 +229,7 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationFailed( // We will attempt to schedule this placement_group once an eligible node is // registered. infeasible_placement_groups_.emplace_back(std::move(placement_group)); + MarkSchedulingDone(); } else { auto state = placement_group->GetState(); RAY_CHECK(state == rpc::PlacementGroupTableData::RESCHEDULING || @@ -247,13 +241,14 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationFailed( // NOTE: If a node is dead, the placement group scheduler should try to recover the // group by rescheduling the bundles of the dead node. This should have higher // priority than trying to place other placement groups. - AddToPendingQueue(std::move(placement_group), /* rank */ 0); + pending_placement_groups_.emplace_front(std::move(placement_group)); } else { - AddToPendingQueue(std::move(placement_group), std::nullopt, backoff); + pending_placement_groups_.emplace_back(std::move(placement_group)); } + + MarkSchedulingDone(); + RetryCreatingPlacementGroup(); } - io_context_.post([this] { SchedulePendingPlacementGroups(); }); - MarkSchedulingDone(); } void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( @@ -261,11 +256,16 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( RAY_LOG(INFO) << "Successfully created placement group " << placement_group->GetName() << ", id: " << placement_group->GetPlacementGroupID(); placement_group->UpdateState(rpc::PlacementGroupTableData::CREATED); + // Mark the scheduling done firstly. + MarkSchedulingDone(); auto placement_group_id = placement_group->GetPlacementGroupID(); RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group_id, placement_group->GetPlacementGroupTableData(), [this, placement_group_id](Status status) { RAY_CHECK_OK(status); + + SchedulePendingPlacementGroups(); + // Invoke all callbacks for all `WaitPlacementGroupUntilReady` requests of this // placement group and remove all of them from // placement_group_to_create_callbacks_. @@ -278,8 +278,6 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( placement_group_to_create_callbacks_.erase(pg_to_create_iter); } })); - io_context_.post([this] { SchedulePendingPlacementGroups(); }); - MarkSchedulingDone(); } void GcsPlacementGroupManager::SchedulePendingPlacementGroups() { @@ -296,28 +294,16 @@ void GcsPlacementGroupManager::SchedulePendingPlacementGroups() { bool is_new_placement_group_scheduled = false; while (!pending_placement_groups_.empty() && !is_new_placement_group_scheduled) { - auto iter = pending_placement_groups_.begin(); - if (iter->first > absl::GetCurrentTimeNanos()) { - // Here the rank equals the time to schedule, and it's an ordered tree, - // it means all the other tasks should be scheduled after this one. - // If the first one won't be scheduled, we just skip. - // Tick will cover the next time retry. - break; - } - auto backoff = iter->second.first; - auto placement_group = std::move(iter->second.second); - pending_placement_groups_.erase(iter); - + const auto placement_group = pending_placement_groups_.front(); + pending_placement_groups_.pop_front(); const auto &placement_group_id = placement_group->GetPlacementGroupID(); // Do not reschedule if the placement group has removed already. if (registered_placement_groups_.contains(placement_group_id)) { MarkSchedulingStarted(placement_group_id); gcs_placement_group_scheduler_->ScheduleUnplacedBundles( placement_group, - [this, backoff](std::shared_ptr placement_group, - bool is_insfeasble) { - OnPlacementGroupCreationFailed(std::move(placement_group), backoff, - is_insfeasble); + [this](std::shared_ptr placement_group, bool is_insfeasble) { + OnPlacementGroupCreationFailed(std::move(placement_group), is_insfeasble); }, [this](std::shared_ptr placement_group) { OnPlacementGroupCreationSuccess(std::move(placement_group)); @@ -326,7 +312,6 @@ void GcsPlacementGroupManager::SchedulePendingPlacementGroups() { } // If the placement group is not registered == removed. } - ++counts_[CountType::SCHEDULING_PENDING_PLACEMENT_GROUP]; } void GcsPlacementGroupManager::HandleCreatePlacementGroup( @@ -408,10 +393,18 @@ void GcsPlacementGroupManager::RemovePlacementGroup( } // Remove a placement group from a pending list if exists. - RemoveFromPendingQueue(placement_group_id); + auto pending_it = std::find_if( + pending_placement_groups_.begin(), pending_placement_groups_.end(), + [placement_group_id](const std::shared_ptr &placement_group) { + return placement_group->GetPlacementGroupID() == placement_group_id; + }); + if (pending_it != pending_placement_groups_.end()) { + // The placement group was pending scheduling, remove it from the queue. + pending_placement_groups_.erase(pending_it); + } // Remove a placement group from infeasible queue if exists. - auto pending_it = std::find_if( + pending_it = std::find_if( infeasible_placement_groups_.begin(), infeasible_placement_groups_.end(), [placement_group_id](const std::shared_ptr &placement_group) { return placement_group->GetPlacementGroupID() == placement_group_id; @@ -580,36 +573,9 @@ void GcsPlacementGroupManager::WaitPlacementGroup( } } -void GcsPlacementGroupManager::AddToPendingQueue( - std::shared_ptr pg, std::optional rank, - std::optional exp_backer) { - if (!rank) { - rank = absl::GetCurrentTimeNanos(); - } - - if (!exp_backer) { - exp_backer = ExponentialBackOff( - 1000000 * - RayConfig::instance().gcs_create_placement_group_retry_min_interval_ms(), - RayConfig::instance().gcs_create_placement_group_retry_multiplier(), - 1000000 * - RayConfig::instance().gcs_create_placement_group_retry_max_interval_ms()); - } else { - *rank += static_cast(exp_backer->Next()); - } - auto val = std::make_pair(*exp_backer, std::move(pg)); - pending_placement_groups_.emplace(*rank, std::move(val)); -} - -void GcsPlacementGroupManager::RemoveFromPendingQueue(const PlacementGroupID &pg_id) { - auto it = std::find_if(pending_placement_groups_.begin(), - pending_placement_groups_.end(), [&pg_id](const auto &val) { - return val.second.second->GetPlacementGroupID() == pg_id; - }); - // The placement group was pending scheduling, remove it from the queue. - if (it != pending_placement_groups_.end()) { - pending_placement_groups_.erase(it); - } +void GcsPlacementGroupManager::RetryCreatingPlacementGroup() { + execute_after(io_context_, [this] { SchedulePendingPlacementGroups(); }, + RayConfig::instance().gcs_create_placement_group_retry_interval_ms()); } void GcsPlacementGroupManager::OnNodeDead(const NodeID &node_id) { @@ -627,7 +593,7 @@ void GcsPlacementGroupManager::OnNodeDead(const NodeID &node_id) { // creating until a node with the resources is added. we will solve it in next pr. if (iter->second->GetState() != rpc::PlacementGroupTableData::RESCHEDULING) { iter->second->UpdateState(rpc::PlacementGroupTableData::RESCHEDULING); - AddToPendingQueue(iter->second, 0); + pending_placement_groups_.emplace_front(iter->second); } } } @@ -643,9 +609,9 @@ void GcsPlacementGroupManager::OnNodeAdd(const NodeID &node_id) { // Move all the infeasible placement groups to the pending queue so that we can // reschedule them. if (infeasible_placement_groups_.size() > 0) { - for (auto &pg : infeasible_placement_groups_) { - AddToPendingQueue(std::move(pg)); - } + auto end_it = pending_placement_groups_.end(); + pending_placement_groups_.insert(end_it, infeasible_placement_groups_.cbegin(), + infeasible_placement_groups_.cend()); infeasible_placement_groups_.clear(); } SchedulePendingPlacementGroups(); @@ -701,16 +667,14 @@ void GcsPlacementGroupManager::Tick() { // Note that we don't currently have a known race condition that requires this, but we // added as a safety check. https://github.com/ray-project/ray/pull/18419 SchedulePendingPlacementGroups(); - execute_after( - io_context_, [this] { Tick(); }, 1000 /* milliseconds */); + execute_after(io_context_, [this] { Tick(); }, 1000 /* milliseconds */); } void GcsPlacementGroupManager::UpdatePlacementGroupLoad() { std::shared_ptr placement_group_load = std::make_shared(); int total_cnt = 0; - for (const auto &elem : pending_placement_groups_) { - const auto pending_pg_spec = elem.second.second; + for (const auto &pending_pg_spec : pending_placement_groups_) { auto placement_group_data = placement_group_load->add_placement_group_data(); auto placement_group_table_data = pending_pg_spec->GetPlacementGroupTableData(); placement_group_data->Swap(&placement_group_table_data); @@ -746,7 +710,7 @@ void GcsPlacementGroupManager::Initialize(const GcsInitData &gcs_init_data) { if (item.second.state() == rpc::PlacementGroupTableData::PENDING || item.second.state() == rpc::PlacementGroupTableData::RESCHEDULING) { - AddToPendingQueue(std::move(placement_group)); + pending_placement_groups_.emplace_back(std::move(placement_group)); } if (item.second.state() == rpc::PlacementGroupTableData::CREATED || @@ -785,8 +749,6 @@ std::string GcsPlacementGroupManager::DebugString() const { << counts_[CountType::WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST] << ", GetNamedPlacementGroup request count: " << counts_[CountType::GET_NAMED_PLACEMENT_GROUP_REQUEST] - << ", Scheduling pending placement group count: " - << counts_[CountType::SCHEDULING_PENDING_PLACEMENT_GROUP] << ", Registered placement groups count: " << registered_placement_groups_.size() << ", Named placement group count: " << num_pgs << ", Pending placement groups count: " << pending_placement_groups_.size() diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h index 93bc68d306e43..bc3407fd8ac02 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h @@ -13,12 +13,8 @@ // limitations under the License. #pragma once -#include - -#include #include -#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "ray/common/asio/instrumented_io_context.h" @@ -93,10 +89,10 @@ class GcsPlacementGroup { std::string GetRayNamespace() const; /// Get the bundles of this placement_group (including unplaced). - std::vector> &GetBundles() const; + std::vector> GetBundles() const; /// Get the unplaced bundles of this placement group. - std::vector> GetUnplacedBundles() const; + std::vector> GetUnplacedBundles() const; /// Get the Strategy rpc::PlacementStrategy GetStrategy() const; @@ -125,14 +121,9 @@ class GcsPlacementGroup { bool IsDetached() const; private: - FRIEND_TEST(GcsPlacementGroupManagerTest, TestPlacementGroupBundleCache); /// The placement_group meta data which contains the task specification as well as the /// state of the gcs placement_group and so on (see gcs.proto). rpc::PlacementGroupTableData placement_group_table_data_; - /// Creating bundle specification requires heavy computation because it needs to compute - /// formatted strings for all resources (heavy string operations). To optimize the CPU - /// usage, we cache bundle specs. - mutable std::vector> cached_bundle_specs_; }; /// GcsPlacementGroupManager is responsible for managing the lifecycle of all placement @@ -218,7 +209,7 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { /// \param placement_group The placement_group whose creation task is infeasible. /// \param is_feasible whether the scheduler can be retry or not currently. void OnPlacementGroupCreationFailed(std::shared_ptr placement_group, - ExponentialBackOff backoff, bool is_feasible); + bool is_feasible = true); /// Handle placement_group creation task success. This should be called when the /// placement_group creation task has been scheduled successfully. @@ -286,19 +277,6 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { std::string DebugString() const; private: - /// Push a placement group to pending queue. - /// - /// \param pg The placementgroup we are adding - /// \param rank The rank for this placement group. Semantically it's the time - /// this placement group to be scheduled. By default it'll be assigned to be - /// the current time. - /// \param exp_backer The exponential backoff. A default one will be given if - /// it's not set. This will be used to generate the deferred time for this pg. - void AddToPendingQueue(std::shared_ptr pg, - std::optional rank = std::nullopt, - std::optional exp_backer = std::nullopt); - void RemoveFromPendingQueue(const PlacementGroupID &pg_id); - /// Try to create placement group after a short time. void RetryCreatingPlacementGroup(); @@ -344,17 +322,12 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { absl::flat_hash_map> registered_placement_groups_; - /// The pending placement_groups which will not be scheduled until there's a - /// resource change. The pending queue is represented as an ordered map, where - /// the key is the time to schedule the pg and value if a pair containing the - /// actual placement group and a exp-backoff. - /// When error happens, we'll retry it later and this can be simply done by - /// inserting an element into the queue with a bigger key. With this, we don't - /// need to post retry job to io context. And when schedule pending placement - /// group, we always start with the one with the smallest key. - absl::btree_multimap>> - pending_placement_groups_; + /// The pending placement_groups which will not be scheduled until there's a resource + /// change. + /// NOTE: When we remove placement group, we need to look for + /// `pending_placement_groups_` and delete the specific placement group, so we can't use + /// `std::priority_queue`. + std::deque> pending_placement_groups_; /// The infeasible placement_groups that can't be scheduled currently. std::deque> infeasible_placement_groups_; @@ -390,14 +363,9 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { GET_ALL_PLACEMENT_GROUP_REQUEST = 3, WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST = 4, GET_NAMED_PLACEMENT_GROUP_REQUEST = 5, - SCHEDULING_PENDING_PLACEMENT_GROUP = 6, - CountType_MAX = 7, + CountType_MAX = 6, }; uint64_t counts_[CountType::CountType_MAX] = {0}; - - FRIEND_TEST(GcsPlacementGroupManagerMockTest, PendingQueuePriorityReschedule); - FRIEND_TEST(GcsPlacementGroupManagerMockTest, PendingQueuePriorityFailed); - FRIEND_TEST(GcsPlacementGroupManagerMockTest, PendingQueuePriorityOrder); }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc index 7c9391315a945..c2ca3c3c8cd40 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc @@ -39,7 +39,7 @@ GcsPlacementGroupScheduler::GcsPlacementGroupScheduler( } std::vector GcsScheduleStrategy::GetRequiredResourcesFromBundles( - const std::vector> &bundles) { + const std::vector> &bundles) { std::vector required_resources; for (const auto &bundle : bundles) { required_resources.push_back(bundle->GetRequiredResources()); @@ -48,7 +48,7 @@ std::vector GcsScheduleStrategy::GetRequiredResourcesFromBundles( } ScheduleResult GcsScheduleStrategy::GenerateScheduleResult( - const std::vector> &bundles, + const std::vector> &bundles, const std::vector &selected_nodes, const SchedulingResultStatus &status) { ScheduleMap schedule_map; if (status == SUCCESS && !selected_nodes.empty()) { @@ -62,7 +62,7 @@ ScheduleResult GcsScheduleStrategy::GenerateScheduleResult( } ScheduleResult GcsStrictPackStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { const auto &required_resources = GetRequiredResourcesFromBundles(bundles); @@ -73,7 +73,7 @@ ScheduleResult GcsStrictPackStrategy::Schedule( } ScheduleResult GcsPackStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { // The current algorithm is to select a node and deploy as many bundles as possible. @@ -87,7 +87,7 @@ ScheduleResult GcsPackStrategy::Schedule( } ScheduleResult GcsSpreadStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { const auto &required_resources = GetRequiredResourcesFromBundles(bundles); @@ -98,7 +98,7 @@ ScheduleResult GcsSpreadStrategy::Schedule( } ScheduleResult GcsStrictSpreadStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { // TODO(ffbin): A bundle may require special resources, such as GPU. We need to @@ -211,7 +211,7 @@ void GcsPlacementGroupScheduler::MarkScheduleCancelled( } void GcsPlacementGroupScheduler::PrepareResources( - const std::shared_ptr &bundle, + const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback &callback) { if (!node.has_value()) { @@ -240,7 +240,7 @@ void GcsPlacementGroupScheduler::PrepareResources( } void GcsPlacementGroupScheduler::CommitResources( - const std::shared_ptr &bundle, + const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback callback) { RAY_CHECK(node.has_value()); @@ -265,7 +265,7 @@ void GcsPlacementGroupScheduler::CommitResources( } void GcsPlacementGroupScheduler::CancelResourceReserve( - const std::shared_ptr &bundle_spec, + const std::shared_ptr &bundle_spec, const absl::optional> &node) { if (!node.has_value()) { RAY_LOG(INFO) << "Node for a placement group id " << bundle_spec->PlacementGroupId() @@ -660,7 +660,7 @@ void BundleLocationIndex::AddNodes( LeaseStatusTracker::LeaseStatusTracker( std::shared_ptr placement_group, - const std::vector> &unplaced_bundles, + const std::vector> &unplaced_bundles, const ScheduleMap &schedule_map) : placement_group_(placement_group), bundles_to_schedule_(unplaced_bundles) { preparing_bundle_locations_ = std::make_shared(); @@ -675,13 +675,13 @@ LeaseStatusTracker::LeaseStatusTracker( } bool LeaseStatusTracker::MarkPreparePhaseStarted( - const NodeID &node_id, const std::shared_ptr &bundle) { + const NodeID &node_id, std::shared_ptr bundle) { const auto &bundle_id = bundle->BundleId(); return node_to_bundles_when_preparing_[node_id].emplace(bundle_id).second; } void LeaseStatusTracker::MarkPrepareRequestReturned( - const NodeID &node_id, const std::shared_ptr &bundle, + const NodeID &node_id, const std::shared_ptr bundle, const Status &status) { RAY_CHECK(prepare_request_returned_count_ <= bundles_to_schedule_.size()); auto leasing_bundles = node_to_bundles_when_preparing_.find(node_id); @@ -715,7 +715,7 @@ bool LeaseStatusTracker::AllPrepareRequestsSuccessful() const { } void LeaseStatusTracker::MarkCommitRequestReturned( - const NodeID &node_id, const std::shared_ptr &bundle, + const NodeID &node_id, const std::shared_ptr bundle, const Status &status) { commit_request_returned_count_ += 1; // If the request succeeds, record it. @@ -762,7 +762,7 @@ const std::shared_ptr &LeaseStatusTracker::GetBundleLocations() return bundle_locations_; } -const std::vector> +const std::vector> &LeaseStatusTracker::GetBundlesToSchedule() const { return bundles_to_schedule_; } diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h index 4e921ab13e248..bdfee4276dec5 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h @@ -49,8 +49,9 @@ struct pair_hash { }; using ScheduleMap = std::unordered_map; using ScheduleResult = std::pair; -using BundleLocations = absl::flat_hash_map< - BundleID, std::pair>, pair_hash>; +using BundleLocations = + absl::flat_hash_map>, + pair_hash>; class GcsPlacementGroupSchedulerInterface { public: @@ -111,7 +112,7 @@ class GcsScheduleStrategy { public: virtual ~GcsScheduleStrategy() {} virtual ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) = 0; @@ -121,7 +122,7 @@ class GcsScheduleStrategy { /// \param bundles Bundles to be scheduled. /// \return Required resources. std::vector GetRequiredResourcesFromBundles( - const std::vector> &bundles); + const std::vector> &bundles); /// Generate `ScheduleResult` from bundles and nodes . /// @@ -130,7 +131,7 @@ class GcsScheduleStrategy { /// \param status Status of the scheduling result. /// \return The scheduling result from the required resource. ScheduleResult GenerateScheduleResult( - const std::vector> &bundles, + const std::vector> &bundles, const std::vector &selected_nodes, const SchedulingResultStatus &status); }; @@ -140,7 +141,7 @@ class GcsScheduleStrategy { class GcsPackStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -149,7 +150,7 @@ class GcsPackStrategy : public GcsScheduleStrategy { class GcsSpreadStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -159,7 +160,7 @@ class GcsSpreadStrategy : public GcsScheduleStrategy { class GcsStrictPackStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -170,7 +171,7 @@ class GcsStrictPackStrategy : public GcsScheduleStrategy { class GcsStrictSpreadStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -191,7 +192,7 @@ class LeaseStatusTracker { public: LeaseStatusTracker( std::shared_ptr placement_group, - const std::vector> &unplaced_bundles, + const std::vector> &unplaced_bundles, const ScheduleMap &schedule_map); ~LeaseStatusTracker() = default; @@ -201,7 +202,7 @@ class LeaseStatusTracker { /// \param bundle Bundle specification the node is supposed to prepare. /// \return False if the prepare phase was already started. True otherwise. bool MarkPreparePhaseStarted(const NodeID &node_id, - const std::shared_ptr &bundle); + std::shared_ptr bundle); /// Indicate the tracker that all prepare requests are returned. /// @@ -209,9 +210,9 @@ class LeaseStatusTracker { /// \param bundle Bundle specification the node was supposed to schedule. /// \param status Status of the prepare response. /// \param void - void MarkPrepareRequestReturned( - const NodeID &node_id, const std::shared_ptr &bundle, - const Status &status); + void MarkPrepareRequestReturned(const NodeID &node_id, + std::shared_ptr bundle, + const Status &status); /// Used to know if all prepare requests are returned. /// @@ -229,7 +230,7 @@ class LeaseStatusTracker { /// \param bundle Bundle specification the node was supposed to schedule. /// \param status Status of the returned commit request. void MarkCommitRequestReturned(const NodeID &node_id, - const std::shared_ptr &bundle, + const std::shared_ptr bundle, const Status &status); /// Used to know if all commit requests are returend. @@ -250,8 +251,7 @@ class LeaseStatusTracker { /// Return bundles that should be scheduled. /// /// \return List of bundle specification that are supposed to be scheduled. - [[nodiscard]] const std::vector> - &GetBundlesToSchedule() const; + const std::vector> &GetBundlesToSchedule() const; /// This method returns bundle locations that succeed to prepare resources. /// @@ -324,7 +324,7 @@ class LeaseStatusTracker { node_to_bundles_when_preparing_; /// Bundles to schedule. - std::vector> bundles_to_schedule_; + std::vector> bundles_to_schedule_; /// Location of bundles. std::shared_ptr bundle_locations_; @@ -460,7 +460,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// \param node A node to prepare resources for a given bundle. /// \param callback void PrepareResources( - const std::shared_ptr &bundle, + const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback &callback); @@ -470,7 +470,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// \param bundle A bundle to schedule on a node. /// \param node A node to commit resources for a given bundle. /// \param callback - void CommitResources(const std::shared_ptr &bundle, + void CommitResources(const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback callback); @@ -481,7 +481,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// \param bundle A description of the bundle to return. /// \param node The node that the worker will be returned for. void CancelResourceReserve( - const std::shared_ptr &bundle_spec, + const std::shared_ptr &bundle_spec, const absl::optional> &node); /// Get an existing lease client or connect a new one or connect a new one. diff --git a/src/ray/gcs/gcs_server/gcs_resource_manager.cc b/src/ray/gcs/gcs_server/gcs_resource_manager.cc index eec2d5d4dfd22..983edfe7df9c3 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_resource_manager.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "ray/gcs/gcs_server/gcs_resource_manager.h" - #include "ray/common/ray_config.h" #include "ray/stats/stats.h" @@ -234,8 +233,10 @@ void GcsResourceManager::HandleGetAllResourceUsage( aggregate_demand.set_num_infeasible_requests_queued( aggregate_demand.num_infeasible_requests_queued() + demand.num_infeasible_requests_queued()); - aggregate_demand.set_backlog_size(aggregate_demand.backlog_size() + - demand.backlog_size()); + if (RayConfig::instance().report_worker_backlog()) { + aggregate_demand.set_backlog_size(aggregate_demand.backlog_size() + + demand.backlog_size()); + } } batch->add_batch()->CopyFrom(usage.second); diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 84821c6af1d3a..6a4d60c685e9b 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -35,7 +35,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, : config_(config), main_service_(main_service), rpc_server_(config.grpc_server_name, config.grpc_server_port, - config.node_ip_address == "127.0.0.1", config.grpc_server_thread_num, + config.grpc_server_thread_num, /*keepalive_time_ms=*/RayConfig::instance().grpc_keepalive_time_ms()), client_call_manager_(main_service), raylet_client_pool_( @@ -267,8 +267,7 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { client_factory); } gcs_actor_manager_ = std::make_shared( - main_service_, std::move(scheduler), gcs_table_storage_, gcs_pub_sub_, - *runtime_env_manager_, + std::move(scheduler), gcs_table_storage_, gcs_pub_sub_, *runtime_env_manager_, [this](const ActorID &actor_id) { gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(actor_id); }, @@ -479,7 +478,7 @@ void GcsServer::InstallEventListeners() { gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenJobDead(*job_id); }); - // Install scheduling event listeners. + // Install scheduling policy event listeners. if (RayConfig::instance().gcs_actor_scheduling_enabled()) { gcs_resource_manager_->AddResourcesChangedListener([this] { main_service_.post([this] { @@ -514,10 +513,9 @@ void GcsServer::PrintDebugInfo() { // TODO(ffbin): We will get the session_dir in the next PR, and write the log to // gcs_debug_state.txt. RAY_LOG(INFO) << stream.str(); - execute_after( - main_service_, [this] { PrintDebugInfo(); }, - (RayConfig::instance().gcs_dump_debug_log_interval_minutes() * - 60000) /* milliseconds */); + execute_after(main_service_, [this] { PrintDebugInfo(); }, + (RayConfig::instance().gcs_dump_debug_log_interval_minutes() * + 60000) /* milliseconds */); } void GcsServer::PrintAsioStats() { @@ -526,9 +524,8 @@ void GcsServer::PrintAsioStats() { RayConfig::instance().event_stats_print_interval_ms(); if (event_stats_print_interval_ms != -1 && RayConfig::instance().event_stats()) { RAY_LOG(INFO) << "Event stats:\n\n" << main_service_.StatsString() << "\n\n"; - execute_after( - main_service_, [this] { PrintAsioStats(); }, - event_stats_print_interval_ms /* milliseconds */); + execute_after(main_service_, [this] { PrintAsioStats(); }, + event_stats_print_interval_ms /* milliseconds */); } } diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index 507ab2820cab7..cadb70a3f3541 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -16,6 +16,7 @@ #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/runtime_env_manager.h" +#include "ray/gcs/gcs_server/gcs_actor_distribution.h" #include "ray/gcs/gcs_server/gcs_heartbeat_manager.h" #include "ray/gcs/gcs_server/gcs_init_data.h" #include "ray/gcs/gcs_server/gcs_kv_manager.h" diff --git a/src/ray/gcs/gcs_server/gcs_server_main.cc b/src/ray/gcs/gcs_server/gcs_server_main.cc index abb39469b427e..79a64514409b6 100644 --- a/src/ray/gcs/gcs_server/gcs_server_main.cc +++ b/src/ray/gcs/gcs_server/gcs_server_main.cc @@ -80,12 +80,13 @@ int main(int argc, char *argv[]) { storage->InternalConfigTable().Put(ray::UniqueID::Nil(), config, on_done)); boost::asio::io_service::work work(service); service.run(); - }).detach(); + }) + .detach(); promise->get_future().get(); const ray::stats::TagsType global_tags = { {ray::stats::ComponentKey, "gcs_server"}, - {ray::stats::VersionKey, kRayVersion}, + {ray::stats::VersionKey, "2.0.0.dev0"}, {ray::stats::NodeAddressKey, node_ip_address}}; ray::stats::Init(global_tags, metrics_agent_port); diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.h b/src/ray/gcs/gcs_server/gcs_table_storage.h index ed48cf71abdf2..84a70a347ebf7 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/ray/gcs/gcs_server/gcs_table_storage.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include "ray/common/asio/instrumented_io_context.h" @@ -52,8 +51,8 @@ using rpc::WorkerTableData; template class GcsTable { public: - explicit GcsTable(std::shared_ptr store_client) - : store_client_(std::move(store_client)) {} + explicit GcsTable(std::shared_ptr &store_client) + : store_client_(store_client) {} virtual ~GcsTable() = default; @@ -107,8 +106,8 @@ class GcsTable { template class GcsTableWithJobId : public GcsTable { public: - explicit GcsTableWithJobId(std::shared_ptr store_client) - : GcsTable(std::move(store_client)) {} + explicit GcsTableWithJobId(std::shared_ptr &store_client) + : GcsTable(store_client) {} /// Write data to the table asynchronously. /// @@ -153,16 +152,16 @@ class GcsTableWithJobId : public GcsTable { class GcsJobTable : public GcsTable { public: - explicit GcsJobTable(std::shared_ptr store_client) - : GcsTable(std::move(store_client)) { + explicit GcsJobTable(std::shared_ptr &store_client) + : GcsTable(store_client) { table_name_ = TablePrefix_Name(TablePrefix::JOB); } }; class GcsActorTable : public GcsTableWithJobId { public: - explicit GcsActorTable(std::shared_ptr store_client) - : GcsTableWithJobId(std::move(store_client)) { + explicit GcsActorTable(std::shared_ptr &store_client) + : GcsTableWithJobId(store_client) { table_name_ = TablePrefix_Name(TablePrefix::ACTOR); } @@ -173,16 +172,16 @@ class GcsActorTable : public GcsTableWithJobId { class GcsPlacementGroupTable : public GcsTable { public: - explicit GcsPlacementGroupTable(std::shared_ptr store_client) - : GcsTable(std::move(store_client)) { + explicit GcsPlacementGroupTable(std::shared_ptr &store_client) + : GcsTable(store_client) { table_name_ = TablePrefix_Name(TablePrefix::PLACEMENT_GROUP); } }; class GcsTaskTable : public GcsTableWithJobId { public: - explicit GcsTaskTable(std::shared_ptr store_client) - : GcsTableWithJobId(std::move(store_client)) { + explicit GcsTaskTable(std::shared_ptr &store_client) + : GcsTableWithJobId(store_client) { table_name_ = TablePrefix_Name(TablePrefix::TASK); } @@ -192,8 +191,8 @@ class GcsTaskTable : public GcsTableWithJobId { class GcsTaskLeaseTable : public GcsTableWithJobId { public: - explicit GcsTaskLeaseTable(std::shared_ptr store_client) - : GcsTableWithJobId(std::move(store_client)) { + explicit GcsTaskLeaseTable(std::shared_ptr &store_client) + : GcsTableWithJobId(store_client) { table_name_ = TablePrefix_Name(TablePrefix::TASK_LEASE); } @@ -204,8 +203,8 @@ class GcsTaskLeaseTable : public GcsTableWithJobId { class GcsTaskReconstructionTable : public GcsTableWithJobId { public: - explicit GcsTaskReconstructionTable(std::shared_ptr store_client) - : GcsTableWithJobId(std::move(store_client)) { + explicit GcsTaskReconstructionTable(std::shared_ptr &store_client) + : GcsTableWithJobId(store_client) { table_name_ = TablePrefix_Name(TablePrefix::TASK_RECONSTRUCTION); } @@ -215,8 +214,8 @@ class GcsTaskReconstructionTable class GcsObjectTable : public GcsTableWithJobId { public: - explicit GcsObjectTable(std::shared_ptr store_client) - : GcsTableWithJobId(std::move(store_client)) { + explicit GcsObjectTable(std::shared_ptr &store_client) + : GcsTableWithJobId(store_client) { table_name_ = TablePrefix_Name(TablePrefix::OBJECT); } @@ -226,56 +225,56 @@ class GcsObjectTable : public GcsTableWithJobId { class GcsNodeTable : public GcsTable { public: - explicit GcsNodeTable(std::shared_ptr store_client) - : GcsTable(std::move(store_client)) { + explicit GcsNodeTable(std::shared_ptr &store_client) + : GcsTable(store_client) { table_name_ = TablePrefix_Name(TablePrefix::NODE); } }; class GcsNodeResourceTable : public GcsTable { public: - explicit GcsNodeResourceTable(std::shared_ptr store_client) - : GcsTable(std::move(store_client)) { + explicit GcsNodeResourceTable(std::shared_ptr &store_client) + : GcsTable(store_client) { table_name_ = TablePrefix_Name(TablePrefix::NODE_RESOURCE); } }; class GcsPlacementGroupScheduleTable : public GcsTable { public: - explicit GcsPlacementGroupScheduleTable(std::shared_ptr store_client) - : GcsTable(std::move(store_client)) { + explicit GcsPlacementGroupScheduleTable(std::shared_ptr &store_client) + : GcsTable(store_client) { table_name_ = TablePrefix_Name(TablePrefix::PLACEMENT_GROUP_SCHEDULE); } }; class GcsResourceUsageBatchTable : public GcsTable { public: - explicit GcsResourceUsageBatchTable(std::shared_ptr store_client) - : GcsTable(std::move(store_client)) { + explicit GcsResourceUsageBatchTable(std::shared_ptr &store_client) + : GcsTable(store_client) { table_name_ = TablePrefix_Name(TablePrefix::RESOURCE_USAGE_BATCH); } }; class GcsProfileTable : public GcsTable { public: - explicit GcsProfileTable(std::shared_ptr store_client) - : GcsTable(std::move(store_client)) { + explicit GcsProfileTable(std::shared_ptr &store_client) + : GcsTable(store_client) { table_name_ = TablePrefix_Name(TablePrefix::PROFILE); } }; class GcsWorkerTable : public GcsTable { public: - explicit GcsWorkerTable(std::shared_ptr store_client) - : GcsTable(std::move(store_client)) { + explicit GcsWorkerTable(std::shared_ptr &store_client) + : GcsTable(store_client) { table_name_ = TablePrefix_Name(TablePrefix::WORKERS); } }; class GcsInternalConfigTable : public GcsTable { public: - explicit GcsInternalConfigTable(std::shared_ptr store_client) - : GcsTable(std::move(store_client)) { + explicit GcsInternalConfigTable(std::shared_ptr &store_client) + : GcsTable(store_client) { table_name_ = TablePrefix_Name(TablePrefix::INTERNAL_CONFIG); } }; @@ -286,29 +285,6 @@ class GcsInternalConfigTable : public GcsTable { /// derive from this class and override class member variables. class GcsTableStorage { public: - explicit GcsTableStorage(std::shared_ptr store_client) - : store_client_(std::move(store_client)) { - job_table_ = std::make_unique(store_client_); - actor_table_ = std::make_unique(store_client_); - placement_group_table_ = std::make_unique(store_client_); - task_table_ = std::make_unique(store_client_); - task_lease_table_ = std::make_unique(store_client_); - task_reconstruction_table_ = - std::make_unique(store_client_); - object_table_ = std::make_unique(store_client_); - node_table_ = std::make_unique(store_client_); - node_resource_table_ = std::make_unique(store_client_); - placement_group_schedule_table_ = - std::make_unique(store_client_); - placement_group_schedule_table_ = - std::make_unique(store_client_); - resource_usage_batch_table_ = - std::make_unique(store_client_); - profile_table_ = std::make_unique(store_client_); - worker_table_ = std::make_unique(store_client_); - system_config_table_ = std::make_unique(store_client_); - } - GcsJobTable &JobTable() { RAY_CHECK(job_table_ != nullptr); return *job_table_; @@ -407,8 +383,26 @@ class GcsTableStorage { /// that uses redis as storage. class RedisGcsTableStorage : public GcsTableStorage { public: - explicit RedisGcsTableStorage(std::shared_ptr redis_client) - : GcsTableStorage(std::make_shared(std::move(redis_client))) {} + explicit RedisGcsTableStorage(std::shared_ptr redis_client) { + store_client_ = std::make_shared(redis_client); + job_table_.reset(new GcsJobTable(store_client_)); + actor_table_.reset(new GcsActorTable(store_client_)); + placement_group_table_.reset(new GcsPlacementGroupTable(store_client_)); + task_table_.reset(new GcsTaskTable(store_client_)); + task_lease_table_.reset(new GcsTaskLeaseTable(store_client_)); + task_reconstruction_table_.reset(new GcsTaskReconstructionTable(store_client_)); + object_table_.reset(new GcsObjectTable(store_client_)); + node_table_.reset(new GcsNodeTable(store_client_)); + node_resource_table_.reset(new GcsNodeResourceTable(store_client_)); + placement_group_schedule_table_.reset( + new GcsPlacementGroupScheduleTable(store_client_)); + placement_group_schedule_table_.reset( + new GcsPlacementGroupScheduleTable(store_client_)); + resource_usage_batch_table_.reset(new GcsResourceUsageBatchTable(store_client_)); + profile_table_.reset(new GcsProfileTable(store_client_)); + worker_table_.reset(new GcsWorkerTable(store_client_)); + system_config_table_.reset(new GcsInternalConfigTable(store_client_)); + } }; /// \class InMemoryGcsTableStorage @@ -416,8 +410,24 @@ class RedisGcsTableStorage : public GcsTableStorage { /// that uses memory as storage. class InMemoryGcsTableStorage : public GcsTableStorage { public: - explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service) - : GcsTableStorage(std::make_shared(main_io_service)) {} + explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service) { + store_client_ = std::make_shared(main_io_service); + job_table_.reset(new GcsJobTable(store_client_)); + actor_table_.reset(new GcsActorTable(store_client_)); + placement_group_table_.reset(new GcsPlacementGroupTable(store_client_)); + task_table_.reset(new GcsTaskTable(store_client_)); + task_lease_table_.reset(new GcsTaskLeaseTable(store_client_)); + task_reconstruction_table_.reset(new GcsTaskReconstructionTable(store_client_)); + object_table_.reset(new GcsObjectTable(store_client_)); + node_table_.reset(new GcsNodeTable(store_client_)); + node_resource_table_.reset(new GcsNodeResourceTable(store_client_)); + placement_group_schedule_table_.reset( + new GcsPlacementGroupScheduleTable(store_client_)); + resource_usage_batch_table_.reset(new GcsResourceUsageBatchTable(store_client_)); + profile_table_.reset(new GcsProfileTable(store_client_)); + worker_table_.reset(new GcsWorkerTable(store_client_)); + system_config_table_.reset(new GcsInternalConfigTable(store_client_)); + } }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc index f43d40dd392ac..b921fd2acd2a0 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc @@ -33,7 +33,6 @@ class MockActorScheduler : public gcs::GcsActorSchedulerInterface { void Reschedule(std::shared_ptr actor) {} void ReleaseUnusedWorkers( const std::unordered_map> &node_to_workers) {} - void OnActorDestruction(std::shared_ptr actor) {} MOCK_METHOD1(CancelOnNode, std::vector(const NodeID &node_id)); MOCK_METHOD2(CancelOnWorker, ActorID(const NodeID &node_id, const WorkerID &worker_id)); @@ -106,8 +105,8 @@ class GcsActorManagerTest : public ::testing::Test { store_client_ = std::make_shared(io_service_); gcs_table_storage_ = std::make_shared(io_service_); gcs_actor_manager_.reset(new gcs::GcsActorManager( - io_service_, mock_actor_scheduler_, gcs_table_storage_, gcs_pub_sub_, - *runtime_env_mgr_, [](const ActorID &actor_id) {}, + mock_actor_scheduler_, gcs_table_storage_, gcs_pub_sub_, *runtime_env_mgr_, + [](const ActorID &actor_id) {}, [this](const JobID &job_id) { return job_namespace_table_[job_id]; }, [this](std::function fn, boost::posix_time::milliseconds delay) { if (skip_delay_) { @@ -954,7 +953,6 @@ TEST_F(GcsActorManagerTest, TestRayNamespace) { } TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { - google::protobuf::Arena arena; skip_delay_ = false; auto job_id_1 = JobID::FromInt(1); auto request1 = Mocker::GenRegisterActorRequest(job_id_1, /*max_restarts=*/0, @@ -973,8 +971,7 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { { rpc::GetAllActorInfoRequest request; - auto &reply = - *google::protobuf::Arena::CreateMessage(&arena); + rpc::GetAllActorInfoReply reply; bool called = false; auto callback = [&called](Status status, std::function success, std::function failure) { called = true; }; @@ -984,8 +981,7 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { } { rpc::GetAllActorInfoRequest request; - auto &reply = - *google::protobuf::Arena::CreateMessage(&arena); + rpc::GetAllActorInfoReply reply; request.set_show_dead_jobs(true); std::promise promise; auto callback = [&promise](Status status, std::function success, @@ -998,8 +994,7 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { delayed_to_run_(); { rpc::GetAllActorInfoRequest request; - auto &reply = - *google::protobuf::Arena::CreateMessage(&arena); + rpc::GetAllActorInfoReply reply; request.set_show_dead_jobs(true); std::promise promise; auto callback = [&promise](Status status, std::function success, diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc deleted file mode 100644 index 0829caf3e0d91..0000000000000 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// clang-format off -#include "gtest/gtest.h" -#include "gmock/gmock.h" -#include "ray/gcs/gcs_server/gcs_actor_manager.h" -#include "ray/gcs/gcs_server/gcs_actor_scheduler.h" -#include "mock/ray/gcs/store_client/store_client.h" -#include "mock/ray/gcs/gcs_server/gcs_node_manager.h" -#include "mock/ray/raylet_client/raylet_client.h" -#include "mock/ray/pubsub/subscriber.h" -#include "mock/ray/gcs/pubsub/gcs_pub_sub.h" -#include "mock/ray/rpc/worker/core_worker_client.h" -// clang-format on -using namespace ::testing; - -namespace ray { -namespace gcs { -struct MockCallback { - MOCK_METHOD(void, Call, ((std::shared_ptr))); - void operator()(std::shared_ptr a) { return Call(a); } -}; - -class GcsActorSchedulerTest : public Test { - public: - void SetUp() override { - store_client = std::make_shared(); - actor_table = std::make_unique(store_client); - gcs_node_manager = std::make_unique(); - pub_sub = std::make_shared(); - raylet_client = std::make_shared(); - core_worker_client = std::make_shared(); - client_pool = std::make_shared( - [this](const rpc::Address &) { return raylet_client; }); - actor_scheduler = std::make_unique( - io_context, *actor_table, *gcs_node_manager, pub_sub, - [this](auto a) { schedule_failure_handler(a); }, - [this](auto a, const rpc::PushTaskReply) { schedule_success_handler(a); }, - client_pool, [this](const rpc::Address &) { return core_worker_client; }); - auto node_info = std::make_shared(); - node_info->set_state(rpc::GcsNodeInfo::ALIVE); - node_id = NodeID::FromRandom(); - node_info->set_node_id(node_id.Binary()); - worker_id = WorkerID::FromRandom(); - gcs_node_manager->AddNode(node_info); - } - std::shared_ptr raylet_client; - instrumented_io_context io_context; - std::shared_ptr store_client; - std::unique_ptr actor_table; - std::unique_ptr actor_scheduler; - std::unique_ptr gcs_node_manager; - std::shared_ptr pub_sub; - std::shared_ptr core_worker_client; - std::shared_ptr client_pool; - MockCallback schedule_failure_handler; - MockCallback schedule_success_handler; - NodeID node_id; - WorkerID worker_id; -}; - -TEST_F(GcsActorSchedulerTest, KillWorkerLeak1) { - // Ensure worker is not leak in the following case: - // 1. Gcs start to lease a worker - // 2. Gcs cancel the actor - // 3. Gcs lease reply with a grant - // We'd like to test the worker got released eventually. - // Worker is released with actor killing - auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); - rpc::ActorTableData actor_data; - actor_data.set_state(rpc::ActorTableData::PENDING_CREATION); - actor_data.set_actor_id(actor_id.Binary()); - auto actor = std::make_shared(actor_data); - std::function cb; - EXPECT_CALL(*raylet_client, RequestWorkerLease(Matcher(), _, _)) - .WillOnce(testing::SaveArg<1>(&cb)); - // Ensure actor is killed - EXPECT_CALL(*core_worker_client, KillActor(_, _)); - actor_scheduler->Schedule(actor); - actor->GetMutableActorTableData()->set_state(rpc::ActorTableData::DEAD); - actor_scheduler->CancelOnNode(node_id); - ray::rpc::RequestWorkerLeaseReply reply; - reply.mutable_worker_address()->set_raylet_id(node_id.Binary()); - reply.mutable_worker_address()->set_worker_id(worker_id.Binary()); - cb(Status::OK(), reply); -} - -TEST_F(GcsActorSchedulerTest, KillWorkerLeak2) { - // Ensure worker is not leak in the following case: - // 1. Actor is in pending creation - // 2. Gcs push creation task to run in worker - // 3. Cancel the task - // 4. Task creating reply received - // We'd like to test the worker got released eventually. - // Worker is released with actor killing - auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); - rpc::ActorTableData actor_data; - actor_data.set_state(rpc::ActorTableData::PENDING_CREATION); - actor_data.set_actor_id(actor_id.Binary()); - auto actor = std::make_shared(actor_data); - rpc::ClientCallback request_worker_lease_cb; - // Ensure actor is killed - EXPECT_CALL(*core_worker_client, KillActor(_, _)); - EXPECT_CALL(*raylet_client, RequestWorkerLease(Matcher(), _, _)) - .WillOnce(testing::SaveArg<1>(&request_worker_lease_cb)); - - std::function async_put_with_index_cb; - // Leasing successfully - EXPECT_CALL(*store_client, AsyncPutWithIndex(_, _, _, _, _)) - .WillOnce(DoAll(SaveArg<4>(&async_put_with_index_cb), Return(Status::OK()))); - actor_scheduler->Schedule(actor); - rpc::RequestWorkerLeaseReply reply; - reply.mutable_worker_address()->set_raylet_id(node_id.Binary()); - reply.mutable_worker_address()->set_worker_id(worker_id.Binary()); - request_worker_lease_cb(Status::OK(), reply); - - rpc::ClientCallback push_normal_task_cb; - // Worker start to run task - EXPECT_CALL(*core_worker_client, PushNormalTask(_, _)) - .WillOnce(testing::SaveArg<1>(&push_normal_task_cb)); - async_put_with_index_cb(Status::OK()); - actor->GetMutableActorTableData()->set_state(rpc::ActorTableData::DEAD); - actor_scheduler->CancelOnWorker(node_id, worker_id); - push_normal_task_cb(Status::OK(), rpc::PushTaskReply()); -} -} // namespace gcs -} // namespace ray diff --git a/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc index 48793907f117f..ada5f0094872b 100644 --- a/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc @@ -147,15 +147,13 @@ TEST_F(GcsBasedActorSchedulerTest, TestNotEnoughClusterResources) { ASSERT_TRUE(actor->GetNodeID().IsNil()); } -TEST_F(GcsBasedActorSchedulerTest, TestScheduleAndDestroyOneActor) { +TEST_F(GcsBasedActorSchedulerTest, TestScheduleOneActor) { // Add a node with 64 memory units and 8 CPU. std::unordered_map node_resources = {{kMemory_ResourceLabel, 64}, {kCPU_ResourceLabel, 8}}; auto node = AddNewNode(node_resources); auto node_id = NodeID::FromBinary(node->node_id()); ASSERT_EQ(1, gcs_node_manager_->GetAllAliveNodes().size()); - auto cluster_resources_before_scheduling = gcs_resource_manager_->GetClusterResources(); - ASSERT_TRUE(cluster_resources_before_scheduling.contains(node_id)); // Schedule a actor (requiring 32 memory units and 4 CPU). std::unordered_map required_placement_resources = { @@ -184,20 +182,6 @@ TEST_F(GcsBasedActorSchedulerTest, TestScheduleAndDestroyOneActor) { ASSERT_EQ(actor, success_actors_.front()); ASSERT_EQ(actor->GetNodeID(), node_id); ASSERT_EQ(actor->GetWorkerID(), worker_id); - - auto cluster_resources_after_scheduling = gcs_resource_manager_->GetClusterResources(); - ASSERT_TRUE(cluster_resources_after_scheduling.contains(node_id)); - ASSERT_FALSE( - cluster_resources_before_scheduling[node_id].GetAvailableResources().IsEqual( - cluster_resources_after_scheduling[node_id].GetAvailableResources())); - - // When destroying an actor, its acquired resources have to be returned. - gcs_actor_scheduler_->OnActorDestruction(actor); - auto cluster_resources_after_destruction = gcs_resource_manager_->GetClusterResources(); - ASSERT_TRUE(cluster_resources_after_destruction.contains(node_id)); - ASSERT_TRUE( - cluster_resources_before_scheduling[node_id].GetAvailableResources().IsEqual( - cluster_resources_after_destruction[node_id].GetAvailableResources())); } TEST_F(GcsBasedActorSchedulerTest, TestBalancedSchedule) { diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc deleted file mode 100644 index e017fb793bafe..0000000000000 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// clang-format off -#include "gtest/gtest.h" -#include "gmock/gmock.h" -#include "ray/gcs/gcs_server/gcs_placement_group_manager.h" -#include "mock/ray/gcs/gcs_server/gcs_placement_group_manager.h" -#include "mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h" -#include "mock/ray/gcs/gcs_server/gcs_resource_manager.h" -#include "mock/ray/gcs/store_client/store_client.h" -#include "ray/gcs/test/gcs_test_util.h" -// clang-format on - -using namespace ::testing; -using namespace ray; -using namespace ray::gcs; -namespace ray { -namespace gcs { - -class GcsPlacementGroupManagerMockTest : public Test { - public: - void SetUp() override { - store_client_ = std::make_shared(); - gcs_table_storage_ = std::make_shared(store_client_); - gcs_placement_group_scheduler_ = - std::make_shared(); - resource_manager_ = - std::make_shared(io_context_, nullptr, nullptr, true); - - gcs_placement_group_manager_ = std::make_unique( - io_context_, gcs_placement_group_scheduler_, gcs_table_storage_, - *resource_manager_, [](auto &) { return ""; }); - } - - std::unique_ptr gcs_placement_group_manager_; - std::shared_ptr gcs_placement_group_scheduler_; - std::shared_ptr gcs_table_storage_; - std::shared_ptr store_client_; - std::shared_ptr resource_manager_; - instrumented_io_context io_context_; -}; - -TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityReschedule) { - // Test priority works - // When return with reschedule, it should be given with the highest pri - auto req = - Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); - auto pg = std::make_shared(req, ""); - auto cb = [](Status s) {}; - PGSchedulingFailureCallback failure_callback; - PGSchedulingSuccessfulCallback success_callback; - StatusCallback put_cb; - EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _)) - .WillOnce(DoAll(SaveArg<3>(&put_cb), Return(Status::OK()))); - EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_, _, _)) - .WillOnce(DoAll(SaveArg<1>(&failure_callback), SaveArg<2>(&success_callback))); - auto now = absl::GetCurrentTimeNanos(); - gcs_placement_group_manager_->RegisterPlacementGroup(pg, cb); - auto &pending_queue = gcs_placement_group_manager_->pending_placement_groups_; - ASSERT_EQ(1, pending_queue.size()); - ASSERT_LE(now, pending_queue.begin()->first); - ASSERT_GE(absl::GetCurrentTimeNanos(), pending_queue.begin()->first); - put_cb(Status::OK()); - pg->UpdateState(rpc::PlacementGroupTableData::RESCHEDULING); - failure_callback(pg, true); - ASSERT_EQ(1, pending_queue.size()); - ASSERT_GE(0, pending_queue.begin()->first); -} - -TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityFailed) { - // Test priority works - // When return with a failure, exp backoff should work - auto req = - Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); - auto pg = std::make_shared(req, ""); - auto cb = [](Status s) {}; - PGSchedulingFailureCallback failure_callback; - PGSchedulingSuccessfulCallback success_callback; - StatusCallback put_cb; - EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _)) - .WillOnce(DoAll(SaveArg<3>(&put_cb), Return(Status::OK()))); - EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_, _, _)) - .Times(2) - .WillRepeatedly( - DoAll(SaveArg<1>(&failure_callback), SaveArg<2>(&success_callback))); - auto now = absl::GetCurrentTimeNanos(); - gcs_placement_group_manager_->RegisterPlacementGroup(pg, cb); - auto &pending_queue = gcs_placement_group_manager_->pending_placement_groups_; - ASSERT_EQ(1, pending_queue.size()); - ASSERT_LE(now, pending_queue.begin()->first); - ASSERT_GE(absl::GetCurrentTimeNanos(), pending_queue.begin()->first); - put_cb(Status::OK()); - pg->UpdateState(rpc::PlacementGroupTableData::PENDING); - now = absl::GetCurrentTimeNanos(); - failure_callback(pg, true); - auto exp_backer = ExponentialBackOff( - 1000000 * RayConfig::instance().gcs_create_placement_group_retry_min_interval_ms(), - RayConfig::instance().gcs_create_placement_group_retry_multiplier(), - 1000000 * RayConfig::instance().gcs_create_placement_group_retry_max_interval_ms()); - auto next = exp_backer.Next(); - ASSERT_DOUBLE_EQ( - next, - 1000000 * RayConfig::instance().gcs_create_placement_group_retry_min_interval_ms()); - ASSERT_EQ(1, pending_queue.size()); - auto rank = pending_queue.begin()->first; - ASSERT_LE(now + next, rank); - // ScheduleUnplacedBundles is not called here - gcs_placement_group_manager_->SchedulePendingPlacementGroups(); - ASSERT_EQ(1, pending_queue.size()); - ASSERT_EQ(rank, pending_queue.begin()->first); - - absl::SleepFor(absl::Milliseconds(1) + - absl::Nanoseconds(rank - absl::GetCurrentTimeNanos())); - gcs_placement_group_manager_->SchedulePendingPlacementGroups(); - ASSERT_EQ(0, pending_queue.size()); - pg->UpdateState(rpc::PlacementGroupTableData::PENDING); - now = absl::GetCurrentTimeNanos(); - failure_callback(pg, true); - next = RayConfig::instance().gcs_create_placement_group_retry_multiplier() * next; - ASSERT_EQ(1, pending_queue.size()); - ASSERT_LE(now + next, pending_queue.begin()->first); -} - -TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityOrder) { - // Test priority works - // Add two pgs - // Fail one and make sure it's scheduled later - auto req1 = - Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); - auto pg1 = std::make_shared(req1, ""); - auto req2 = - Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); - auto pg2 = std::make_shared(req2, ""); - auto cb = [](Status s) {}; - PGSchedulingFailureCallback failure_callback; - PGSchedulingSuccessfulCallback success_callback; - StatusCallback put_cb; - EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _)) - .Times(2) - .WillRepeatedly(DoAll(SaveArg<3>(&put_cb), Return(Status::OK()))); - EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_, _, _)) - .Times(2) - .WillRepeatedly( - DoAll(SaveArg<1>(&failure_callback), SaveArg<2>(&success_callback))); - gcs_placement_group_manager_->RegisterPlacementGroup(pg1, cb); - gcs_placement_group_manager_->RegisterPlacementGroup(pg2, cb); - auto &pending_queue = gcs_placement_group_manager_->pending_placement_groups_; - ASSERT_EQ(2, pending_queue.size()); - put_cb(Status::OK()); - ASSERT_EQ(1, pending_queue.size()); - // PG1 is scheduled first, so PG2 is in pending queue - ASSERT_EQ(pg2, pending_queue.begin()->second.second); - failure_callback(pg1, true); - ASSERT_EQ(2, pending_queue.size()); - gcs_placement_group_manager_->SchedulePendingPlacementGroups(); - // PG2 is scheduled for the next, so PG1 is in pending queue - ASSERT_EQ(1, pending_queue.size()); - ASSERT_EQ(pg1, pending_queue.begin()->second.second); -} - -} // namespace gcs -} // namespace ray diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc index 8eeed97f7eca6..7c941aa27f815 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc @@ -22,7 +22,6 @@ #include "ray/gcs/test/gcs_test_util.h" namespace ray { -namespace gcs { using ::testing::_; using StatusCallback = std::function; @@ -136,8 +135,6 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { EXPECT_TRUE(WaitForCondition(condition, 10 * 1000)); } - ExponentialBackOff GetExpBackOff() { return ExponentialBackOff(0, 1); } - std::shared_ptr mock_placement_group_scheduler_; std::unique_ptr gcs_placement_group_manager_; std::unordered_map job_namespace_table_; @@ -151,26 +148,6 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { std::shared_ptr redis_client_; }; -TEST_F(GcsPlacementGroupManagerTest, TestPlacementGroupBundleCache) { - auto request = Mocker::GenCreatePlacementGroupRequest(); - std::atomic registered_placement_group_count(0); - RegisterPlacementGroup(request, - [®istered_placement_group_count](const Status &status) { - ++registered_placement_group_count; - }); - ASSERT_EQ(registered_placement_group_count, 1); - WaitForExpectedPgCount(1); - auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); - ASSERT_TRUE(placement_group->cached_bundle_specs_.empty()); - // Fill the cache and verify it. - const auto &bundle_specs = placement_group->GetBundles(); - ASSERT_EQ(placement_group->cached_bundle_specs_, bundle_specs); - ASSERT_FALSE(placement_group->cached_bundle_specs_.empty()); - // Invalidate the cache and verify it. - RAY_UNUSED(placement_group->GetMutableBundle(0)); - ASSERT_TRUE(placement_group->cached_bundle_specs_.empty()); -} - TEST_F(GcsPlacementGroupManagerTest, TestBasic) { auto request = Mocker::GenCreatePlacementGroupRequest(); std::atomic registered_placement_group_count(0); @@ -199,8 +176,7 @@ TEST_F(GcsPlacementGroupManagerTest, TestSchedulingFailed) { auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.clear(); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), true); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); gcs_placement_group_manager_->SchedulePendingPlacementGroups(); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_.size(), 1); mock_placement_group_scheduler_->placement_groups_.clear(); @@ -264,8 +240,7 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeAdd) { mock_placement_group_scheduler_->placement_groups_.pop_back(); // If the creation of placement group fails, it will be rescheduled after a short time. - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), true); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); WaitForExpectedPgCount(1); } @@ -280,8 +255,7 @@ TEST_F(GcsPlacementGroupManagerTest, TestRemovingPendingPlacementGroup) { auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.clear(); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), true); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::PENDING); const auto &placement_group_id = placement_group->GetPlacementGroupID(); gcs_placement_group_manager_->RemovePlacementGroup(placement_group_id, @@ -317,8 +291,7 @@ TEST_F(GcsPlacementGroupManagerTest, TestRemovingLeasingPlacementGroup) { gcs_placement_group_manager_->RemovePlacementGroup(placement_group_id, [](const Status &status) {}); ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::REMOVED); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), true); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); // Make sure it is not rescheduled gcs_placement_group_manager_->SchedulePendingPlacementGroups(); @@ -381,6 +354,7 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeDead) { placement_group->GetMutableBundle(0)->set_node_id(NodeID::FromRandom().Binary()); placement_group->GetMutableBundle(1)->set_node_id(NodeID::FromRandom().Binary()); mock_placement_group_scheduler_->placement_groups_.pop_back(); + // If a node dies, we will set the bundles above it to be unplaced and reschedule the // placement group. The placement group state is set to `RESCHEDULING` and will be // scheduled first. @@ -399,15 +373,14 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeDead) { placement_group->GetPlacementGroupID()); const auto &bundles = mock_placement_group_scheduler_->placement_groups_[0]->GetBundles(); - EXPECT_TRUE(NodeID::FromBinary(bundles[0]->GetMessage().node_id()).IsNil()); - EXPECT_FALSE(NodeID::FromBinary(bundles[1]->GetMessage().node_id()).IsNil()); + EXPECT_TRUE(NodeID::FromBinary(bundles[0]->GetMutableMessage().node_id()).IsNil()); + EXPECT_FALSE(NodeID::FromBinary(bundles[1]->GetMutableMessage().node_id()).IsNil()); // If `RESCHEDULING` placement group fails to create, we will schedule it again first. placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.pop_back(); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_.size(), 0); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), true); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); WaitForExpectedPgCount(1); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_[0]->GetPlacementGroupID(), placement_group->GetPlacementGroupID()); @@ -553,8 +526,7 @@ TEST_F(GcsPlacementGroupManagerTest, TestSchedulingCanceledWhenPgIsInfeasible) { mock_placement_group_scheduler_->placement_groups_.clear(); // Mark it non-retryable. - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), false); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, false); // Schedule twice to make sure it will not be scheduled afterward. gcs_placement_group_manager_->SchedulePendingPlacementGroups(); @@ -635,7 +607,6 @@ TEST_F(GcsPlacementGroupManagerTest, TestRayNamespace) { } } -} // namespace gcs } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index 5d265ac1bbb59..cbe1ba78495f4 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -33,7 +33,6 @@ class GcsServerTest : public ::testing::Test { config.grpc_server_name = "MockedGcsServer"; config.grpc_server_thread_num = 1; config.redis_address = "127.0.0.1"; - config.node_ip_address = "127.0.0.1"; config.enable_sharding_conn = false; config.redis_port = TEST_REDIS_SERVER_PORTS.front(); gcs_server_.reset(new gcs::GcsServer(config, io_service_)); diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h index 11f0783bb8465..249ac5a9fdd6a 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -70,10 +70,6 @@ struct GcsServerMocker { return Status::OK(); } - void ReportWorkerBacklog( - const WorkerID &worker_id, - const std::vector &backlog_reports) override {} - /// WorkerLeaseInterface void RequestWorkerLease( const ray::TaskSpecification &resource_spec, @@ -83,14 +79,6 @@ struct GcsServerMocker { callbacks.push_back(callback); } - void RequestWorkerLease( - const rpc::TaskSpec &spec, - const rpc::ClientCallback &callback, - const int64_t backlog_size = -1) override { - num_workers_requested += 1; - callbacks.push_back(callback); - } - /// WorkerLeaseInterface void ReleaseUnusedWorkers( const std::vector &workers_in_use, @@ -192,7 +180,7 @@ struct GcsServerMocker { /// ResourceReserveInterface void CancelResourceReserve( - const BundleSpecification &bundle_spec, + BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) override { num_return_requested += 1; diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.h b/src/ray/gcs/pubsub/gcs_pub_sub.h index 70828a3679691..b871a02b13ddd 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.h +++ b/src/ray/gcs/pubsub/gcs_pub_sub.h @@ -96,9 +96,6 @@ class GcsPubSub { std::string DebugString() const; - protected: - GcsPubSub() : GcsPubSub(nullptr) {} - private: /// Represents a caller's command to subscribe or unsubscribe to a given /// channel. diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 443c42f9dee69..c7244aac80549 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -15,7 +15,7 @@ #pragma once #include -#include +#include #include #include #include diff --git a/src/ray/object_manager/object_buffer_pool.cc b/src/ray/object_manager/object_buffer_pool.cc index b25439cd7203c..e6e214b3062f2 100644 --- a/src/ray/object_manager/object_buffer_pool.cc +++ b/src/ray/object_manager/object_buffer_pool.cc @@ -14,7 +14,6 @@ #include "ray/object_manager/object_buffer_pool.h" -#include "absl/time/time.h" #include "ray/common/status.h" #include "ray/util/logging.h" @@ -22,49 +21,26 @@ namespace ray { ObjectBufferPool::ObjectBufferPool(const std::string &store_socket_name, uint64_t chunk_size) - : store_socket_name_(store_socket_name), default_chunk_size_(chunk_size) { + : default_chunk_size_(chunk_size) { + store_socket_name_ = store_socket_name; RAY_CHECK_OK(store_client_.Connect(store_socket_name_.c_str(), "", 0, 300)); } ObjectBufferPool::~ObjectBufferPool() { - absl::MutexLock lock(&pool_mutex_); - auto inflight_ops = create_buffer_ops_; - pool_mutex_.Unlock(); - - for (const auto &[id, cond_var] : inflight_ops) { - cond_var->SignalAll(); - } - auto no_inflight = [this]() { - pool_mutex_.AssertReaderHeld(); - return create_buffer_ops_.empty(); - }; - // Assume no request would arrive, acquire pool_mutex_ when there is no inflight - // operation. Otherwise print an error. - if (!pool_mutex_.LockWhenWithTimeout(absl::Condition(&no_inflight), absl::Seconds(5))) { - RAY_LOG(ERROR) - << create_buffer_ops_.size() << " remaining inflight create buffer operations " - << "during ObjectBufferPool destruction. Either abort these operations before " - << "destroying ObjectBufferPool, or refactor ObjectBufferPool to make it " - "unnecessary to wait for the operations' completion."; + // Abort everything in progress. + auto create_buf_state_copy = create_buffer_state_; + for (const auto &pair : create_buf_state_copy) { + AbortCreate(pair.first); } - - // Abort unfinished buffers in progress. - for (auto it = create_buffer_state_.begin(); it != create_buffer_state_.end(); it++) { - RAY_CHECK_OK(store_client_.Release(it->first)); - RAY_CHECK_OK(store_client_.Abort(it->first)); - create_buffer_state_.erase(it); - } - RAY_CHECK(create_buffer_state_.empty()); RAY_CHECK_OK(store_client_.Disconnect()); } -uint64_t ObjectBufferPool::GetNumChunks(uint64_t data_size) const { +uint64_t ObjectBufferPool::GetNumChunks(uint64_t data_size) { return (data_size + default_chunk_size_ - 1) / default_chunk_size_; } -uint64_t ObjectBufferPool::GetBufferLength(uint64_t chunk_index, - uint64_t data_size) const { +uint64_t ObjectBufferPool::GetBufferLength(uint64_t chunk_index, uint64_t data_size) { return (chunk_index + 1) * default_chunk_size_ > data_size ? data_size % default_chunk_size_ : default_chunk_size_; @@ -73,7 +49,7 @@ uint64_t ObjectBufferPool::GetBufferLength(uint64_t chunk_index, std::pair, ray::Status> ObjectBufferPool::CreateObjectReader(const ObjectID &object_id, rpc::Address owner_address) { - absl::MutexLock lock(&pool_mutex_); + std::lock_guard lock(pool_mutex_); std::vector object_ids{object_id}; std::vector object_buffers(1); @@ -100,21 +76,53 @@ ray::Status ObjectBufferPool::CreateChunk(const ObjectID &object_id, const rpc::Address &owner_address, uint64_t data_size, uint64_t metadata_size, uint64_t chunk_index) { - absl::MutexLock lock(&pool_mutex_); - RAY_RETURN_NOT_OK(EnsureBufferExists(object_id, owner_address, data_size, metadata_size, - chunk_index)); - auto &state = create_buffer_state_.at(object_id); - if (state.chunk_state[chunk_index] != CreateChunkState::AVAILABLE) { + std::unique_lock lock(pool_mutex_); + if (create_buffer_state_.count(object_id) == 0) { + int64_t object_size = data_size - metadata_size; + // Try to create shared buffer. + std::shared_ptr data; + + // Release the buffer pool lock during the blocking create call. + lock.unlock(); + Status s = store_client_.CreateAndSpillIfNeeded( + object_id, owner_address, object_size, NULL, metadata_size, &data, + plasma::flatbuf::ObjectSource::ReceivedFromRemoteRaylet); + lock.lock(); + + // Another thread may have succeeded in creating the chunk while the lock + // was released. In that case skip the remainder of the creation block. + if (create_buffer_state_.count(object_id) == 0) { + std::vector buffer; + if (!s.ok()) { + // Create failed. The object may already exist locally. If something else went + // wrong, another chunk will succeed in creating the buffer, and this + // chunk will eventually make it here via pull requests. + return ray::Status::IOError(s.message()); + } + // Read object into store. + uint8_t *mutable_data = data->Data(); + uint64_t num_chunks = GetNumChunks(data_size); + create_buffer_state_.emplace( + std::piecewise_construct, std::forward_as_tuple(object_id), + std::forward_as_tuple(BuildChunks(object_id, mutable_data, data_size, data))); + RAY_LOG(DEBUG) << "Created object " << object_id + << " in plasma store, number of chunks: " << num_chunks + << ", chunk index: " << chunk_index; + RAY_CHECK(create_buffer_state_[object_id].chunk_info.size() == num_chunks); + } + } + if (create_buffer_state_[object_id].chunk_state[chunk_index] != + CreateChunkState::AVAILABLE) { // There can be only one reference to this chunk at any given time. return ray::Status::IOError("Chunk already received by a different thread."); } - state.chunk_state[chunk_index] = CreateChunkState::REFERENCED; + create_buffer_state_[object_id].chunk_state[chunk_index] = CreateChunkState::REFERENCED; return ray::Status::OK(); } void ObjectBufferPool::WriteChunk(const ObjectID &object_id, const uint64_t chunk_index, const std::string &data) { - absl::MutexLock lock(&pool_mutex_); + std::lock_guard lock(pool_mutex_); auto it = create_buffer_state_.find(object_id); if (it == create_buffer_state_.end() || it->second.chunk_state.at(chunk_index) != CreateChunkState::REFERENCED) { @@ -140,7 +148,7 @@ void ObjectBufferPool::WriteChunk(const ObjectID &object_id, const uint64_t chun } void ObjectBufferPool::AbortCreate(const ObjectID &object_id) { - absl::MutexLock lock(&pool_mutex_); + std::lock_guard lock(pool_mutex_); auto it = create_buffer_state_.find(object_id); if (it != create_buffer_state_.end()) { RAY_LOG(INFO) << "Not enough memory to create requested object " << object_id @@ -171,84 +179,13 @@ std::vector ObjectBufferPool::BuildChunks( return chunks; } -ray::Status ObjectBufferPool::EnsureBufferExists(const ObjectID &object_id, - const rpc::Address &owner_address, - uint64_t data_size, - uint64_t metadata_size, - uint64_t chunk_index) { - while (true) { - // Buffer for object_id already exists. - if (create_buffer_state_.contains(object_id)) { - return ray::Status::OK(); - } - - auto it = create_buffer_ops_.find(object_id); - if (it == create_buffer_ops_.end()) { - // No inflight create buffer operation, proceed to start one. - break; - } - - auto cond_var = it->second; - // Release pool_mutex_ while waiting, until the current inflight create buffer - // operation finishes. - cond_var->Wait(&pool_mutex_); - } - - // Indicate that there is an inflight create buffer operation, by inserting into - // create_buffer_ops_. - RAY_CHECK( - create_buffer_ops_.insert({object_id, std::make_shared()}).second); - const int64_t object_size = - static_cast(data_size) - static_cast(metadata_size); - std::shared_ptr data; - - // Release pool_mutex_ during the blocking create call. - pool_mutex_.Unlock(); - Status s = store_client_.CreateAndSpillIfNeeded( - object_id, owner_address, static_cast(object_size), nullptr, - static_cast(metadata_size), &data, - plasma::flatbuf::ObjectSource::ReceivedFromRemoteRaylet); - pool_mutex_.Lock(); - - // No other thread could have created the buffer. - RAY_CHECK(!create_buffer_state_.contains(object_id)); - - // Remove object_id from create_buffer_ops_ to indicate to the waiting ops that the - // inflight operation has finished. Wake up waiters so they can either start another - // create buffer op, or proceed after the buffer has been created. - { - auto it = create_buffer_ops_.find(object_id); - it->second->SignalAll(); - create_buffer_ops_.erase(it); - } - - if (!s.ok()) { - // Create failed. Buffer creation will be tried by another chunk. - // And this chunk will eventually make it here via retried pull requests. - return ray::Status::IOError(s.message()); - } - - // Read object into store. - uint8_t *mutable_data = data->Data(); - uint64_t num_chunks = GetNumChunks(data_size); - create_buffer_state_.emplace( - std::piecewise_construct, std::forward_as_tuple(object_id), - std::forward_as_tuple(BuildChunks(object_id, mutable_data, data_size, data))); - RAY_CHECK(create_buffer_state_[object_id].chunk_info.size() == num_chunks); - RAY_LOG(DEBUG) << "Created object " << object_id - << " in plasma store, number of chunks: " << num_chunks - << ", chunk index: " << chunk_index; - - return ray::Status::OK(); -} - void ObjectBufferPool::FreeObjects(const std::vector &object_ids) { - absl::MutexLock lock(&pool_mutex_); + std::lock_guard lock(pool_mutex_); RAY_CHECK_OK(store_client_.Delete(object_ids)); } std::string ObjectBufferPool::DebugString() const { - absl::MutexLock lock(&pool_mutex_); + std::lock_guard lock(pool_mutex_); std::stringstream result; result << "BufferPool:"; result << "\n- create buffer state map size: " << create_buffer_state_.size(); diff --git a/src/ray/object_manager/object_buffer_pool.h b/src/ray/object_manager/object_buffer_pool.h index b2722a3eceecc..05c51e5e00117 100644 --- a/src/ray/object_manager/object_buffer_pool.h +++ b/src/ray/object_manager/object_buffer_pool.h @@ -16,14 +16,12 @@ #include #include -#include +#include #include #include +#include #include -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/synchronization/mutex.h" #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/object_manager/memory_object_reader.h" @@ -70,14 +68,14 @@ class ObjectBufferPool { /// /// \param data_size The size of the object + metadata. /// \return The number of chunks into which the object will be split. - uint64_t GetNumChunks(uint64_t data_size) const; + uint64_t GetNumChunks(uint64_t data_size); /// Computes the buffer length of a chunk of an object. /// /// \param chunk_index The chunk index for which to obtain the buffer length. /// \param data_size The size of the object + metadata. /// \return The buffer length of the chunk at chunk_index. - uint64_t GetBufferLength(uint64_t chunk_index, uint64_t data_size) const; + uint64_t GetBufferLength(uint64_t chunk_index, uint64_t data_size); /// Returns an object reader for read. /// @@ -87,7 +85,7 @@ class ObjectBufferPool { /// this method. An IOError status is returned if the Get call on the plasma store /// fails, and the MemoryObjectReader will be empty. std::pair, ray::Status> CreateObjectReader( - const ObjectID &object_id, rpc::Address owner_address) LOCKS_EXCLUDED(pool_mutex_); + const ObjectID &object_id, rpc::Address owner_address); /// Returns a chunk of an empty object at the given chunk_index. The object chunk /// serves as the buffer that is to be written to by a connection receiving an @@ -108,7 +106,7 @@ class ObjectBufferPool { /// (with no intermediate AbortCreateChunk). ray::Status CreateChunk(const ObjectID &object_id, const rpc::Address &owner_address, uint64_t data_size, uint64_t metadata_size, - uint64_t chunk_index) LOCKS_EXCLUDED(pool_mutex_); + uint64_t chunk_index); /// Write to a Chunk of an object. If all chunks of an object is written, /// it seals the object. @@ -121,44 +119,34 @@ class ObjectBufferPool { /// \param chunk_index The index of the chunk. /// \param data The data to write into the chunk. void WriteChunk(const ObjectID &object_id, uint64_t chunk_index, - const std::string &data) LOCKS_EXCLUDED(pool_mutex_); + const std::string &data); /// Free a list of objects from object store. /// /// \param object_ids the The list of ObjectIDs to be deleted. /// \return Void. - void FreeObjects(const std::vector &object_ids) LOCKS_EXCLUDED(pool_mutex_); + void FreeObjects(const std::vector &object_ids); /// Abort the create operation associated with an object. This destroys the buffer /// state, including create operations in progress for all chunks of the object. - void AbortCreate(const ObjectID &object_id) LOCKS_EXCLUDED(pool_mutex_); + void AbortCreate(const ObjectID &object_id); /// Returns debug string for class. /// /// \return string. - std::string DebugString() const LOCKS_EXCLUDED(pool_mutex_); + std::string DebugString() const; private: /// Splits an object into ceil(data_size/chunk_size) chunks, which will /// either be read or written to in parallel. std::vector BuildChunks(const ObjectID &object_id, uint8_t *data, uint64_t data_size, - std::shared_ptr buffer_ref) - EXCLUSIVE_LOCKS_REQUIRED(pool_mutex_); - - /// Ensures buffer for the object exists, and creates the buffer if needed. - /// Returns OK if buffer exists. - /// Must hold pool_mutex_ when calling this function. pool_mutex_ can be released - /// during the call. - ray::Status EnsureBufferExists(const ObjectID &object_id, - const rpc::Address &owner_address, uint64_t data_size, - uint64_t metadata_size, uint64_t chunk_index) - EXCLUSIVE_LOCKS_REQUIRED(pool_mutex_); + std::shared_ptr buffer_ref); /// The state of a chunk associated with a create operation. enum class CreateChunkState : unsigned int { AVAILABLE = 0, REFERENCED, SEALED }; - /// Holds the state of creating chunks. Members are protected by pool_mutex_. + /// Holds the state of a create buffer. struct CreateBufferState { CreateBufferState() {} CreateBufferState(std::vector chunk_info) @@ -178,29 +166,18 @@ class ObjectBufferPool { /// Returned when GetChunk or CreateChunk fails. const ChunkInfo errored_chunk_ = {0, nullptr, 0, nullptr}; - /// Socket name of plasma store. - const std::string store_socket_name_; - + /// Mutex on public methods for thread-safe operations on + /// get_buffer_state_, create_buffer_state_, and store_client_. + mutable std::mutex pool_mutex_; /// Determines the maximum chunk size to be transferred by a single thread. const uint64_t default_chunk_size_; - - /// Mutex to protect create_buffer_ops_, create_buffer_state_ and following invariants: - /// - create_buffer_ops_ contains an object_id iff there is an inflight operation to - /// create the buffer for the object. - /// - An object_id cannot appear in both create_buffer_ops_ and create_buffer_state_. - mutable absl::Mutex pool_mutex_; - /// Makes sure each object has at most one inflight create buffer operation. - /// Other operations can wait on the std::condition_variable for the operation - /// to complete. If successful, the corresponding entry in create_buffer_state_ - /// will be created. - absl::flat_hash_map> create_buffer_ops_ - GUARDED_BY(pool_mutex_); /// The state of a buffer that's currently being used. - absl::flat_hash_map create_buffer_state_ - GUARDED_BY(pool_mutex_); + std::unordered_map create_buffer_state_; /// Plasma client pool. plasma::PlasmaClient store_client_; + /// Socket name of plasma store. + std::string store_socket_name_; }; } // namespace ray diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 3ee951d75553d..8e4dd703b91fb 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -88,7 +88,6 @@ ObjectManager::ObjectManager( buffer_pool_(config_.store_socket_name, config_.object_chunk_size), rpc_work_(rpc_service_), object_manager_server_("ObjectManager", config_.object_manager_port, - config_.object_manager_address == "127.0.0.1", config_.rpc_service_threads_number), object_manager_service_(rpc_service_, *this), client_call_manager_(main_service, config_.rpc_service_threads_number), @@ -442,18 +441,17 @@ void ObjectManager::PushObjectInternal(const ObjectID &object_id, const NodeID & [=]() { // Post to the multithreaded RPC event loop so that data is copied // off of the main thread. - SendObjectChunk( - push_id, object_id, node_id, chunk_id, rpc_client, - [=](const Status &status) { - // Post back to the main event loop because the - // PushManager is thread-safe. - main_service_->post( - [this, node_id, object_id]() { - push_manager_->OnChunkComplete(node_id, object_id); - }, - "ObjectManager.Push"); - }, - chunk_reader); + SendObjectChunk(push_id, object_id, node_id, chunk_id, rpc_client, + [=](const Status &status) { + // Post back to the main event loop because the + // PushManager is thread-safe. + main_service_->post( + [this, node_id, object_id]() { + push_manager_->OnChunkComplete(node_id, object_id); + }, + "ObjectManager.Push"); + }, + std::move(chunk_reader)); }, "ObjectManager.Push"); }); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index c0519a38306bd..3aaa847f03381 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include #include @@ -49,8 +49,6 @@ namespace ray { struct ObjectManagerConfig { - /// The IP address this object manager is running on. - std::string object_manager_address; /// The port that the object manager should use to listen for connections /// from other object managers. If this is 0, the object manager will choose /// its own port. @@ -58,7 +56,7 @@ struct ObjectManagerConfig { /// The object manager's global timer frequency. unsigned int timer_freq_ms; /// The time in milliseconds to wait before retrying a pull - /// that failed. + /// that fails due to node id lookup. unsigned int pull_timeout_ms; /// Object chunk size, in bytes uint64_t object_chunk_size; diff --git a/src/ray/object_manager/plasma/store.cc b/src/ray/object_manager/plasma/store.cc index 0b8b24dbac56d..ff9e98ddb765c 100644 --- a/src/ray/object_manager/plasma/store.cc +++ b/src/ray/object_manager/plasma/store.cc @@ -32,7 +32,7 @@ #include #include -#include +#include #include #include #include @@ -53,7 +53,6 @@ #include "ray/object_manager/plasma/protocol.h" #include "ray/util/util.h" -namespace ph = boost::placeholders; namespace fb = plasma::flatbuf; namespace plasma { @@ -298,9 +297,7 @@ void PlasmaStore::ConnectClient(const boost::system::error_code &error) { if (!error) { // Accept a new local client and dispatch it to the node manager. auto new_connection = Client::Create( - // NOLINTNEXTLINE : handler must be of boost::AcceptHandler type. - boost::bind(&PlasmaStore::ProcessMessage, this, ph::_1, ph::_2, ph::_3), - std::move(socket_)); + boost::bind(&PlasmaStore::ProcessMessage, this, _1, _2, _3), std::move(socket_)); } // We're ready to accept another client. DoAccept(); diff --git a/src/ray/object_manager/pull_manager.h b/src/ray/object_manager/pull_manager.h index 6c5108f111abe..9b7f20c14a478 100644 --- a/src/ray/object_manager/pull_manager.h +++ b/src/ray/object_manager/pull_manager.h @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include "absl/container/flat_hash_map.h" diff --git a/src/ray/protobuf/agent_manager.proto b/src/ray/protobuf/agent_manager.proto index cbbd127004536..f573f53766525 100644 --- a/src/ray/protobuf/agent_manager.proto +++ b/src/ray/protobuf/agent_manager.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 1d3dd8124484d..dd9cf403c305c 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; @@ -156,12 +155,10 @@ message RayException { /// The runtime environment describes all the runtime packages needed to /// run some task or actor. message RuntimeEnv { - /// The serialized runtime env passed from the user. - string serialized_runtime_env = 1; - /// URIs used in this runtime env. These will be used for reference counting. + /// The raw json passed from user + string raw_json = 1; + /// Uris used in this runtime env repeated string uris = 2; - /// Indicates whether to install runtime env eagerly before the workers are leased. - bool runtime_env_eager_install = 3; } /// The task specification encapsulates all immutable information about the @@ -212,19 +209,21 @@ message TaskSpec { int64 placement_group_bundle_index = 19; // Whether or not this task should capture parent's placement group automatically. bool placement_group_capture_child_tasks = 20; + // Environment variables to override for this task + map override_environment_variables = 21; // Whether or not to skip the execution of this task. When it's true, // the receiver will not execute the task. This field is used by async actors // to guarantee task submission order after restart. - bool skip_execution = 21; + bool skip_execution = 22; // Breakpoint if this task should drop into the debugger when it starts executing // and "" if the task should not drop into the debugger. - bytes debugger_breakpoint = 22; - // Runtime environment for this task. - RuntimeEnv runtime_env = 23; + bytes debugger_breakpoint = 23; + // Serialized JSON string of the parsed runtime environment dict for this task. + string serialized_runtime_env = 24; // The concurrency group name in which this task will be performed. - string concurrency_group_name = 24; + string concurrency_group_name = 25; // Whether application-level errors (exceptions) should be retried. - bool retry_exceptions = 25; + bool retry_exceptions = 26; } message Bundle { diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 81a8fbb5fd3d2..9af0a87231326 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/event.proto b/src/ray/protobuf/event.proto index 5ec8ee9402492..2edc202776f6b 100644 --- a/src/ray/protobuf/event.proto +++ b/src/ray/protobuf/event.proto @@ -1,5 +1,4 @@ syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 5f35c1a21e4d5..ec1f3e7380d53 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; @@ -150,17 +149,19 @@ message ActorTableData { RayException creation_task_exception = 18; // The actor's namespace. Named `ray_namespace` to avoid confusions when invoked in c++. string ray_namespace = 19; + // Runtime required to run this actor + // It'll only be set if it's a detached actor and the original job has this field + RuntimeEnv runtime_env = 20; // The unix ms timestamp the actor was started at. - uint64 start_time = 20; + uint64 start_time = 21; // The unix ms timestamp the actor was ended at. - uint64 end_time = 21; - // Serialized runtime_env used to report in the dashboard snapshot. We need to populate - // it here instead of grabbing it from the task spec because the task spec is cleared - // for deleted actors: https://github.com/ray-project/ray/pull/11149. - string serialized_runtime_env = 22; + uint64 end_time = 22; // The actor's class name. This is necessary because the task spec's lifetime // is shorter than the ActorTableData. string class_name = 23; + // The actor's serialized runtime environment. This is necessary because the + // task spec's lifetime is shorter than the ActorTableData. + string serialized_runtime_env = 24; } message ErrorTableData { @@ -277,20 +278,24 @@ message TaskLeaseData { } message JobConfig { + // Environment variables to be set on worker processes. + map worker_env = 1; // The number of java workers per worker process. - uint32 num_java_workers_per_process = 1; + uint32 num_java_workers_per_process = 2; // The jvm options for java workers of the job. - repeated string jvm_options = 2; + repeated string jvm_options = 3; // A list of directories or files (jar files or dynamic libraries) that specify the // search path for user code. This will be used as `CLASSPATH` in Java, and `PYTHONPATH` // in Python. In C++, libraries under these paths will be loaded by 'dlopen'. - repeated string code_search_path = 3; + repeated string code_search_path = 4; // Runtime environment to run the code - RuntimeEnv runtime_env = 4; + RuntimeEnv runtime_env = 5; // The job's namespace. Named `ray_namespace` to avoid confusions when invoked in c++. - string ray_namespace = 5; + string ray_namespace = 6; + // Serialized JSON string of the parsed runtime environment dict for this job. + string serialized_runtime_env = 7; // An opaque kv store for job related metadata. - map metadata = 6; + map metadata = 8; } message JobTableData { diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 65e9bbad13bc3..308083f201208 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -13,7 +13,7 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; + package ray.rpc; import "src/ray/protobuf/common.proto"; diff --git a/src/ray/protobuf/job_agent.proto b/src/ray/protobuf/job_agent.proto index e187de67ae0f5..07355a0a8f7c0 100644 --- a/src/ray/protobuf/job_agent.proto +++ b/src/ray/protobuf/job_agent.proto @@ -15,7 +15,6 @@ syntax = "proto3"; package ray.rpc; -option cc_enable_arenas = true; import "src/ray/protobuf/agent_manager.proto"; diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index 0331369528753..0c56bb7832b3a 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -13,31 +13,12 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; import "src/ray/protobuf/common.proto"; import "src/ray/protobuf/gcs.proto"; -message WorkerBacklogReport { - // TaskSpec indicating the scheduling class. - // Cannot send scheduling class directly - // since it's local to each process. - TaskSpec resource_spec = 1; - // Size of the backlog for the above scheduling class. - int64 backlog_size = 2; -} - -message ReportWorkerBacklogRequest { - // Unique id of the worker that's reporting the backlog - bytes worker_id = 1; - // Backlog report per scheduling class - repeated WorkerBacklogReport backlog_reports = 2; -} - -message ReportWorkerBacklogReply {} - // Request a worker from the raylet with the specified resources. message RequestWorkerLeaseRequest { // TaskSpec containing the requested resources. @@ -273,8 +254,6 @@ service NodeManagerService { returns (RequestResourceReportReply); // Request a worker from the raylet. rpc RequestWorkerLease(RequestWorkerLeaseRequest) returns (RequestWorkerLeaseReply); - // Report task backlog information from a worker to the raylet - rpc ReportWorkerBacklog(ReportWorkerBacklogRequest) returns (ReportWorkerBacklogReply); // Release a worker back to its raylet. rpc ReturnWorker(ReturnWorkerRequest) returns (ReturnWorkerReply); // This method is only used by GCS, and the purpose is to release leased workers diff --git a/src/ray/protobuf/object_manager.proto b/src/ray/protobuf/object_manager.proto index c212b18b266d1..8bd6986f6b5b1 100644 --- a/src/ray/protobuf/object_manager.proto +++ b/src/ray/protobuf/object_manager.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/pubsub.proto b/src/ray/protobuf/pubsub.proto index 8181f886ffb3c..fc046afcf69c2 100644 --- a/src/ray/protobuf/pubsub.proto +++ b/src/ray/protobuf/pubsub.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index 5dab0499d7d56..e207263e515a7 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; @@ -62,6 +61,8 @@ message ClientTask { // A name parameter, if the payload can be called in more than one way // (like a method on a payload object). string name = 2; + // A namespace parameter. + string namespace = 9; // A reference to the payload. bytes payload_id = 3; // Positional parameters to pass to this call. @@ -75,8 +76,6 @@ message ClientTask { TaskOptions options = 7; // Options passed to create the default remote task excution environment. TaskOptions baseline_options = 8; - // A namespace parameter. - string namespace = 9; } message ClientTaskTicket { diff --git a/src/ray/protobuf/reporter.proto b/src/ray/protobuf/reporter.proto index 00849c0683960..225c520481cc5 100644 --- a/src/ray/protobuf/reporter.proto +++ b/src/ray/protobuf/reporter.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/runtime_env_agent.proto b/src/ray/protobuf/runtime_env_agent.proto index f36adf38cdb2a..a7903f8939c91 100644 --- a/src/ray/protobuf/runtime_env_agent.proto +++ b/src/ray/protobuf/runtime_env_agent.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.rpc; @@ -22,10 +21,6 @@ import "src/ray/protobuf/agent_manager.proto"; message CreateRuntimeEnvRequest { string serialized_runtime_env = 1; bytes job_id = 2; - // Serialized allocated resource instances. Key is resource type, value is allocated - // instances. For example,{"CPU":20000,"memory":40000,"GPU":[10000, 10000]} means 2 cpu - // cores, 2 Gi memory, GPU 0 and GPU 1. - string serialized_allocated_resource_instances = 3; } message CreateRuntimeEnvReply { diff --git a/src/ray/protobuf/serialization.proto b/src/ray/protobuf/serialization.proto index 84da8dff1531c..e5fed8e4a3876 100644 --- a/src/ray/protobuf/serialization.proto +++ b/src/ray/protobuf/serialization.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.serialization; diff --git a/src/ray/protobuf/serve.proto b/src/ray/protobuf/serve.proto index 2636dcf685544..24e755a0b883a 100644 --- a/src/ray/protobuf/serve.proto +++ b/src/ray/protobuf/serve.proto @@ -13,7 +13,6 @@ // limitations under the License. syntax = "proto3"; -option cc_enable_arenas = true; package ray.serve; @@ -32,17 +31,19 @@ message AutoscalingConfig { uint32 max_replicas = 2; // Target number of in flight requests per replicas. This is the primary configuration // knob for replica autoscaler. Lower the number, the more rapidly will the replicas - // being scaled up. Must be a non-negative integer. + // being scaled up. Must be a non-negative inter. uint32 target_num_ongoing_requests_per_replica = 3; // The frequency of how long does each replica sending metrics to autoscaler. double metrics_interval_s = 4; - + // The interval (in seconds) of autoscaler evaluating metrics and performing scaling + // decision. + double loop_period_s = 5; // The window (in seconds) for autoscaler to calculate rolling average of metrics on. - double look_back_period_s = 5; + double look_back_period_s = 6; // The multiplicative "gain" factor to limit scaling decisions. - double smoothing_factor = 6; + double smoothing_factor = 7; } // Configuration options for a backend, to be set by the user. @@ -61,11 +62,11 @@ message BackendConfig { // Duration that backend workers will wait until there is no more work to be done before // shutting down. Defaults to 2s. - double graceful_shutdown_wait_loop_s = 4; + double experimental_graceful_shutdown_wait_loop_s = 4; // Controller waits for this duration to forcefully kill the replica for shutdown. // Defaults to 20s. - double graceful_shutdown_timeout_s = 5; + double experimental_graceful_shutdown_timeout_s = 5; // Is the construction of backend is cross language? bool is_cross_language = 6; @@ -94,35 +95,3 @@ message RequestMetadata { message RequestWrapper { bytes body = 1; } - -message UpdatedObject { - bytes object_snapshot = 1; - int32 snapshot_id = 2; -} - -message LongPollRequest { - map keys_to_snapshot_ids = 1; -} - -message LongPollResult { - map updated_objects = 1; -} - -message EndpointInfo { - string endpoint_name = 1; - string route = 2; - map config = 3; -} - -message EndpointSet { - map endpoints = 1; -} - -message ActorSet { - repeated string names = 1; -} - -message BackendVersion { - string code_version = 1; - bytes user_config = 2; -} diff --git a/src/ray/ray_version_script.lds b/src/ray/ray_version_script.lds index b18b99d675dfa..6d53de5ed92d1 100644 --- a/src/ray/ray_version_script.lds +++ b/src/ray/ray_version_script.lds @@ -39,6 +39,7 @@ VERSION_1.0 { *ray*streaming*; *aligned_free*; *aligned_malloc*; + *absl*; *grpc*; local: *; }; diff --git a/src/ray/raylet/agent_manager.cc b/src/ray/raylet/agent_manager.cc index ec15c27a85be4..55fe1392f6686 100644 --- a/src/ray/raylet/agent_manager.cc +++ b/src/ray/raylet/agent_manager.cc @@ -36,8 +36,6 @@ void AgentManager::HandleRegisterAgent(const rpc::RegisterAgentRequest &request, RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << agent_ip_address_ << ", port: " << agent_port_ << ", pid: " << agent_pid_; reply->set_status(rpc::AGENT_RPC_STATUS_OK); - // Reset the restart count after registration is done. - agent_restart_count_ = 0; send_reply_callback(ray::Status::OK(), nullptr, nullptr); } @@ -67,16 +65,14 @@ void AgentManager::StartAgent() { ProcessEnvironment env; env.insert({"RAY_NODE_ID", options_.node_id.Hex()}); env.insert({"RAY_RAYLET_PID", std::to_string(getpid())}); - // Report the restart count to the agent so that we can decide whether or not - // report the error message to drivers. - env.insert({"RESTART_COUNT", std::to_string(agent_restart_count_)}); - env.insert({"MAX_RESTART_COUNT", - std::to_string(RayConfig::instance().agent_max_restart_count())}); Process child(argv.data(), nullptr, ec, false, env); if (!child.IsValid() || ec) { // The worker failed to start. This is a fatal error. RAY_LOG(FATAL) << "Failed to start agent with return value " << ec << ": " << ec.message(); + RAY_UNUSED(delay_executor_([this] { StartAgent(); }, + RayConfig::instance().agent_restart_interval_ms())); + return; } std::thread monitor_thread([this, child]() mutable { @@ -105,39 +101,22 @@ void AgentManager::StartAgent() { .WithField("pid", agent_pid_) << "Agent process with pid " << child.GetId() << " exit, return value " << exit_code; - if (agent_restart_count_ < RayConfig::instance().agent_max_restart_count()) { - RAY_UNUSED(delay_executor_( - [this] { - agent_restart_count_++; - StartAgent(); - }, - // Retrying with exponential backoff - RayConfig::instance().agent_restart_interval_ms() * - std::pow(2, (agent_restart_count_ + 1)))); - } else { - RAY_LOG(INFO) << "Agent has failed " - << RayConfig::instance().agent_max_restart_count() - << " times in a row without registering the agent. This is highly " - "likely there's a bug in the dashboard agent. Please check out " - "the dashboard_agent.log file."; - } + RAY_UNUSED(delay_executor_([this] { StartAgent(); }, + RayConfig::instance().agent_restart_interval_ms())); }); monitor_thread.detach(); } -void AgentManager::CreateRuntimeEnv( - const JobID &job_id, const std::string &serialized_runtime_env, - const std::string &serialized_allocated_resource_instances, - CreateRuntimeEnvCallback callback) { +void AgentManager::CreateRuntimeEnv(const JobID &job_id, + const std::string &serialized_runtime_env, + CreateRuntimeEnvCallback callback) { if (runtime_env_agent_client_ == nullptr) { RAY_LOG(INFO) << "Runtime env agent is not registered yet. Will retry CreateRuntimeEnv later: " << serialized_runtime_env; delay_executor_( - [this, job_id, serialized_runtime_env, serialized_allocated_resource_instances, - callback] { - CreateRuntimeEnv(job_id, serialized_runtime_env, - serialized_allocated_resource_instances, callback); + [this, job_id, serialized_runtime_env, callback] { + CreateRuntimeEnv(job_id, serialized_runtime_env, callback); }, RayConfig::instance().agent_manager_retry_interval_ms()); return; @@ -145,12 +124,9 @@ void AgentManager::CreateRuntimeEnv( rpc::CreateRuntimeEnvRequest request; request.set_job_id(job_id.Hex()); request.set_serialized_runtime_env(serialized_runtime_env); - request.set_serialized_allocated_resource_instances( - serialized_allocated_resource_instances); runtime_env_agent_client_->CreateRuntimeEnv( - request, - [this, job_id, serialized_runtime_env, serialized_allocated_resource_instances, - callback](const Status &status, const rpc::CreateRuntimeEnvReply &reply) { + request, [this, job_id, serialized_runtime_env, callback]( + Status status, const rpc::CreateRuntimeEnvReply &reply) { if (status.ok()) { if (reply.status() == rpc::AGENT_RPC_STATUS_OK) { callback(true, reply.serialized_runtime_env_context()); @@ -166,10 +142,8 @@ void AgentManager::CreateRuntimeEnv( << ", status = " << status << ", maybe there are some network problems, will retry it later."; delay_executor_( - [this, job_id, serialized_runtime_env, - serialized_allocated_resource_instances, callback] { - CreateRuntimeEnv(job_id, serialized_runtime_env, - serialized_allocated_resource_instances, callback); + [this, job_id, serialized_runtime_env, callback] { + CreateRuntimeEnv(job_id, serialized_runtime_env, callback); }, RayConfig::instance().agent_manager_retry_interval_ms()); } diff --git a/src/ray/raylet/agent_manager.h b/src/ray/raylet/agent_manager.h index ba81454b84536..bb12df0f64da4 100644 --- a/src/ray/raylet/agent_manager.h +++ b/src/ray/raylet/agent_manager.h @@ -64,10 +64,9 @@ class AgentManager : public rpc::AgentManagerServiceHandler { /// Request agent to create a runtime env. /// \param[in] runtime_env The runtime env. - virtual void CreateRuntimeEnv( - const JobID &job_id, const std::string &serialized_runtime_env, - const std::string &serialized_allocated_resource_instances, - CreateRuntimeEnvCallback callback); + virtual void CreateRuntimeEnv(const JobID &job_id, + const std::string &serialized_runtime_env, + CreateRuntimeEnvCallback callback); /// Request agent to delete a list of URIs. /// \param[in] URIs The list of URIs to delete. @@ -81,8 +80,6 @@ class AgentManager : public rpc::AgentManagerServiceHandler { Options options_; pid_t agent_pid_ = 0; int agent_port_ = 0; - /// The number of times the agent is restarted. - std::atomic agent_restart_count_ = 0; std::string agent_ip_address_; DelayExecutorFn delay_executor_; RuntimeEnvAgentClientFactoryFn runtime_env_agent_client_factory_; diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 93655b7501d1e..aa096b3f1e86b 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -212,7 +212,6 @@ int main(int argc, char *argv[]) { // Configuration for the object manager. ray::ObjectManagerConfig object_manager_config; - object_manager_config.object_manager_address = node_ip_address; object_manager_config.object_manager_port = object_manager_port; object_manager_config.store_socket_name = store_socket_name; @@ -245,7 +244,7 @@ int main(int argc, char *argv[]) { // Initialize stats. const ray::stats::TagsType global_tags = { {ray::stats::ComponentKey, "raylet"}, - {ray::stats::VersionKey, kRayVersion}, + {ray::stats::VersionKey, "2.0.0.dev0"}, {ray::stats::NodeAddressKey, node_ip_address}}; ray::stats::Init(global_tags, metrics_agent_port); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index eb0e1f7cadc37..4260542319060 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -252,8 +252,7 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self temp_dir_(config.temp_dir), initial_config_(config), dependency_manager_(object_manager_), - node_manager_server_("NodeManager", config.node_manager_port, - config.node_manager_address == "127.0.0.1"), + node_manager_server_("NodeManager", config.node_manager_port), node_manager_service_(io_service, *this), agent_manager_service_handler_( new DefaultAgentManagerServiceHandler(agent_manager_)), @@ -373,8 +372,7 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self }, /*runtime_env_agent_factory=*/ [this](const std::string &ip_address, int port) { - RAY_CHECK(!ip_address.empty() && port != 0) - << "ip_address: " << ip_address << " port: " << port; + RAY_CHECK(!ip_address.empty() && port != 0); return std::shared_ptr( new rpc::RuntimeEnvAgentClient(ip_address, port, client_call_manager_)); }); @@ -527,7 +525,7 @@ void NodeManager::DestroyWorker(std::shared_ptr worker, } void NodeManager::HandleJobStarted(const JobID &job_id, const JobTableData &job_data) { - RAY_LOG(DEBUG) << "HandleJobStarted for job " << job_id; + RAY_LOG(DEBUG) << "HandleJobStarted " << job_id; worker_pool_.HandleJobStarted(job_id, job_data.config()); // NOTE: Technically `HandleJobStarted` isn't idempotent because we'll // increment the ref count multiple times. This is fine because @@ -1257,8 +1255,6 @@ void NodeManager::DisconnectClient( // Return the resources that were being used by this worker. cluster_task_manager_->ReleaseWorkerResources(worker); - cluster_task_manager_->ClearWorkerBacklog(worker->WorkerId()); - // Since some resources may have been released, we can try to dispatch more tasks. cluster_task_manager_->ScheduleAndDispatchTasks(); } else if (is_driver) { @@ -1504,28 +1500,19 @@ void NodeManager::HandleRequestResourceReport( send_reply_callback(Status::OK(), nullptr, nullptr); } -void NodeManager::HandleReportWorkerBacklog( - const rpc::ReportWorkerBacklogRequest &request, rpc::ReportWorkerBacklogReply *reply, - rpc::SendReplyCallback send_reply_callback) { - const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); - cluster_task_manager_->ClearWorkerBacklog(worker_id); - std::unordered_set seen; - for (const auto &backlog_report : request.backlog_reports()) { - const TaskSpecification resource_spec(backlog_report.resource_spec()); - const SchedulingClass scheduling_class = resource_spec.GetSchedulingClass(); - RAY_CHECK(seen.find(scheduling_class) == seen.end()); - cluster_task_manager_->SetWorkerBacklog(scheduling_class, worker_id, - backlog_report.backlog_size()); - } - send_reply_callback(Status::OK(), nullptr, nullptr); -} - void NodeManager::HandleRequestWorkerLease(const rpc::RequestWorkerLeaseRequest &request, rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) { rpc::Task task_message; task_message.mutable_task_spec()->CopyFrom(request.resource_spec()); - RayTask task(task_message); + auto backlog_size = -1; + if (RayConfig::instance().report_worker_backlog()) { + // We add 1 to the backlog size because we need a worker to fulfill the + // current request, as well as workers to serve the requests in the + // backlog. + backlog_size = request.backlog_size() + 1; + } + RayTask task(task_message, backlog_size); bool is_actor_creation_task = task.GetTaskSpecification().IsActorCreationTask(); ActorID actor_id = ActorID::Nil(); metrics_num_task_scheduled_ += 1; @@ -1675,7 +1662,7 @@ void NodeManager::HandleReturnWorker(const rpc::ReturnWorkerRequest &request, if (worker->IsBlocked()) { HandleDirectCallTaskUnblocked(worker); } - cluster_task_manager_->ReleaseWorkerResources(worker); + cluster_task_manager_->ReturnWorkerResources(worker); HandleWorkerAvailable(worker); } } else { @@ -1881,8 +1868,7 @@ void NodeManager::FinishAssignedActorCreationTask(WorkerInterface &worker, auto job_id = task.GetTaskSpecification().JobId(); auto job_config = worker_pool_.GetJobConfig(job_id); RAY_CHECK(job_config); - runtime_env_manager_.AddURIReference(actor_id.Hex(), - task.GetTaskSpecification().RuntimeEnv()); + runtime_env_manager_.AddURIReference(actor_id.Hex(), job_config->runtime_env()); } } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index e8fb4e3050254..a699635c439f7 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -48,6 +48,7 @@ namespace ray { namespace raylet { +using rpc::ActorTableData; using rpc::ErrorType; using rpc::GcsNodeInfo; using rpc::HeartbeatTableData; @@ -272,6 +273,13 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// returned to idle. bool FinishAssignedTask(const std::shared_ptr &worker_ptr); + /// Helper function to produce actor table data for a newly created actor. + /// + /// \param task_spec RayTask specification of the actor creation task that created the + /// actor. + /// \param worker The port that the actor is listening on. + std::shared_ptr CreateActorTableDataFromCreationTask( + const TaskSpecification &task_spec, int port, const WorkerID &worker_id); /// Handle a worker finishing an assigned actor creation task. /// \param worker The worker that finished the task. /// \param task The actor task or actor creation task. @@ -487,11 +495,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) override; - /// Handle a `ReportWorkerBacklog` request. - void HandleReportWorkerBacklog(const rpc::ReportWorkerBacklogRequest &request, - rpc::ReportWorkerBacklogReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - /// Handle a `ReturnWorker` request. void HandleReturnWorker(const rpc::ReturnWorkerRequest &request, rpc::ReturnWorkerReply *reply, diff --git a/src/ray/raylet/placement_group_resource_manager.cc b/src/ray/raylet/placement_group_resource_manager.cc index d9ccfd1ac0574..8639689edb949 100644 --- a/src/ray/raylet/placement_group_resource_manager.cc +++ b/src/ray/raylet/placement_group_resource_manager.cc @@ -152,9 +152,6 @@ void NewPlacementGroupResourceManager::ReturnBundle( // will be resource leak. cluster_resource_scheduler_->DeleteLocalResource(resource.first); deleted.push_back(resource.first); - } else { - RAY_LOG(DEBUG) << "Available bundle resource:[" << resource.first - << "] is not empty. Resources are not deleted from the local node."; } } pg_bundles_.erase(it); diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index c2f431b20027d..b8040b6f8acdc 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -15,7 +15,7 @@ #include "ray/raylet/raylet.h" #include -#include +#include #include #include @@ -61,10 +61,7 @@ Raylet::Raylet(instrumented_io_context &main_service, const std::string &socket_ const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client, int metrics_export_port) : main_service_(main_service), - self_node_id_( - !RayConfig::instance().OVERRIDE_NODE_ID_FOR_TESTING().empty() - ? NodeID::FromHex(RayConfig::instance().OVERRIDE_NODE_ID_FOR_TESTING()) - : NodeID::FromRandom()), + self_node_id_(NodeID::FromRandom()), gcs_client_(gcs_client), node_manager_(main_service, self_node_id_, node_manager_config, object_manager_config, gcs_client_), diff --git a/src/ray/raylet/scheduling/cluster_resource_data.cc b/src/ray/raylet/scheduling/cluster_resource_data.cc index f19287d0915f5..ea4ae6621f6b5 100644 --- a/src/ray/raylet/scheduling/cluster_resource_data.cc +++ b/src/ray/raylet/scheduling/cluster_resource_data.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "ray/raylet/scheduling/cluster_resource_data.h" - #include "ray/common/bundle_spec.h" #include "ray/common/task/scheduling_resources.h" @@ -537,7 +536,7 @@ bool TaskResourceInstances::IsEmpty() const { return true; } -std::string TaskResourceInstances::DebugString(const StringIdMap &string_id_map) const { +std::string TaskResourceInstances::DebugString() const { std::stringstream buffer; buffer << std::endl << " Allocation: {"; for (size_t i = 0; i < this->predefined_resources.size(); i++) { @@ -548,7 +547,7 @@ std::string TaskResourceInstances::DebugString(const StringIdMap &string_id_map) buffer << " ["; for (auto it = this->custom_resources.begin(); it != this->custom_resources.end(); ++it) { - buffer << string_id_map.Get(it->first) << ":" << VectorToString(it->second) << ", "; + buffer << it->first << ":" << VectorToString(it->second) << ", "; } buffer << "]" << std::endl; diff --git a/src/ray/raylet/scheduling/cluster_resource_data.h b/src/ray/raylet/scheduling/cluster_resource_data.h index 783ab12da9eee..0398726f39d42 100644 --- a/src/ray/raylet/scheduling/cluster_resource_data.h +++ b/src/ray/raylet/scheduling/cluster_resource_data.h @@ -138,7 +138,7 @@ class TaskResourceInstances { /// Check whether there are no resource instances. bool IsEmpty() const; /// Returns human-readable string for these resources. - [[nodiscard]] std::string DebugString(const StringIdMap &string_id_map) const; + std::string DebugString() const; }; /// Total and available capacities of each resource of a node. @@ -189,7 +189,7 @@ class NodeResourceInstances { /// Returns if this equals another node resources. bool operator==(const NodeResourceInstances &other); /// Returns human-readable string for these resources. - [[nodiscard]] std::string DebugString(StringIdMap string_to_int_map) const; + std::string DebugString(StringIdMap string_to_int_map) const; }; struct Node { diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc index 1174f138395e0..6fcff8a501c55 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc @@ -456,7 +456,8 @@ void ClusterResourceScheduler::AddLocalResourceInstances( for (size_t i = 0; i < instances.size(); i++) { node_instances->available[i] += instances[i]; - node_instances->total[i] += instances[i]; + node_instances->total[i] = + std::max(node_instances->total[i], node_instances->available[i]); } UpdateLocalAvailableResourcesFromResourceInstances(); } diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index 14a999d3be4e4..1b90e93fb1bf4 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -48,6 +48,7 @@ ClusterTaskManager::ClusterTaskManager( announce_infeasible_task_(announce_infeasible_task), max_resource_shapes_per_load_report_( RayConfig::instance().max_resource_shapes_per_load_report()), + report_worker_backlog_(RayConfig::instance().report_worker_backlog()), worker_pool_(worker_pool), leased_workers_(leased_workers), get_task_arguments_(get_task_arguments), @@ -425,6 +426,7 @@ void ClusterTaskManager::QueueAndScheduleTask( } else { tasks_to_schedule_[scheduling_class].push_back(work); } + AddToBacklogTracker(task); ScheduleAndDispatchTasks(); } @@ -561,6 +563,12 @@ void ClusterTaskManager::ReleaseTaskArgs(const TaskID &task_id) { } } +void ClusterTaskManager::ReturnWorkerResources(std::shared_ptr worker) { + // TODO(Shanly): This method will be removed and can be replaced by + // `ReleaseWorkerResources` directly once we remove the legacy scheduler. + ReleaseWorkerResources(worker); +} + void ReplyCancelled(std::shared_ptr &work, bool runtime_env_setup_failed) { auto reply = work->reply; auto callback = work->callback; @@ -579,6 +587,7 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id, for (auto work_it = work_queue.begin(); work_it != work_queue.end(); work_it++) { const auto &task = (*work_it)->task; if (task.GetTaskSpecification().TaskId() == task_id) { + RemoveFromBacklogTracker(task); RAY_LOG(DEBUG) << "Canceling task " << task_id << " from schedule queue."; ReplyCancelled(*work_it, runtime_env_setup_failed); work_queue.erase(work_it); @@ -595,6 +604,7 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id, for (auto work_it = work_queue.begin(); work_it != work_queue.end(); work_it++) { const auto &task = (*work_it)->task; if (task.GetTaskSpecification().TaskId() == task_id) { + RemoveFromBacklogTracker(task); RAY_LOG(DEBUG) << "Canceling task " << task_id << " from dispatch queue."; ReplyCancelled(*work_it, runtime_env_setup_failed); if ((*work_it)->status == WorkStatus::WAITING_FOR_WORKER) { @@ -624,6 +634,7 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id, for (auto work_it = work_queue.begin(); work_it != work_queue.end(); work_it++) { const auto &task = (*work_it)->task; if (task.GetTaskSpecification().TaskId() == task_id) { + RemoveFromBacklogTracker(task); RAY_LOG(DEBUG) << "Canceling task " << task_id << " from infeasible queue."; ReplyCancelled(*work_it, runtime_env_setup_failed); work_queue.erase(work_it); @@ -638,6 +649,7 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id, auto iter = waiting_tasks_index_.find(task_id); if (iter != waiting_tasks_index_.end()) { const auto &task = (*iter->second)->task; + RemoveFromBacklogTracker(task); ReplyCancelled(*iter->second, runtime_env_setup_failed); if (!task.GetTaskSpecification().GetDependencies().empty()) { task_dependency_manager_.RemoveTaskDependencies( @@ -704,35 +716,36 @@ void ClusterTaskManager::FillResourceUsage( TaskSpecification::GetSchedulingClass(one_cpu_resource_set)); { num_reported++; - int ready_count = 0; + int count = 0; auto it = tasks_to_schedule_.find(one_cpu_scheduling_cls); if (it != tasks_to_schedule_.end()) { - ready_count += it->second.size(); + count += it->second.size(); } it = tasks_to_dispatch_.find(one_cpu_scheduling_cls); if (it != tasks_to_dispatch_.end()) { - ready_count += it->second.size(); - } - int infeasible_count = 0; - it = infeasible_tasks_.find(one_cpu_scheduling_cls); - if (it != infeasible_tasks_.end()) { - infeasible_count += it->second.size(); + count += it->second.size(); } - const int total_count = ready_count + infeasible_count; - if (total_count > 0) { + + if (count > 0) { auto by_shape_entry = resource_load_by_shape->Add(); - for (const auto &[label, quantity] : one_cpu_resource_set.GetResourceMap()) { + for (const auto &resource : one_cpu_resource_set.GetResourceMap()) { // Add to `resource_loads`. - (*resource_loads)[label] += quantity * total_count; + const auto &label = resource.first; + const auto &quantity = resource.second; + (*resource_loads)[label] += quantity * count; // Add to `resource_load_by_shape`. (*by_shape_entry->mutable_shape())[label] = quantity; } - by_shape_entry->set_num_ready_requests_queued(ready_count); - by_shape_entry->set_num_infeasible_requests_queued(infeasible_count); - by_shape_entry->set_backlog_size(TotalBacklogSize(one_cpu_scheduling_cls)); + int num_ready = by_shape_entry->num_ready_requests_queued(); + by_shape_entry->set_num_ready_requests_queued(num_ready + count); + + auto backlog_it = backlog_tracker_.find(one_cpu_scheduling_cls); + if (backlog_it != backlog_tracker_.end()) { + by_shape_entry->set_backlog_size(backlog_it->second); + } } } @@ -770,7 +783,10 @@ void ClusterTaskManager::FillResourceUsage( // ClusterResourceScheduler::GetBestSchedulableNode for more details. int num_ready = by_shape_entry->num_ready_requests_queued(); by_shape_entry->set_num_ready_requests_queued(num_ready + count); - by_shape_entry->set_backlog_size(TotalBacklogSize(scheduling_class)); + auto backlog_it = backlog_tracker_.find(scheduling_class); + if (backlog_it != backlog_tracker_.end()) { + by_shape_entry->set_backlog_size(backlog_it->second); + } } for (const auto &pair : tasks_to_dispatch_) { @@ -803,7 +819,10 @@ void ClusterTaskManager::FillResourceUsage( } int num_ready = by_shape_entry->num_ready_requests_queued(); by_shape_entry->set_num_ready_requests_queued(num_ready + count); - by_shape_entry->set_backlog_size(TotalBacklogSize(scheduling_class)); + auto backlog_it = backlog_tracker_.find(scheduling_class); + if (backlog_it != backlog_tracker_.end()) { + by_shape_entry->set_backlog_size(backlog_it->second); + } } for (const auto &pair : infeasible_tasks_) { @@ -839,7 +858,10 @@ void ClusterTaskManager::FillResourceUsage( // ClusterResourceScheduler::GetBestSchedulableNode for more details. int num_infeasible = by_shape_entry->num_infeasible_requests_queued(); by_shape_entry->set_num_infeasible_requests_queued(num_infeasible + count); - by_shape_entry->set_backlog_size(TotalBacklogSize(scheduling_class)); + auto backlog_it = backlog_tracker_.find(scheduling_class); + if (backlog_it != backlog_tracker_.end()) { + by_shape_entry->set_backlog_size(backlog_it->second); + } } if (RayConfig::instance().enable_light_weight_resource_report()) { @@ -993,6 +1015,7 @@ void ClusterTaskManager::Dispatch( RAY_CHECK(leased_workers.find(worker->WorkerId()) == leased_workers.end()); leased_workers[worker->WorkerId()] = worker; + RemoveFromBacklogTracker(task); // Update our internal view of the cluster state. std::shared_ptr allocated_resources; @@ -1048,6 +1071,7 @@ void ClusterTaskManager::Spillback(const NodeID &spillback_to, metric_tasks_spilled_++; const auto &task = work->task; const auto &task_spec = task.GetTaskSpecification(); + RemoveFromBacklogTracker(task); RAY_LOG(DEBUG) << "Spilling task " << task_spec.TaskId() << " to node " << spillback_to; if (!cluster_resource_scheduler_->AllocateRemoteTaskResources( @@ -1074,44 +1098,23 @@ void ClusterTaskManager::Spillback(const NodeID &spillback_to, send_reply_callback(); } -void ClusterTaskManager::ClearWorkerBacklog(const WorkerID &worker_id) { - for (auto it = backlog_tracker_.begin(); it != backlog_tracker_.end();) { - it->second.erase(worker_id); - if (it->second.empty()) { - it = backlog_tracker_.erase(it); - } else { - ++it; - } +void ClusterTaskManager::AddToBacklogTracker(const RayTask &task) { + if (report_worker_backlog_) { + auto cls = task.GetTaskSpecification().GetSchedulingClass(); + backlog_tracker_[cls] += task.BacklogSize(); } } -void ClusterTaskManager::SetWorkerBacklog(SchedulingClass scheduling_class, - const WorkerID &worker_id, - int64_t backlog_size) { - if (backlog_size == 0) { - backlog_tracker_[scheduling_class].erase(worker_id); - if (backlog_tracker_[scheduling_class].empty()) { - backlog_tracker_.erase(scheduling_class); +void ClusterTaskManager::RemoveFromBacklogTracker(const RayTask &task) { + if (report_worker_backlog_) { + SchedulingClass cls = task.GetTaskSpecification().GetSchedulingClass(); + backlog_tracker_[cls] -= task.BacklogSize(); + if (backlog_tracker_[cls] == 0) { + backlog_tracker_.erase(backlog_tracker_.find(cls)); } - } else { - backlog_tracker_[scheduling_class][worker_id] = backlog_size; } } -int64_t ClusterTaskManager::TotalBacklogSize(SchedulingClass scheduling_class) { - auto backlog_it = backlog_tracker_.find(scheduling_class); - if (backlog_it == backlog_tracker_.end()) { - return 0; - } - - int64_t sum = 0; - for (const auto &worker_id_and_backlog_size : backlog_it->second) { - sum += worker_id_and_backlog_size.second; - } - - return sum; -} - void ClusterTaskManager::ReleaseWorkerResources(std::shared_ptr worker) { RAY_CHECK(worker != nullptr); auto allocated_instances = worker->GetAllocatedInstances(); @@ -1193,6 +1196,8 @@ void ClusterTaskManager::ScheduleAndDispatchTasks() { } void ClusterTaskManager::SpillWaitingTasks() { + RAY_LOG(DEBUG) << "Attempting to spill back from waiting task queue, num waiting: " + << waiting_task_queue_.size(); // Try to spill waiting tasks to a remote node, prioritizing those at the end // of the queue. Waiting tasks are spilled if there are enough remote // resources AND (we have no resources available locally OR their diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index 4259afc8d04c5..57ee9aab80678 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -102,11 +102,6 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { get_task_arguments, size_t max_pinned_task_arguments_bytes); - void SetWorkerBacklog(SchedulingClass scheduling_class, const WorkerID &worker_id, - int64_t backlog_size) override; - - void ClearWorkerBacklog(const WorkerID &worker_id) override; - /// (Step 1) Queue tasks and schedule. /// Queue task and schedule. This hanppens when processing the worker lease request. /// @@ -130,6 +125,13 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { /// \param task: Output parameter. void TaskFinished(std::shared_ptr worker, RayTask *task) override; + /// Return worker resources. + /// This method will be removed and can be replaced by `ReleaseWorkerResources` directly + /// once we remove the legacy scheduler. + /// + /// \param worker: The worker which was running the task. + void ReturnWorkerResources(std::shared_ptr worker) override; + /// Attempt to cancel an already queued task. /// /// \param task_id: The id of the task to remove. @@ -259,6 +261,7 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { std::function announce_infeasible_task_; const int max_resource_shapes_per_load_report_; + const bool report_worker_backlog_; /// TODO(swang): Add index from TaskID -> Work to avoid having to iterate /// through queues to cancel tasks, etc. @@ -304,9 +307,8 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { std::unordered_map>> infeasible_tasks_; - /// Track the backlog of all workers belonging to this raylet. - std::unordered_map> - backlog_tracker_; + /// Track the cumulative backlog of all workers requesting a lease to this raylet. + std::unordered_map backlog_tracker_; /// TODO(Shanly): Remove `worker_pool_` and `leased_workers_` and make them as /// parameters of methods if necessary once we remove the legacy scheduler. @@ -358,8 +360,8 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { void Spillback(const NodeID &spillback_to, const std::shared_ptr &work); - /// Sum up the backlog size across all workers for a given scheduling class. - int64_t TotalBacklogSize(SchedulingClass scheduling_class); + void AddToBacklogTracker(const RayTask &task); + void RemoveFromBacklogTracker(const RayTask &task); // Helper function to pin a task's args immediately before dispatch. This // returns false if there are missing args (due to eviction) or if there is diff --git a/src/ray/raylet/scheduling/cluster_task_manager_interface.h b/src/ray/raylet/scheduling/cluster_task_manager_interface.h index 71864f38df846..457daa4d7b320 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_interface.h +++ b/src/ray/raylet/scheduling/cluster_task_manager_interface.h @@ -78,6 +78,13 @@ class ClusterTaskManagerInterface { /// \param task: Output parameter. virtual void TaskFinished(std::shared_ptr worker, RayTask *task) = 0; + /// Return worker resources. + /// This method will be removed and can be replaced by `ReleaseWorkerResources` directly + /// once we remove the legacy scheduler + /// + /// \param worker: The worker which was running the task. + virtual void ReturnWorkerResources(std::shared_ptr worker) = 0; + /// Attempt to cancel an already queued task. /// /// \param task_id: The id of the task to remove. @@ -89,20 +96,6 @@ class ClusterTaskManagerInterface { virtual bool CancelTask(const TaskID &task_id, bool runtime_env_setup_failed = false) = 0; - /// Set the worker backlog size for a particular scheduling class. - /// - /// \param scheduling_class: The scheduling class this backlog is for. - /// \param worker_id: The ID of the worker that owns the backlog information. - /// \param backlog_size: The size of the backlog. - virtual void SetWorkerBacklog(SchedulingClass scheduling_class, - const WorkerID &worker_id, int64_t backlog_size) = 0; - - /// Remove all backlog information about the given worker. - /// - /// \param worker_id: The ID of the worker owning the backlog information - /// that we want to remove. - virtual void ClearWorkerBacklog(const WorkerID &worker_id) = 0; - /// Queue task and schedule. This hanppens when processing the worker lease request. /// /// \param task: The incoming task to be queued and scheduled. diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index f19386ad4bab5..78fe7320c8631 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -48,7 +48,8 @@ class MockWorkerPool : public WorkerPoolInterface { void PopWorker(const TaskSpecification &task_spec, const PopWorkerCallback &callback, const std::string &allocated_instances_serialized_json) { num_pops++; - const WorkerCacheKey env = {task_spec.SerializedRuntimeEnv(), {}}; + const WorkerCacheKey env = { + task_spec.OverrideEnvironmentVariables(), task_spec.SerializedRuntimeEnv(), {}}; const int runtime_env_hash = env.IntHash(); callbacks[runtime_env_hash].push_back(callback); } @@ -100,11 +101,10 @@ class MockWorkerPool : public WorkerPoolInterface { int num_pops; }; -std::shared_ptr CreateSingleNodeScheduler(const std::string &id, - double num_cpus, - double num_gpus) { +std::shared_ptr CreateSingleNodeScheduler( + const std::string &id, double num_gpus = 0.0) { std::unordered_map local_node_resources; - local_node_resources[ray::kCPU_ResourceLabel] = num_cpus; + local_node_resources[ray::kCPU_ResourceLabel] = 8; local_node_resources[ray::kGPU_ResourceLabel] = num_gpus; local_node_resources[ray::kMemory_ResourceLabel] = 128; @@ -116,18 +116,16 @@ std::shared_ptr CreateSingleNodeScheduler(const std::s RayTask CreateTask(const std::unordered_map &required_resources, int num_args = 0, std::vector args = {}, - const std::string &serialized_runtime_env = "{}", - const std::vector &runtime_env_uris = {}) { + std::string serialized_runtime_env = "{}") { TaskSpecBuilder spec_builder; TaskID id = RandomTaskId(); JobID job_id = RandomJobId(); rpc::Address address; - spec_builder.SetCommonTaskSpec(id, "dummy_task", Language::PYTHON, - FunctionDescriptorBuilder::BuildPython("", "", "", ""), - job_id, TaskID::Nil(), 0, TaskID::Nil(), address, 0, - required_resources, {}, - std::make_pair(PlacementGroupID::Nil(), -1), true, "", - serialized_runtime_env, runtime_env_uris); + spec_builder.SetCommonTaskSpec( + id, "dummy_task", Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("", "", "", ""), job_id, TaskID::Nil(), 0, + TaskID::Nil(), address, 0, required_resources, {}, + std::make_pair(PlacementGroupID::Nil(), -1), true, "", serialized_runtime_env); if (!args.empty()) { for (auto &arg : args) { @@ -179,41 +177,39 @@ class MockTaskDependencyManager : public TaskDependencyManagerInterface { class ClusterTaskManagerTest : public ::testing::Test { public: - ClusterTaskManagerTest(double num_cpus_at_head = 8.0, double num_gpus_at_head = 0.0) + ClusterTaskManagerTest(double num_gpus_at_head = 0.0) : id_(NodeID::FromRandom()), - scheduler_( - CreateSingleNodeScheduler(id_.Binary(), num_cpus_at_head, num_gpus_at_head)), + scheduler_(CreateSingleNodeScheduler(id_.Binary(), num_gpus_at_head)), is_owner_alive_(true), node_info_calls_(0), announce_infeasible_task_calls_(0), dependency_manager_(missing_objects_), - task_manager_( - id_, scheduler_, dependency_manager_, - /* is_owner_alive= */ - [this](const WorkerID &worker_id, const NodeID &node_id) { - return is_owner_alive_; - }, - /* get_node_info= */ - [this](const NodeID &node_id) { - node_info_calls_++; - return node_info_[node_id]; - }, - /* announce_infeasible_task= */ - [this](const RayTask &task) { announce_infeasible_task_calls_++; }, pool_, - leased_workers_, - /* get_task_arguments= */ - [this](const std::vector &object_ids, - std::vector> *results) { - for (auto &obj_id : object_ids) { - if (missing_objects_.count(obj_id) == 0) { - results->emplace_back(MakeDummyArg()); - } else { - results->emplace_back(nullptr); - } - } - return true; - }, - /*max_pinned_task_arguments_bytes=*/1000) {} + task_manager_(id_, scheduler_, dependency_manager_, + /* is_owner_alive= */ + [this](const WorkerID &worker_id, const NodeID &node_id) { + return is_owner_alive_; + }, + /* get_node_info= */ + [this](const NodeID &node_id) { + node_info_calls_++; + return node_info_[node_id]; + }, + /* announce_infeasible_task= */ + [this](const RayTask &task) { announce_infeasible_task_calls_++; }, + pool_, leased_workers_, + /* get_task_arguments= */ + [this](const std::vector &object_ids, + std::vector> *results) { + for (auto &obj_id : object_ids) { + if (missing_objects_.count(obj_id) == 0) { + results->emplace_back(MakeDummyArg()); + } else { + results->emplace_back(nullptr); + } + } + return true; + }, + /*max_pinned_task_arguments_bytes=*/1000) {} RayObject *MakeDummyArg() { std::vector data; @@ -291,15 +287,7 @@ class ClusterTaskManagerTest : public ::testing::Test { // Same as ClusterTaskManagerTest, but the head node starts with 4.0 num gpus. class ClusterTaskManagerTestWithGPUsAtHead : public ClusterTaskManagerTest { public: - ClusterTaskManagerTestWithGPUsAtHead() - : ClusterTaskManagerTest(/*num_cpus_at_head=*/8.0, /*num_gpus_at_head=*/4.0) {} -}; - -// Same as ClusterTaskManagerTest, but the head node starts with 0.0 num cpus. -class ClusterTaskManagerTestWithoutCPUsAtHead : public ClusterTaskManagerTest { - public: - ClusterTaskManagerTestWithoutCPUsAtHead() - : ClusterTaskManagerTest(/*num_cpus_at_head=*/0.0) {} + ClusterTaskManagerTestWithGPUsAtHead() : ClusterTaskManagerTest(4.0) {} }; TEST_F(ClusterTaskManagerTest, BasicTest) { @@ -379,7 +367,8 @@ TEST_F(ClusterTaskManagerTest, DispatchQueueNonBlockingTest) { pool_.TriggerCallbacks(); // Push a worker that can only run task A. - const WorkerCacheKey env_A = {serialized_runtime_env_A, {}}; + const WorkerCacheKey env_A = { + /*override_environment_variables=*/{}, serialized_runtime_env_A, {}}; const int runtime_env_hash_A = env_A.IntHash(); std::shared_ptr worker_A = std::make_shared(WorkerID::FromRandom(), 1234, runtime_env_hash_A); @@ -871,7 +860,7 @@ TEST_F(ClusterTaskManagerTest, HeartbeatTest) { TEST_F(ClusterTaskManagerTest, BacklogReportTest) { /* Test basic scheduler functionality: - 1. Queue and attempt to schedule/dispatch a test with no workers available + 1. Queue and attempt to schedule/dispatch atest with no workers available 2. A worker becomes available, dispatch again. */ rpc::RequestWorkerLeaseReply reply; @@ -884,21 +873,18 @@ TEST_F(ClusterTaskManagerTest, BacklogReportTest) { std::vector to_cancel; - const WorkerID worker_id_submitting_first_task = WorkerID::FromRandom(); - // Don't add the fist task to `to_cancel`. + // Don't add these fist 2 tasks to `to_cancel`. for (int i = 0; i < 1; i++) { RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 8}}); + task.SetBacklogSize(10 - i); task_manager_.QueueAndScheduleTask(task, &reply, callback); - task_manager_.SetWorkerBacklog(task.GetTaskSpecification().GetSchedulingClass(), - worker_id_submitting_first_task, 10 - i); pool_.TriggerCallbacks(); } for (int i = 1; i < 10; i++) { RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 8}}); + task.SetBacklogSize(10 - i); task_manager_.QueueAndScheduleTask(task, &reply, callback); - task_manager_.SetWorkerBacklog(task.GetTaskSpecification().GetSchedulingClass(), - WorkerID::FromRandom(), 10 - i); pool_.TriggerCallbacks(); to_cancel.push_back(task.GetTaskSpecification().TaskId()); } @@ -924,7 +910,6 @@ TEST_F(ClusterTaskManagerTest, BacklogReportTest) { std::make_shared(WorkerID::FromRandom(), 1234); pool_.PushWorker(worker); task_manager_.ScheduleAndDispatchTasks(); - task_manager_.ClearWorkerBacklog(worker_id_submitting_first_task); pool_.TriggerCallbacks(); { @@ -1540,50 +1525,6 @@ TEST_F(ClusterTaskManagerTest, PopWorkerExactlyOnce) { AssertNoLeaks(); } -// Regression test for https://github.com/ray-project/ray/issues/16935: -// When a task requires 1 CPU and is infeasible because head node has 0 CPU, -// make sure the task's resource demand is reported. -TEST_F(ClusterTaskManagerTestWithoutCPUsAtHead, OneCpuInfeasibleTask) { - rpc::RequestWorkerLeaseReply reply; - bool callback_occurred = false; - bool *callback_occurred_ptr = &callback_occurred; - auto callback = [callback_occurred_ptr](const Status &, const std::function &, - const std::function &) { - *callback_occurred_ptr = true; - }; - - constexpr int num_cases = 5; - // Create 5 tasks with different CPU requests. - const std::array cpu_request = {1, 2, 1, 3, 1}; - // Each type of CPU request corresponds to a types of resource demand. - const std::array demand_types = {1, 2, 2, 3, 3}; - // Number of infeasible 1 CPU requests.. - const std::array num_infeasible_1cpu = {1, 1, 2, 2, 3}; - - for (int i = 0; i < num_cases; ++i) { - RayTask task = CreateTask({{ray::kCPU_ResourceLabel, cpu_request[i]}}); - task_manager_.QueueAndScheduleTask(task, &reply, callback); - pool_.TriggerCallbacks(); - - // The task cannot run because there is only 1 node (head) with 0 CPU. - ASSERT_FALSE(callback_occurred); - ASSERT_EQ(leased_workers_.size(), 0); - ASSERT_EQ(pool_.workers.size(), 0); - ASSERT_EQ(node_info_calls_, 0); - - rpc::ResourcesData data; - task_manager_.FillResourceUsage(data); - const auto &resource_load_by_shape = data.resource_load_by_shape(); - ASSERT_EQ(resource_load_by_shape.resource_demands().size(), demand_types[i]); - - // 1 CPU demand currently is always the 1st. - const auto &demand = resource_load_by_shape.resource_demands()[0]; - EXPECT_EQ(demand.num_infeasible_requests_queued(), num_infeasible_1cpu[i]); - ASSERT_EQ(demand.shape().size(), 1); - ASSERT_EQ(demand.shape().at("CPU"), 1); - } -} - int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/src/ray/raylet/scheduling/fixed_point.cc b/src/ray/raylet/scheduling/fixed_point.cc new file mode 100644 index 0000000000000..ec0b3ed9af16d --- /dev/null +++ b/src/ray/raylet/scheduling/fixed_point.cc @@ -0,0 +1,96 @@ +// Copyright 2020-2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/raylet/scheduling/fixed_point.h" + +#include + +FixedPoint::FixedPoint(double d) { i_ = (uint64_t)(d * RESOURCE_UNIT_SCALING); } + +FixedPoint::FixedPoint(int i) { i_ = (i * RESOURCE_UNIT_SCALING); } + +FixedPoint::FixedPoint(uint32_t i) { i_ = (i * RESOURCE_UNIT_SCALING); } + +FixedPoint::FixedPoint(int64_t i) : FixedPoint((double)i) {} + +FixedPoint::FixedPoint(uint64_t i) : FixedPoint((double)i) {} + +FixedPoint FixedPoint::operator+(FixedPoint const &ru) const { + FixedPoint res; + res.i_ = i_ + ru.i_; + return res; +} + +FixedPoint FixedPoint::operator+=(FixedPoint const &ru) { + i_ += ru.i_; + return *this; +} + +FixedPoint FixedPoint::operator-(FixedPoint const &ru) const { + FixedPoint res; + res.i_ = i_ - ru.i_; + return res; +} + +FixedPoint FixedPoint::operator-=(FixedPoint const &ru) { + i_ -= ru.i_; + return *this; +} + +FixedPoint FixedPoint::operator-() const { + FixedPoint res; + res.i_ = -i_; + return res; +} + +FixedPoint FixedPoint::operator+(double const d) const { + FixedPoint res; + res.i_ = i_ + (int64_t)(d * RESOURCE_UNIT_SCALING); + return res; +} + +FixedPoint FixedPoint::operator-(double const d) const { + FixedPoint res; + res.i_ = i_ - (int64_t)(d * RESOURCE_UNIT_SCALING); + return res; +} + +FixedPoint FixedPoint::operator=(double const d) { + i_ = (int64_t)(d * RESOURCE_UNIT_SCALING); + return *this; +} + +FixedPoint FixedPoint::operator+=(double const d) { + i_ += (int64_t)(d * RESOURCE_UNIT_SCALING); + return *this; +} + +FixedPoint FixedPoint::operator+=(int64_t const ru) { + *this += (double)ru; + return *this; +} + +bool FixedPoint::operator<(FixedPoint const &ru1) const { return (i_ < ru1.i_); }; +bool FixedPoint::operator>(FixedPoint const &ru1) const { return (i_ > ru1.i_); }; +bool FixedPoint::operator<=(FixedPoint const &ru1) const { return (i_ <= ru1.i_); }; +bool FixedPoint::operator>=(FixedPoint const &ru1) const { return (i_ >= ru1.i_); }; +bool FixedPoint::operator==(FixedPoint const &ru1) const { return (i_ == ru1.i_); }; +bool FixedPoint::operator!=(FixedPoint const &ru1) const { return (i_ != ru1.i_); }; + +std::ostream &operator<<(std::ostream &out, FixedPoint const &ru1) { + out << ru1.i_; + return out; +} + +double FixedPoint::Double() const { return round(i_) / RESOURCE_UNIT_SCALING; }; diff --git a/src/ray/raylet/scheduling/fixed_point.h b/src/ray/raylet/scheduling/fixed_point.h index a18ffd1873218..f133397ec6251 100644 --- a/src/ray/raylet/scheduling/fixed_point.h +++ b/src/ray/raylet/scheduling/fixed_point.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include @@ -26,85 +25,41 @@ class FixedPoint { int64_t i_ = 0; public: - FixedPoint() : FixedPoint(0.0) {} - FixedPoint(double d) { i_ = (uint64_t)(d * RESOURCE_UNIT_SCALING); } // NOLINT - - FixedPoint(int i) { i_ = (i * RESOURCE_UNIT_SCALING); } // NOLINT - - FixedPoint(uint32_t i) { i_ = (i * RESOURCE_UNIT_SCALING); } // NOLINT - - FixedPoint(int64_t i) : FixedPoint((double)i) {} // NOLINT - - FixedPoint(uint64_t i) : FixedPoint((double)i) {} // NOLINT - - FixedPoint operator+(FixedPoint const &ru) const { - FixedPoint res; - res.i_ = i_ + ru.i_; - return res; - } - - FixedPoint &operator+=(FixedPoint const &ru) { - i_ += ru.i_; - return *this; - } - - FixedPoint operator-(FixedPoint const &ru) const { - FixedPoint res; - res.i_ = i_ - ru.i_; - return res; - } - - FixedPoint &operator-=(FixedPoint const &ru) { - i_ -= ru.i_; - return *this; - } - - FixedPoint operator-() const { - FixedPoint res; - res.i_ = -i_; - return res; - } - - FixedPoint operator+(double const d) const { - FixedPoint res; - res.i_ = i_ + static_cast(d * RESOURCE_UNIT_SCALING); - return res; - } - - FixedPoint operator-(double const d) const { - FixedPoint res; - res.i_ = i_ + static_cast(d * RESOURCE_UNIT_SCALING); - return res; - } - - FixedPoint operator=(double const d) { - i_ = static_cast(d * RESOURCE_UNIT_SCALING); - return *this; - } - - FixedPoint operator+=(double const d) { - i_ += static_cast(d * RESOURCE_UNIT_SCALING); - return *this; - } - - FixedPoint operator+=(int64_t const ru) { - *this += static_cast(ru); - return *this; - } - - bool operator<(FixedPoint const &ru1) const { return (i_ < ru1.i_); }; - bool operator>(FixedPoint const &ru1) const { return (i_ > ru1.i_); }; - bool operator<=(FixedPoint const &ru1) const { return (i_ <= ru1.i_); }; - bool operator>=(FixedPoint const &ru1) const { return (i_ >= ru1.i_); }; - bool operator==(FixedPoint const &ru1) const { return (i_ == ru1.i_); }; - bool operator!=(FixedPoint const &ru1) const { return (i_ != ru1.i_); }; - - [[nodiscard]] double Double() const { return round(i_) / RESOURCE_UNIT_SCALING; }; + FixedPoint() = default; + FixedPoint(double d); + FixedPoint(int i); + FixedPoint(uint32_t i); + FixedPoint(int64_t i); + FixedPoint(uint64_t i); + + FixedPoint operator+(FixedPoint const &ru) const; + + FixedPoint operator+=(FixedPoint const &ru); + + FixedPoint operator-(FixedPoint const &ru) const; + + FixedPoint operator-=(FixedPoint const &ru); + + FixedPoint operator-() const; + + FixedPoint operator+(double const d) const; + + FixedPoint operator-(double const d) const; + + FixedPoint operator=(double const d); + + FixedPoint operator+=(double const d); + + FixedPoint operator+=(int64_t const ru); + + bool operator<(FixedPoint const &ru1) const; + bool operator>(FixedPoint const &ru1) const; + bool operator<=(FixedPoint const &ru1) const; + bool operator>=(FixedPoint const &ru1) const; + bool operator==(FixedPoint const &ru1) const; + bool operator!=(FixedPoint const &ru1) const; + + double Double() const; friend std::ostream &operator<<(std::ostream &out, FixedPoint const &ru1); }; - -inline std::ostream &operator<<(std::ostream &out, FixedPoint const &ru1) { - out << ru1.i_; - return out; -} diff --git a/src/ray/raylet/scheduling/scheduling_policy.cc b/src/ray/raylet/scheduling/scheduling_policy.cc index 40c1ca39605d8..4bf28bdb75a21 100644 --- a/src/ray/raylet/scheduling/scheduling_policy.cc +++ b/src/ray/raylet/scheduling/scheduling_policy.cc @@ -57,7 +57,7 @@ int64_t HybridPolicyWithFilter(const ResourceRequest &resource_request, if (node_filter == NodeFilter::kGPU) { return has_gpu; } - RAY_CHECK(node_filter == NodeFilter::kNonGpu); + RAY_CHECK(node_filter == NodeFilter::kCPUOnly); return !has_gpu; }; @@ -149,18 +149,16 @@ int64_t HybridPolicy(const ResourceRequest &resource_request, const int64_t loca spread_threshold, force_spillback, require_available); } - // Try schedule on non-GPU nodes. - auto best_node_id = HybridPolicyWithFilter( - resource_request, local_node_id, nodes, spread_threshold, force_spillback, - /*require_available*/ true, NodeFilter::kNonGpu); - if (best_node_id != -1) { - return best_node_id; + // Try schedule on CPU-only nodes. + const auto node_id = + HybridPolicyWithFilter(resource_request, local_node_id, nodes, spread_threshold, + force_spillback, require_available, NodeFilter::kCPUOnly); + if (node_id != -1) { + return node_id; } - - // If we cannot find any available node from non-gpu nodes, fallback to the original - // scheduling + // Could not schedule on CPU-only nodes, schedule on GPU nodes as a last resort. return HybridPolicyWithFilter(resource_request, local_node_id, nodes, spread_threshold, - force_spillback, require_available); + force_spillback, require_available, NodeFilter::kGPU); } } // namespace raylet_scheduling_policy diff --git a/src/ray/raylet/scheduling/scheduling_policy.h b/src/ray/raylet/scheduling/scheduling_policy.h index b137491576690..b6f382ff1d078 100644 --- a/src/ray/raylet/scheduling/scheduling_policy.h +++ b/src/ray/raylet/scheduling/scheduling_policy.h @@ -62,15 +62,8 @@ int64_t HybridPolicy( bool force_spillback, bool require_available, bool scheduler_avoid_gpu_nodes = RayConfig::instance().scheduler_avoid_gpu_nodes()); -enum class NodeFilter { - /// Default scheduling. - kAny, - /// Schedule on GPU only nodes. - kGPU, - /// Schedule on nodes that don't have GPU. Since GPUs are more scarce resources, we need - /// special handling for this. - kNonGpu -}; +// +enum class NodeFilter { kAny, kGPU, kCPUOnly }; /// \param resource_request: The resource request we're attempting to schedule. /// \param local_node_id: The id of the local node, which is needed for traversal order. @@ -79,7 +72,7 @@ enum class NodeFilter { /// truncated to 0. /// \param node_filter: defines the subset of nodes were are allowed to schedule on. /// can be one of kAny (can schedule on all nodes), kGPU (can only schedule on kGPU -/// nodes), kNonGpu (can only schedule on non-GPU nodes. +/// nodes), kCPUOnly (can only schedule on non-GPU nodes. /// /// \return -1 if the task is unfeasible, otherwise the node id (key in `nodes`) to /// schedule on. diff --git a/src/ray/raylet/scheduling/scheduling_policy_test.cc b/src/ray/raylet/scheduling/scheduling_policy_test.cc index 6a834db1966e9..fb51d7f4c8711 100644 --- a/src/ray/raylet/scheduling/scheduling_policy_test.cc +++ b/src/ray/raylet/scheduling/scheduling_policy_test.cc @@ -338,42 +338,6 @@ TEST_F(SchedulingPolicyTest, ForceSpillbackOnlyFeasibleLocallyTest) { ASSERT_EQ(to_schedule, -1); } -TEST_F(SchedulingPolicyTest, NonGpuNodePreferredSchedulingTest) { - // Prefer to schedule on CPU nodes first. - // GPU nodes should be preferred as a last resort. - StringIdMap map; - int64_t local_node = 0; - int64_t remote_node_1 = 1; - int64_t remote_node_2 = 2; - - // local {CPU:2, GPU:1} - // Remote {CPU: 2} - absl::flat_hash_map nodes; - nodes.emplace(local_node, CreateNodeResources(2, 2, 0, 0, 1, 1)); - nodes.emplace(remote_node_1, CreateNodeResources(2, 2, 0, 0, 0, 0)); - nodes.emplace(remote_node_2, CreateNodeResources(3, 3, 0, 0, 0, 0)); - - ResourceRequest req = ResourceMapToResourceRequest(map, {{"CPU", 1}}, false); - int to_schedule = raylet_scheduling_policy::HybridPolicy( - req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); - ASSERT_EQ(to_schedule, remote_node_1); - - req = ResourceMapToResourceRequest(map, {{"CPU", 3}}, false); - to_schedule = raylet_scheduling_policy::HybridPolicy( - req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); - ASSERT_EQ(to_schedule, remote_node_2); - - req = ResourceMapToResourceRequest(map, {{"CPU", 1}, {"GPU", 1}}, false); - to_schedule = raylet_scheduling_policy::HybridPolicy( - req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); - ASSERT_EQ(to_schedule, local_node); - - req = ResourceMapToResourceRequest(map, {{"CPU", 2}}, false); - to_schedule = raylet_scheduling_policy::HybridPolicy( - req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); - ASSERT_EQ(to_schedule, remote_node_1); -} - int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index fd2b7b723f755..08331a75f176d 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -14,7 +14,7 @@ #include "ray/raylet/worker.h" -#include +#include #include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/raylet.h" diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 959cc551f0dbc..f0268021280f8 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -176,6 +176,7 @@ Process WorkerPool::StartWorkerProcess( const Language &language, const rpc::WorkerType worker_type, const JobID &job_id, PopWorkerStatus *status, const std::vector &dynamic_options, const int runtime_env_hash, const std::string &serialized_runtime_env, + std::unordered_map override_environment_variables, const std::string &serialized_runtime_env_context, const std::string &allocated_instances_serialized_json) { rpc::JobConfig *job_config = nullptr; @@ -312,41 +313,39 @@ Process WorkerPool::StartWorkerProcess( // need to add a new CLI parameter for both Python and Java workers. env.emplace(kEnvVarKeyJobId, job_id.Hex()); } + if (job_config) { + env.insert(job_config->worker_env().begin(), job_config->worker_env().end()); + } + + for (const auto &pair : override_environment_variables) { + env[pair.first] = pair.second; + } - if (language == Language::PYTHON || language == Language::JAVA) { + if (language == Language::PYTHON) { if (serialized_runtime_env != "{}" && serialized_runtime_env != "") { worker_command_args.push_back("--serialized-runtime-env=" + serialized_runtime_env); // Allocated_resource_json is only used in "shim process". worker_command_args.push_back("--allocated-instances-serialized-json=" + allocated_instances_serialized_json); - - worker_command_args.push_back("--language=" + Language_Name(language)); - - worker_command_args.push_back("--runtime-env-hash=" + - std::to_string(runtime_env_hash)); - - if (serialized_runtime_env_context != "{}" && - !serialized_runtime_env_context.empty()) { - worker_command_args.push_back("--serialized-runtime-env-context=" + - serialized_runtime_env_context); - } } else { // The "shim process" setup worker is not needed, so do not run it. // Check that the arg really is the path to the setup worker before erasing it, to // prevent breaking tests that mock out the worker command args. if (worker_command_args.size() >= 2 && worker_command_args[1].find(kSetupWorkerFilename) != std::string::npos) { - if (language == Language::PYTHON) { - worker_command_args.erase(worker_command_args.begin() + 1, - worker_command_args.begin() + 2); - } else { - // Erase the python executable as well for other languages. - worker_command_args.erase(worker_command_args.begin(), - worker_command_args.begin() + 2); - } + worker_command_args.erase(worker_command_args.begin() + 1, + worker_command_args.begin() + 2); } } + worker_command_args.push_back("--runtime-env-hash=" + + std::to_string(runtime_env_hash)); + + if (serialized_runtime_env_context != "{}" && serialized_runtime_env_context != "") { + worker_command_args.push_back("--serialized-runtime-env-context=" + + serialized_runtime_env_context); + } + if (ray_debugger_external) { worker_command_args.push_back("--ray-debugger-external"); } @@ -484,24 +483,6 @@ void WorkerPool::MarkPortAsFree(int port) { void WorkerPool::HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config) { all_jobs_[job_id] = job_config; - if (job_config.runtime_env().runtime_env_eager_install() && - job_config.has_runtime_env()) { - auto const &runtime_env = job_config.runtime_env().serialized_runtime_env(); - RAY_LOG(INFO) << "[Eagerly] Start install runtime environment for job " << job_id - << ". The runtime environment was " << runtime_env << "."; - CreateRuntimeEnv( - runtime_env, job_id, - [job_id](bool successful, const std::string &serialized_runtime_env_context) { - if (successful) { - RAY_LOG(INFO) << "[Eagerly] Create runtime env successful for job " << job_id - << ". The result context was " << serialized_runtime_env_context - << "."; - } else { - RAY_LOG(ERROR) << "[Eagerly] Couldn't create a runtime environment for job " - << job_id << "."; - } - }); - } } void WorkerPool::HandleJobFinished(const JobID &job_id) { @@ -768,7 +749,7 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { // The worker is used for the actor creation task with dynamic options. if (!used) { // Put it into idle dedicated worker pool. - // TODO(SongGuyang): This worker will not be used forever. We should kill it. + // TODO(guyang.sgy): This worker will not be used forever. We should kill it. state.idle_dedicated_workers[task_id] = worker; } return; @@ -940,8 +921,7 @@ void WorkerPool::TryKillingIdleWorkers() { void WorkerPool::PopWorker(const TaskSpecification &task_spec, const PopWorkerCallback &callback, const std::string &allocated_instances_serialized_json) { - RAY_LOG(DEBUG) << "Pop worker for task " << task_spec.TaskId() << " task name " - << task_spec.FunctionDescriptor()->ToString(); + RAY_LOG(DEBUG) << "Pop worker for task " << task_spec.TaskId(); auto &state = GetStateForLanguage(task_spec.GetLanguage()); std::shared_ptr worker = nullptr; @@ -956,7 +936,8 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, Process proc = StartWorkerProcess( task_spec.GetLanguage(), rpc::WorkerType::WORKER, task_spec.JobId(), &status, dynamic_options, task_spec.GetRuntimeEnvHash(), serialized_runtime_env, - serialized_runtime_env_context, allocated_instances_serialized_json); + task_spec.OverrideEnvironmentVariables(), serialized_runtime_env_context, + allocated_instances_serialized_json); if (status == PopWorkerStatus::OK) { RAY_CHECK(proc.IsValid()); WarnAboutSize(); @@ -967,7 +948,7 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, state.starting_workers_to_tasks[proc] = std::move(task_info); } } else { - // TODO(SongGuyang): Wait until a worker is pushed or a worker can be started If + // TODO(guyang.sgy): Wait until a worker is pushed or a worker can be started If // startup concurrency maxed out or job not started. PopWorkerCallbackAsync(callback, nullptr, status); } @@ -995,24 +976,24 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, dynamic_options = task_spec.DynamicWorkerOptions(); } + // create runtime env. if (task_spec.HasRuntimeEnv()) { - // create runtime env. - CreateRuntimeEnv( - task_spec.SerializedRuntimeEnv(), task_spec.JobId(), - [start_worker_process_fn, callback, &state, task_spec, dynamic_options]( - bool successful, const std::string &serialized_runtime_env_context) { - if (successful) { + agent_manager_->CreateRuntimeEnv( + task_spec.JobId(), task_spec.SerializedRuntimeEnv(), + [start_worker_process_fn, callback, &state, task_spec, dynamic_options, + allocated_instances_serialized_json]( + bool success, const std::string &serialized_runtime_env_context) { + if (success) { start_worker_process_fn(task_spec, state, dynamic_options, true, task_spec.SerializedRuntimeEnv(), serialized_runtime_env_context, callback); } else { + RAY_LOG(WARNING) << "Couldn't create a runtime environment for task " + << task_spec.TaskId() << ". The runtime environment was " + << task_spec.SerializedRuntimeEnv() << "."; callback(nullptr, PopWorkerStatus::RuntimeEnvCreationFailed); - RAY_LOG(WARNING) - << "Create runtime env failed for task " << task_spec.TaskId() - << " and couldn't create the dedicated worker."; } - }, - allocated_instances_serialized_json); + }); } else { start_worker_process_fn(task_spec, state, dynamic_options, true, "", "", callback); @@ -1055,8 +1036,8 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, // Start a new worker process. if (task_spec.HasRuntimeEnv()) { // create runtime env. - CreateRuntimeEnv( - task_spec.SerializedRuntimeEnv(), task_spec.JobId(), + agent_manager_->CreateRuntimeEnv( + task_spec.JobId(), task_spec.SerializedRuntimeEnv(), [start_worker_process_fn, callback, &state, task_spec]( bool successful, const std::string &serialized_runtime_env_context) { if (successful) { @@ -1064,13 +1045,12 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, task_spec.SerializedRuntimeEnv(), serialized_runtime_env_context, callback); } else { + RAY_LOG(WARNING) << "Couldn't create a runtime environment for task " + << task_spec.TaskId() << ". The runtime environment was " + << task_spec.SerializedRuntimeEnv() << "."; callback(nullptr, PopWorkerStatus::RuntimeEnvCreationFailed); - RAY_LOG(WARNING) - << "Create runtime env failed for task " << task_spec.TaskId() - << " and couldn't create the worker."; } - }, - allocated_instances_serialized_json); + }); } else { start_worker_process_fn(task_spec, state, {}, false, "", "", callback); } @@ -1087,7 +1067,7 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, int64_t bac int64_t num_available_cpus) { // Code path of task that needs a dedicated worker. if ((task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) || - task_spec.HasRuntimeEnv()) { + task_spec.OverrideEnvironmentVariables().size() > 0 || task_spec.HasRuntimeEnv()) { return; // Not handled. // TODO(architkulkarni): We'd eventually like to prestart workers with the same // runtime env to improve initial startup performance. @@ -1344,26 +1324,6 @@ WorkerPool::IOWorkerState &WorkerPool::GetIOWorkerStateFromWorkerType( UNREACHABLE; } -void WorkerPool::CreateRuntimeEnv( - const std::string &serialized_runtime_env, const JobID &job_id, - const std::function &callback, - const std::string &serialized_allocated_resource_instances) { - // create runtime env. - agent_manager_->CreateRuntimeEnv( - job_id, serialized_runtime_env, serialized_allocated_resource_instances, - [job_id, serialized_runtime_env, callback]( - bool successful, const std::string &serialized_runtime_env_context) { - if (successful) { - callback(true, serialized_runtime_env_context); - } else { - RAY_LOG(WARNING) << "Couldn't create a runtime environment for job " << job_id - << ". The runtime environment was " << serialized_runtime_env - << "."; - callback(false, ""); - } - }); -} - } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 92c19329c17dc..7991600cfd6c6 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -397,6 +397,7 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { PopWorkerStatus *status /*output*/, const std::vector &dynamic_options = {}, const int runtime_env_hash = 0, const std::string &serialized_runtime_env = "{}", + std::unordered_map override_environment_variables = {}, const std::string &serialized_runtime_env_context = "{}", const std::string &allocated_instances_serialized_json = "{}"); @@ -588,12 +589,6 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { const PopWorkerStatus &status, bool *found /* output */, bool *worker_used /* output */, TaskID *task_id /* output */); - /// Create runtime env asynchronously by runtime env agent. - void CreateRuntimeEnv( - const std::string &serialized_runtime_env, const JobID &job_id, - const std::function &callback, - const std::string &serialized_allocated_resource_instances = "{}"); - /// For Process class for managing subprocesses (e.g. reaping zombies). instrumented_io_context *io_service_; /// Node ID of the current node. diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 37fb903b4a7ab..9a28520700a8e 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -103,10 +103,9 @@ class WorkerPoolMock : public WorkerPool { const WorkerCommandMap &worker_commands, absl::flat_hash_map> &mock_worker_rpc_clients) - : WorkerPool( - io_service, NodeID::FromRandom(), "", POOL_SIZE_SOFT_LIMIT, 0, - MAXIMUM_STARTUP_CONCURRENCY, 0, 0, {}, nullptr, worker_commands, []() {}, 0, - [this]() { return current_time_ms_; }), + : WorkerPool(io_service, NodeID::FromRandom(), "", POOL_SIZE_SOFT_LIMIT, 0, + MAXIMUM_STARTUP_CONCURRENCY, 0, 0, {}, nullptr, worker_commands, + []() {}, 0, [this]() { return current_time_ms_; }), last_worker_process_(), instrumented_io_service_(io_service), error_message_type_(1), @@ -258,7 +257,7 @@ class WorkerPoolMock : public WorkerPool { is_java = true; } } - // TODO(SongGuyang): support C++ language workers. + // TODO(guyang.sgy): support C++ language workers. int num_workers = is_java ? NUM_WORKERS_PER_PROCESS_JAVA : 1; for (int i = 0; i < num_workers; i++) { auto worker = @@ -459,7 +458,7 @@ static inline TaskSpecification ExampleTaskSpec( } else { message.set_type(TaskType::NORMAL_TASK); } - message.mutable_runtime_env()->set_serialized_runtime_env(serialized_runtime_env); + message.set_serialized_runtime_env(serialized_runtime_env); return TaskSpecification(std::move(message)); } @@ -1258,7 +1257,8 @@ TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) { ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), /*dynamic_options=*/{}, TaskID::ForFakeTask(), "mock_runtime_env_2"); - const WorkerCacheKey env1 = {"mock_runtime_env_1", {}}; + const WorkerCacheKey env1 = { + /*override_environment_variables=*/{}, "mock_runtime_env_1", {}}; const int runtime_env_hash_1 = env1.IntHash(); // Push worker with runtime env 1. diff --git a/src/ray/raylet_client/raylet_client.cc b/src/ray/raylet_client/raylet_client.cc index 3da524b8611a4..41e9611491c7d 100644 --- a/src/ray/raylet_client/raylet_client.cc +++ b/src/ray/raylet_client/raylet_client.cc @@ -296,20 +296,13 @@ Status raylet::RayletClient::FreeObjects(const std::vector &object_ids } void raylet::RayletClient::RequestWorkerLease( - const rpc::TaskSpec &task_spec, + const TaskSpecification &resource_spec, const rpc::ClientCallback &callback, const int64_t backlog_size) { - google::protobuf::Arena arena; - auto request = - google::protobuf::Arena::CreateMessage(&arena); - // The unsafe allocating here is actually safe because the life-cycle of - // task_spec is longer than request. - // Request will be sent before the end of this call, and after that, it won't be - // used any more. - request->unsafe_arena_set_allocated_resource_spec( - const_cast(&task_spec)); - request->set_backlog_size(backlog_size); - grpc_client_->RequestWorkerLease(*request, callback); + rpc::RequestWorkerLeaseRequest request; + request.mutable_resource_spec()->CopyFrom(resource_spec.GetMessage()); + request.set_backlog_size(backlog_size); + grpc_client_->RequestWorkerLease(request, callback); } /// Spill objects to external storage. @@ -321,20 +314,6 @@ void raylet::RayletClient::RequestObjectSpillage( grpc_client_->RequestObjectSpillage(request, callback); } -void raylet::RayletClient::ReportWorkerBacklog( - const WorkerID &worker_id, - const std::vector &backlog_reports) { - rpc::ReportWorkerBacklogRequest request; - request.set_worker_id(worker_id.Binary()); - request.mutable_backlog_reports()->Add(backlog_reports.begin(), backlog_reports.end()); - grpc_client_->ReportWorkerBacklog( - request, [](const Status &status, const rpc::ReportWorkerBacklogReply &reply) { - if (!status.ok()) { - RAY_LOG(INFO) << "Error reporting task backlog information: " << status; - } - }); -} - Status raylet::RayletClient::ReturnWorker(int worker_port, const WorkerID &worker_id, bool disconnect_worker) { rpc::ReturnWorkerRequest request; @@ -394,7 +373,7 @@ void raylet::RayletClient::CommitBundleResources( } void raylet::RayletClient::CancelResourceReserve( - const BundleSpecification &bundle_spec, + BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) { rpc::CancelResourceReserveRequest request; request.mutable_bundle_spec()->CopyFrom(bundle_spec.GetMessage()); diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index 547e8eaa7ee00..558fed24b24cf 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -68,10 +68,6 @@ class WorkerLeaseInterface { const ray::TaskSpecification &resource_spec, const ray::rpc::ClientCallback &callback, const int64_t backlog_size = -1) = 0; - virtual void RequestWorkerLease( - const rpc::TaskSpec &task_spec, - const ray::rpc::ClientCallback &callback, - const int64_t backlog_size = -1) = 0; /// Returns a worker to the raylet. /// \param worker_port The local port of the worker on the raylet node. @@ -93,14 +89,6 @@ class WorkerLeaseInterface { const TaskID &task_id, const rpc::ClientCallback &callback) = 0; - /// Report the backlog size of a given worker and a given scheduling class to the - /// raylet. - /// \param worker_id The ID of the worker that reports the backlog size. - /// \param backlog_reports The backlog report for each scheduling class - virtual void ReportWorkerBacklog( - const WorkerID &worker_id, - const std::vector &backlog_reports) = 0; - virtual ~WorkerLeaseInterface(){}; }; @@ -129,7 +117,7 @@ class ResourceReserveInterface { const ray::rpc::ClientCallback &callback) = 0; virtual void CancelResourceReserve( - const BundleSpecification &bundle_spec, + BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) = 0; virtual void ReleaseUnusedBundles( @@ -372,24 +360,12 @@ class RayletClient : public RayletClientInterface { void RequestWorkerLease( const ray::TaskSpecification &resource_spec, const ray::rpc::ClientCallback &callback, - const int64_t backlog_size) override { - RequestWorkerLease(resource_spec.GetMessage(), callback, backlog_size); - } - - void RequestWorkerLease( - const rpc::TaskSpec &resource_spec, - const ray::rpc::ClientCallback &callback, const int64_t backlog_size) override; /// Implements WorkerLeaseInterface. ray::Status ReturnWorker(int worker_port, const WorkerID &worker_id, bool disconnect_worker) override; - /// Implements WorkerLeaseInterface. - void ReportWorkerBacklog( - const WorkerID &worker_id, - const std::vector &backlog_reports) override; - /// Implements WorkerLeaseInterface. void ReleaseUnusedWorkers( const std::vector &workers_in_use, @@ -413,7 +389,7 @@ class RayletClient : public RayletClientInterface { /// Implements CancelResourceReserveInterface. void CancelResourceReserve( - const BundleSpecification &bundle_spec, + BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) override; diff --git a/src/ray/rpc/common.cc b/src/ray/rpc/common.cc index 7526c1e6efc6f..eef01f3e1e2f5 100644 --- a/src/ray/rpc/common.cc +++ b/src/ray/rpc/common.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ray/rpc/common.h" - #include #include +#include "ray/rpc/common.h" + namespace ray::rpc { std::string ReadCert(const std::string &cert_filepath) { @@ -26,4 +26,4 @@ std::string ReadCert(const std::string &cert_filepath) { return buffer.str(); }; -} // namespace ray::rpc +} // namespace rpc::ray diff --git a/src/ray/rpc/common.h b/src/ray/rpc/common.h index 314e1eccf382c..929a555a942f6 100644 --- a/src/ray/rpc/common.h +++ b/src/ray/rpc/common.h @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - namespace ray::rpc { // Utility to read cert file from a particular location diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 8f3f98b67445c..3840527fb5a9a 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -36,13 +36,11 @@ DEFINE_stats(grpc_server_req_finished, "Finished request number in grpc server", namespace ray { namespace rpc { -GrpcServer::GrpcServer(std::string name, const uint32_t port, - bool listen_to_localhost_only, int num_threads, - int64_t keepalive_time_ms, bool use_tls) +GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, + bool use_tls, int64_t keepalive_time_ms) : name_(std::move(name)), port_(port), use_tls_(use_tls), - listen_to_localhost_only_(listen_to_localhost_only), is_closed_(true), num_threads_(num_threads), keepalive_time_ms_(keepalive_time_ms) { @@ -51,8 +49,7 @@ GrpcServer::GrpcServer(std::string name, const uint32_t port, void GrpcServer::Run() { uint32_t specified_port = port_; - std::string server_address((listen_to_localhost_only_ ? "127.0.0.1:" : "0.0.0.0:") + - std::to_string(port_)); + std::string server_address("0.0.0.0:" + std::to_string(port_)); grpc::ServerBuilder builder; // Disable the SO_REUSEPORT option. We don't need it in ray. If the option is enabled // (default behavior in grpc), we may see multiple workers listen on the same port and diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index c83628b72b2e8..826efbdf260bb 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -61,11 +61,9 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - - GrpcServer(std::string name, const uint32_t port, bool listen_to_localhost_only, - int num_threads = 1, - int64_t keepalive_time_ms = 7200000, /*2 hours, grpc default*/ - bool use_tls = false); + GrpcServer(std::string name, const uint32_t port, int num_threads = 1, + bool use_tls = false, + int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/); /// Destruct this gRPC server. ~GrpcServer() { Shutdown(); } @@ -116,9 +114,6 @@ class GrpcServer { int port_; /// Whether to use TLS. bool use_tls_; - /// Listen to localhost (127.0.0.1) only if it's true, otherwise listen to all network - /// interfaces (0.0.0.0) - const bool listen_to_localhost_only_; /// Indicates whether this server has been closed. bool is_closed_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. diff --git a/src/ray/rpc/node_manager/node_manager_client.h b/src/ray/rpc/node_manager/node_manager_client.h index fad890c990e00..341613a848e98 100644 --- a/src/ray/rpc/node_manager/node_manager_client.h +++ b/src/ray/rpc/node_manager/node_manager_client.h @@ -79,9 +79,6 @@ class NodeManagerWorkerClient /// Request a worker lease. VOID_RPC_CLIENT_METHOD(NodeManagerService, RequestWorkerLease, grpc_client_, ) - /// Report task backlog information - VOID_RPC_CLIENT_METHOD(NodeManagerService, ReportWorkerBacklog, grpc_client_, ) - /// Return a worker lease. VOID_RPC_CLIENT_METHOD(NodeManagerService, ReturnWorker, grpc_client_, ) diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index 2cec90a4512f7..7f7d2a5a9738b 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -28,7 +28,6 @@ namespace rpc { RPC_SERVICE_HANDLER(NodeManagerService, UpdateResourceUsage, -1) \ RPC_SERVICE_HANDLER(NodeManagerService, RequestResourceReport, -1) \ RPC_SERVICE_HANDLER(NodeManagerService, RequestWorkerLease, -1) \ - RPC_SERVICE_HANDLER(NodeManagerService, ReportWorkerBacklog, -1) \ RPC_SERVICE_HANDLER(NodeManagerService, ReturnWorker, -1) \ RPC_SERVICE_HANDLER(NodeManagerService, ReleaseUnusedWorkers, -1) \ RPC_SERVICE_HANDLER(NodeManagerService, CancelWorkerLease, -1) \ @@ -71,10 +70,6 @@ class NodeManagerServiceHandler { RequestWorkerLeaseReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void HandleReportWorkerBacklog(const ReportWorkerBacklogRequest &request, - ReportWorkerBacklogReply *reply, - SendReplyCallback send_reply_callback) = 0; - virtual void HandleReturnWorker(const ReturnWorkerRequest &request, ReturnWorkerReply *reply, SendReplyCallback send_reply_callback) = 0; diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index 9e2d50e8324e4..d3c199b50c6bb 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -14,11 +14,10 @@ #pragma once -#include #include +#include #include - #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/grpc_util.h" #include "ray/common/status.h" @@ -146,7 +145,6 @@ class ServerCallImpl : public ServerCall { response_writer_(&context_), io_service_(io_service), call_name_(std::move(call_name)) { - reply_ = google::protobuf::Arena::CreateMessage(&arena_); // TODO call_name_ sometimes get corrunpted due to memory issues. RAY_CHECK(!call_name_.empty()) << "Call name is empty"; STATS_grpc_server_req_new.Record(1.0, call_name_); @@ -189,7 +187,7 @@ class ServerCallImpl : public ServerCall { factory.CreateCall(); } (service_handler_.*handle_request_function_)( - request_, reply_, + request_, &reply_, [this](Status status, std::function success, std::function failure) { // These two callbacks must be set before `SendReply`, because `SendReply` @@ -224,13 +222,9 @@ class ServerCallImpl : public ServerCall { /// Tell gRPC to finish this request and send reply asynchronously. void SendReply(const Status &status) { state_ = ServerCallState::SENDING_REPLY; - response_writer_.Finish(*reply_, RayStatusToGrpcStatus(status), this); + response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this); } - /// The memory pool for this request. It's used for reply. - /// With arena, we'll be able to setup the reply without copying some field. - google::protobuf::Arena arena_; - /// State of this call. ServerCallState state_; @@ -256,9 +250,8 @@ class ServerCallImpl : public ServerCall { /// The request message. Request request_; - /// The reply message. This one is owned by arena. It's not valid beyond - /// the life-cycle of this call. - Reply *reply_; + /// The reply message. + Reply reply_; /// Human-readable name for this RPC call. std::string call_name_; diff --git a/src/ray/rpc/test/grpc_server_client_test.cc b/src/ray/rpc/test/grpc_server_client_test.cc index 3bd86f5a24f63..e7b602e6b316f 100644 --- a/src/ray/rpc/test/grpc_server_client_test.cc +++ b/src/ray/rpc/test/grpc_server_client_test.cc @@ -13,7 +13,6 @@ // limitations under the License. #include - #include "gtest/gtest.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" @@ -36,14 +35,13 @@ class TestServiceHandler { RAY_LOG(INFO) << "No reply!"; return; } - send_reply_callback( - ray::Status::OK(), - /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, - /*reply_failure=*/ - [this]() { - RAY_LOG(INFO) << "Reply failed."; - reply_failure_count++; - }); + send_reply_callback(ray::Status::OK(), + /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, + /*reply_failure=*/ + [this]() { + RAY_LOG(INFO) << "Reply failed."; + reply_failure_count++; + }); } std::atomic request_count{0}; std::atomic reply_failure_count{0}; @@ -85,7 +83,7 @@ class TestGrpcServerClientFixture : public ::testing::Test { handler_io_service_.run(); }); test_service_.reset(new TestGrpcService(handler_io_service_, test_service_handler_)); - grpc_server_.reset(new GrpcServer("test", 0, true)); + grpc_server_.reset(new GrpcServer("test", 0)); grpc_server_->RegisterService(*test_service_); grpc_server_->Run(); diff --git a/src/ray/util/event.h b/src/ray/util/event.h index 4f2e98a4427c3..9caed946f3af1 100644 --- a/src/ray/util/event.h +++ b/src/ray/util/event.h @@ -13,8 +13,6 @@ // limitations under the License. #pragma once -#include - #include #include #include @@ -24,8 +22,6 @@ #include #include #include - -#include "nlohmann/json.hpp" #include "ray/util/logging.h" #include "ray/util/util.h" #include "spdlog/sinks/basic_file_sink.h" @@ -33,6 +29,10 @@ #include "spdlog/spdlog.h" #include "src/ray/protobuf/event.pb.h" +#include "nlohmann/json.hpp" + +#include + using json = nlohmann::json; namespace ray { @@ -102,7 +102,7 @@ class EventManager final { // We added `const json &custom_fields` here because we need to support typed custom // fields. - // TODO(SongGuyang): Remove the protobuf `rpc::Event` and use an internal struct + // TODO(guyang.sgy): Remove the protobuf `rpc::Event` and use an internal struct // instead. void Publish(const rpc::Event &event, const json &custom_fields); diff --git a/src/ray/util/util.h b/src/ray/util/util.h index 95500e91694a7..9b2e3f443dbac 100644 --- a/src/ray/util/util.h +++ b/src/ray/util/util.h @@ -21,6 +21,7 @@ #include #include #include + #include #include "ray/util/logging.h" @@ -166,7 +167,7 @@ class InitShutdownRAII { /// \param shutdown_func The shutdown function. /// \param args The arguments for the init function. template - InitShutdownRAII(InitFunc init_func, ShutdownFunc shutdown_func, Args &&...args) + InitShutdownRAII(InitFunc init_func, ShutdownFunc shutdown_func, Args &&... args) : shutdown_(shutdown_func) { init_func(args...); } @@ -258,7 +259,7 @@ template class ThreadPrivate { public: template - explicit ThreadPrivate(Ts &&...ts) : t_(std::forward(ts)...) {} + ThreadPrivate(Ts &&... ts) : t_(std::forward(ts)...) {} T &operator*() { ThreadCheck(); @@ -311,43 +312,4 @@ class ThreadPrivate { mutable std::mutex mutex_; }; -class ExponentialBackOff { - public: - ExponentialBackOff() = default; - ExponentialBackOff(const ExponentialBackOff &) = default; - ExponentialBackOff(ExponentialBackOff &&) = default; - ExponentialBackOff &operator=(const ExponentialBackOff &) = default; - ExponentialBackOff &operator=(ExponentialBackOff &&) = default; - - /// Construct an exponential back off counter. - /// - /// \param[in] initial_value The start value for this counter - /// \param[in] multiplier The multiplier for this counter. - /// \param[in] max_value The maximum value for this counter. By default it's - /// infinite double. - ExponentialBackOff(uint64_t initial_value, double multiplier, - uint64_t max_value = std::numeric_limits::max()) - : curr_value_(initial_value), - initial_value_(initial_value), - max_value_(max_value), - multiplier_(multiplier) { - RAY_CHECK(multiplier > 0.0) << "Multiplier must be greater than 0"; - } - - uint64_t Next() { - auto ret = curr_value_; - curr_value_ = curr_value_ * multiplier_; - curr_value_ = std::min(curr_value_, max_value_); - return ret; - } - - void Reset() { curr_value_ = initial_value_; } - - private: - uint64_t curr_value_; - uint64_t initial_value_; - uint64_t max_value_; - double multiplier_; -}; - } // namespace ray diff --git a/src/ray/util/util_test.cc b/src/ray/util/util_test.cc index 3e13dedb10bf9..435f1598f4f69 100644 --- a/src/ray/util/util_test.cc +++ b/src/ray/util/util_test.cc @@ -102,23 +102,6 @@ TEST(UtilTest, ParseCommandLineTest) { ASSERT_EQ(ParseCommandLine(R"(x' a \b')", win32), ArgList({R"(x')", R"(a)", R"(\b')"})); } -TEST(UtilTest, ExponentialBackOffTest) { - auto exp = ExponentialBackOff(1, 2, 9); - ASSERT_EQ(1, exp.Next()); - ASSERT_EQ(2, exp.Next()); - ASSERT_EQ(4, exp.Next()); - ASSERT_EQ(8, exp.Next()); - ASSERT_EQ(9, exp.Next()); - ASSERT_EQ(9, exp.Next()); - exp.Reset(); - ASSERT_EQ(1, exp.Next()); - ASSERT_EQ(2, exp.Next()); - ASSERT_EQ(4, exp.Next()); - ASSERT_EQ(8, exp.Next()); - ASSERT_EQ(9, exp.Next()); - ASSERT_EQ(9, exp.Next()); -} - TEST(UtilTest, ParseURLTest) { const std::string url = "http://abc?num_objects=9&offset=8388878&size=8388878"; auto parsed_url = *ParseURL(url); diff --git a/streaming/src/queue/queue_handler.h b/streaming/src/queue/queue_handler.h index 26bd863e85ecc..e04e34b359804 100644 --- a/streaming/src/queue/queue_handler.h +++ b/streaming/src/queue/queue_handler.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include #include diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index 5e5b575223a6b..c51b1a8a11a5b 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -639,7 +639,7 @@ class StreamingWorker { } // namespace ray int main(int argc, char **argv) { - RAY_CHECK(argc >= 4); + RAY_CHECK(argc == 5); auto store_socket = std::string(argv[1]); auto raylet_socket = std::string(argv[2]); auto node_manager_port = std::stoi(std::string(argv[3])); diff --git a/thirdparty/patches/prometheus-windows-pollfd.patch b/thirdparty/patches/prometheus-windows-pollfd.patch index 3b30942bb85f2..1941b6cb247c0 100644 --- a/thirdparty/patches/prometheus-windows-pollfd.patch +++ b/thirdparty/patches/prometheus-windows-pollfd.patch @@ -6,46 +6,17 @@ Windows Vista and later SDKs define struct pollfd for WSAPoll(), but it has a pe civetweb provides its own implementation of poll, but it has a conflicting definition for pollfd. Hence we block Windows from defining pollfd (which this project doesn't use). --- - bazel/civetweb.BUILD | 7 +++++++ - 1 file changed, 7 insertions(+) + bazel/civetweb.BUILD | 1 + + 1 file changed, 1 insertion(+) diff --git bazel/civetweb.BUILD bazel/civetweb.BUILD --- bazel/civetweb.BUILD +++ bazel/civetweb.BUILD -@@ -9,6 +9,11 @@ config_setting( - values = {"cpu": "darwin_x86_64"}, - ) - -+config_setting( -+ name = "darwin_arm64", -+ values = {"cpu": "darwin_arm64"}, -+) -+ - config_setting( - name = "windows", - values = { "cpu": "x64_windows" }, -@@ -34,6 +39,7 @@ cc_library( +@@ -34,5 +34,6 @@ cc_library( "-DNO_CACHING", "-DNO_SSL", "-DNO_FILES", + "-D_WIN32_WINNT=0x0502", "-UDEBUG", ], - includes = [ -@@ -46,6 +52,7 @@ cc_library( - }) + select({ - ":darwin": [], - ":darwin_x86_64": [], -+ ":darwin_arm64": [], - ":windows": [], - ":windows_msvc": [], - "//conditions:default": ["-lrt"], -@@ -86,6 +93,7 @@ cc_library( - }) + select({ - ":darwin": [], - ":darwin_x86_64": [], -+ ":darwin_arm64": [], - ":windows": [], - ":windows_msvc": [], - "//conditions:default": ["-lrt"], --- +-- diff --git a/thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch b/thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch new file mode 100644 index 0000000000000..9cd53fe60f842 --- /dev/null +++ b/thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch @@ -0,0 +1,8 @@ +diff --git BUILD.boost BUILD.boost +--- BUILD.boost ++++ BUILD.boost +@@ -1356,3 +1356,2 @@ boost_library( + defines = [ +- "BOOST_FALLTHROUGH", + ], +-- diff --git a/thirdparty/patches/rules_boost-windows-linkopts.patch b/thirdparty/patches/rules_boost-windows-linkopts.patch index 204443d3c7186..28bda4eb06939 100644 --- a/thirdparty/patches/rules_boost-windows-linkopts.patch +++ b/thirdparty/patches/rules_boost-windows-linkopts.patch @@ -1,12 +1,15 @@ diff --git BUILD.boost BUILD.boost --- BUILD.boost +++ BUILD.boost -@@ -428,6 +428,7 @@ boost_library( - }), - linkopts = select({ - ":android": [], -+ ":windows": [], - "//conditions:default": ["-lpthread"], - }), - deps = [ --- +@@ -313,1 +313,9 @@ boost_library(name = "asio", +- linkopts = ["-lpthread"], ++ linkopts = select({ ++ ":linux": [ ++ "-lpthread", ++ ], ++ ":osx_x86_64": [ ++ "-lpthread", ++ ], ++ "//conditions:default": [], ++ }), +-- From 8f02386d7f2c15f3ea2f056ab5491a033a61a3cf Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 12 Oct 2021 16:13:10 +0100 Subject: [PATCH 41/56] Merge master --- .bazelrc | 10 +- .buildkite/pipeline.gpu.large.yml | 8 + .buildkite/pipeline.gpu.yml | 10 + .buildkite/pipeline.macos.yml | 2 +- .buildkite/pipeline.yml | 20 +- .clang-tidy | 41 +- .flake8 | 16 + .github/CODEOWNERS | 12 +- .github/workflows/main.yml | 3 +- .gitpod/Dockerfile | 2 +- BUILD.bazel | 33 +- bazel/ray_deps_setup.bzl | 9 +- benchmarks/object_store/test_object_store.py | 1 + benchmarks/single_node/test_single_node.py | 3 +- ci/travis/bazel.py | 42 +- ci/travis/ci.sh | 56 ++- ci/travis/format.sh | 4 + ci/travis/install-dependencies.sh | 2 +- cpp/BUILD.bazel | 1 + cpp/src/ray/api.cc | 2 +- cpp/src/ray/runtime/abstract_ray_runtime.cc | 2 +- .../ray/runtime/object/native_object_store.cc | 2 +- .../runtime/task/local_mode_task_submitter.cc | 4 +- cpp/src/ray/runtime/task/task_executor.cc | 2 +- cpp/src/ray/runtime/task/task_executor.h | 4 +- cpp/src/ray/util/process_helper.cc | 11 +- dashboard/agent.py | 46 +- dashboard/client/src/pages/job/JobDetail.tsx | 11 + dashboard/client/src/pages/job/index.tsx | 19 +- dashboard/client/src/type/job.d.ts | 2 + dashboard/modules/job/job_agent.py | 4 +- .../modules/runtime_env/runtime_env_agent.py | 24 +- dashboard/modules/snapshot/snapshot_head.py | 3 +- .../modules/snapshot/snapshot_schema.json | 4 - dashboard/tests/test_dashboard.py | 2 +- doc/BUILD | 25 + doc/examples/dask_xgboost/README.rst | 1 + doc/examples/dask_xgboost/dask_xgboost.py | 321 ++++++++++++ doc/examples/dask_xgboost/dask_xgboost.yaml | 24 + doc/examples/modin_xgboost/README.rst | 1 + doc/examples/modin_xgboost/modin_xgboost.py | 233 +++++++++ doc/examples/modin_xgboost/modin_xgboost.yaml | 24 + doc/examples/overview.rst | 10 + doc/kubernetes/ray-cluster.yaml | 4 +- doc/source/advanced.rst | 35 +- doc/source/cluster/config.rst | 62 +++ doc/source/cluster/ray-client.rst | 69 ++- doc/source/data/dask-on-ray.rst | 20 +- doc/source/data/dataset-pipeline.rst | 121 ++++- doc/source/data/dataset-tensor-support.rst | 72 +-- doc/source/data/dataset.rst | 8 +- doc/source/data/package-ref.rst | 2 + doc/source/index.rst | 5 +- doc/source/raysgd/raysgd.rst | 2 +- doc/source/raysgd/raysgd_pytorch.rst | 5 +- doc/source/raysgd/raysgd_tensorflow.rst | 5 +- doc/source/raysgd/raysgd_tune.rst | 3 + doc/source/raysgd/v2/api.rst | 21 +- doc/source/raysgd/v2/examples.rst | 3 + .../tune_cifar_pytorch_pbt_example.rst | 6 + doc/source/raysgd/v2/migration-guide.rst | 393 +++++++++++++++ doc/source/raysgd/v2/raysgd.rst | 3 +- doc/source/raysgd/v2/user_guide.rst | 12 +- doc/source/serve/core-apis.rst | 43 +- doc/source/serve/ml-models.rst | 15 +- doc/source/tune/_tutorials/_faq.inc | 55 ++- doc/source/tune/api_docs/suggestion.rst | 3 +- doc/source/tune/user-guide.rst | 10 +- java/BUILD.bazel | 5 + java/dependencies.bzl | 2 + .../java/io/ray/runtime/RayNativeRuntime.java | 18 +- .../runtime/object/LocalModeObjectStore.java | 2 +- .../ray/runtime/object/NativeObjectStore.java | 6 +- .../io/ray/runtime/object/ObjectRefImpl.java | 2 +- .../io/ray/runtime/object/ObjectStore.java | 4 +- java/serve/pom.xml | 15 + .../src/main/java/io/ray/serve/Constants.java | 6 + .../java/io/ray/serve/DeploymentInfo.java | 38 ++ .../io/ray/serve/DummyBackendReplica.java | 12 + .../main/java/io/ray/serve/HandleOptions.java | 15 + .../src/main/java/io/ray/serve/HttpProxy.java | 161 ++++++ .../main/java/io/ray/serve/ProxyActor.java | 175 +++++++ .../main/java/io/ray/serve/ProxyRouter.java | 72 +++ .../java/io/ray/serve/RayServeConfig.java | 6 + .../java/io/ray/serve/RayServeHandle.java | 73 +++ .../java/io/ray/serve/RayServeMetrics.java | 74 +++ .../java/io/ray/serve/RayServeReplica.java | 211 +++++--- .../io/ray/serve/RayServeWrappedReplica.java | 42 +- .../main/java/io/ray/serve/ReplicaConfig.java | 8 +- .../java/io/ray/serve/ReplicaContext.java | 2 +- .../main/java/io/ray/serve/ReplicaSet.java | 138 ++++++ .../src/main/java/io/ray/serve/Router.java | 64 +++ .../java/io/ray/serve/ServeController.java | 6 + .../main/java/io/ray/serve/ServeProxy.java | 14 + .../main/java/io/ray/serve/api/Client.java | 72 +++ .../src/main/java/io/ray/serve/api/Serve.java | 54 +- .../java/io/ray/serve/poll/KeyListener.java | 2 +- .../io/ray/serve/poll/LongPollClient.java | 69 ++- .../io/ray/serve/poll/LongPollNamespace.java | 4 +- .../java/io/ray/serve/poll/UpdatedObject.java | 33 -- .../io/ray/serve/util/CollectionUtil.java | 10 + .../java/io/ray/serve/util/CommonUtil.java | 13 + .../java/io/ray/serve/util/ReflectUtil.java | 14 + .../io/ray/serve/util/ServeProtoUtil.java | 75 ++- .../java/io/ray/serve/util/SocketUtil.java | 49 ++ .../io/ray/serve/DummyServeController.java | 21 + .../test/java/io/ray/serve/HttpProxyTest.java | 74 +++ .../java/io/ray/serve/ProxyActorTest.java | 110 +++++ .../java/io/ray/serve/ProxyRouterTest.java | 68 +++ .../java/io/ray/serve/RayServeHandleTest.java | 76 +++ .../io/ray/serve/RayServeReplicaTest.java | 46 +- .../java/io/ray/serve/ReplicaSetTest.java | 108 ++++ .../test/java/io/ray/serve/RouterTest.java | 80 +++ .../java/io/ray/serve/api/ClientTest.java | 47 ++ .../test/java/io/ray/serve/api/ServeTest.java | 71 ++- .../java/io/ray/serve/poll/KeyTypeTest.java | 15 +- .../io/ray/serve/poll/LongPollClientTest.java | 29 +- python/build-wheel-windows.sh | 7 + python/ray/_private/client_mode_hook.py | 50 +- python/ray/_private/parameter.py | 4 +- python/ray/_private/runtime_env/__init__.py | 3 - python/ray/_private/runtime_env/conda.py | 4 +- .../ray/_private/runtime_env/conda_utils.py | 15 + python/ray/_private/runtime_env/context.py | 13 +- python/ray/_private/runtime_env/plugin.py | 70 +++ python/ray/_private/runtime_env/validation.py | 436 ++++++++++------ .../ray/_private/runtime_env/working_dir.py | 2 +- python/ray/_private/services.py | 43 +- python/ray/_raylet.pxd | 2 +- python/ray/_raylet.pyx | 142 +++--- python/ray/actor.py | 68 +-- python/ray/autoscaler/_private/autoscaler.py | 3 +- python/ray/autoscaler/_private/docker.py | 2 +- python/ray/autoscaler/_private/gcp/node.py | 20 +- python/ray/autoscaler/_private/monitor.py | 5 +- .../_private/resource_demand_scheduler.py | 24 +- python/ray/autoscaler/gcp/tpu.yaml | 18 +- python/ray/cross_language.py | 3 +- python/ray/data/__init__.py | 7 +- python/ray/data/block.py | 24 +- python/ray/data/dataset.py | 381 ++++++++++---- python/ray/data/dataset_pipeline.py | 188 ++++++- python/ray/data/datasource/datasource.py | 16 +- .../data/datasource/file_based_datasource.py | 41 +- .../ray/data/datasource/numpy_datasource.py | 13 +- python/ray/data/examples/demo_infer.py | 2 +- .../ray/data/extensions/tensor_extension.py | 8 +- python/ray/data/impl/arrow_block.py | 20 +- python/ray/data/impl/block_list.py | 7 + python/ray/data/impl/compute.py | 23 +- python/ray/data/impl/lazy_block_list.py | 57 ++- python/ray/data/impl/pipeline_executor.py | 15 +- python/ray/data/impl/progress_bar.py | 12 +- python/ray/data/impl/remote_fn.py | 5 +- python/ray/data/impl/simple_block.py | 4 +- python/ray/data/impl/tensor_block.py | 80 --- python/ray/data/read_api.py | 53 +- python/ray/data/tests/test_dataset.py | 466 ++++++++++++------ .../ray/data/tests/test_dataset_pipeline.py | 89 +++- python/ray/data/tests/test_raydp_dataset.py | 4 + python/ray/exceptions.py | 6 +- python/ray/experimental/array/remote/core.py | 4 +- python/ray/experimental/internal_kv.py | 12 +- python/ray/experimental/raysort/constants.py | 11 +- python/ray/experimental/raysort/main.py | 369 +++++++++----- python/ray/experimental/raysort/sortlib.py | 8 +- .../ray/experimental/raysort/tracing_utils.py | 127 ++++- python/ray/experimental/raysort/types.py | 12 +- python/ray/includes/common.pxd | 6 +- python/ray/includes/libcoreworker.pxd | 5 +- python/ray/job_config.py | 61 +-- python/ray/node.py | 24 +- python/ray/remote_function.py | 77 +-- python/ray/runtime_context.py | 16 +- python/ray/scripts/scripts.py | 7 +- python/ray/serialization.py | 2 +- python/ray/serve/BUILD | 10 +- python/ray/serve/api.py | 154 ++++-- python/ray/serve/autoscaling_metrics.py | 5 +- python/ray/serve/autoscaling_policy.py | 1 - python/ray/serve/backend_state.py | 77 ++- python/ray/serve/common.py | 5 +- python/ray/serve/config.py | 38 +- python/ray/serve/controller.py | 139 ++++-- python/ray/serve/endpoint_state.py | 1 - python/ray/serve/examples/doc/conda_env.py | 23 +- python/ray/serve/handle.py | 15 +- python/ray/serve/http_proxy.py | 7 +- python/ray/serve/long_poll.py | 26 +- .../serve/{backend_worker.py => replica.py} | 54 +- python/ray/serve/tests/conftest.py | 7 + python/ray/serve/tests/test_advanced.py | 12 +- .../serve/tests/test_autoscaling_metrics.py | 12 +- .../serve/tests/test_autoscaling_policy.py | 52 ++ python/ray/serve/tests/test_backend_state.py | 77 +-- python/ray/serve/tests/test_config.py | 2 + python/ray/serve/tests/test_deploy.py | 57 ++- python/ray/serve/tests/test_get_deployment.py | 31 ++ python/ray/serve/tests/test_handle.py | 27 +- python/ray/serve/tests/test_long_poll.py | 14 + python/ray/serve/tests/test_ray_client.py | 5 +- python/ray/serve/tests/test_regression.py | 2 +- python/ray/serve/tests/test_standalone.py | 9 +- python/ray/sgd/__init__.py | 3 +- python/ray/sgd/callbacks.py | 1 + python/ray/state.py | 14 +- python/ray/tests/BUILD | 12 +- python/ray/tests/client_test_utils.py | 17 + python/ray/tests/mock_setup_worker.py | 3 + python/ray/tests/test_advanced.py | 5 +- python/ray/tests/test_advanced_3.py | 13 +- python/ray/tests/test_autoscaler.py | 61 ++- python/ray/tests/test_client.py | 54 +- python/ray/tests/test_client_compat.py | 33 ++ .../tests/test_client_library_integration.py | 8 +- python/ray/tests/test_client_proxy.py | 10 +- python/ray/tests/test_client_reconnect.py | 9 +- python/ray/tests/test_dashboard.py | 52 +- python/ray/tests/test_distributed_sort.py | 19 +- python/ray/tests/test_failure_2.py | 3 +- python/ray/tests/test_multi_tenancy.py | 12 +- python/ray/tests/test_object_manager.py | 48 +- python/ray/tests/test_output.py | 8 +- python/ray/tests/test_placement_group_3.py | 35 ++ python/ray/tests/test_ray_debugger.py | 7 +- python/ray/tests/test_ray_init.py | 40 ++ .../tests/test_resource_demand_scheduler.py | 130 +++-- python/ray/tests/test_runtime_context.py | 115 +++++ python/ray/tests/test_runtime_env.py | 68 +-- .../ray/tests/test_runtime_env_complicated.py | 137 +++-- python/ray/tests/test_runtime_env_env_vars.py | 244 +++------ python/ray/tests/test_runtime_env_plugin.py | 75 +++ .../ray/tests/test_runtime_env_validation.py | 360 ++++++++++++++ python/ray/tests/test_scheduling.py | 99 +++- python/ray/tests/test_traceback.py | 39 ++ .../ray/tune/analysis/experiment_analysis.py | 22 +- python/ray/tune/commands.py | 7 +- python/ray/tune/durable_trainable.py | 14 +- python/ray/tune/function_runner.py | 11 +- python/ray/tune/progress_reporter.py | 96 +++- python/ray/tune/ray_trial_executor.py | 27 +- python/ray/tune/registry.py | 22 +- python/ray/tune/result.py | 4 + python/ray/tune/tests/test_api.py | 8 + python/ray/tune/tests/test_cluster.py | 2 +- .../ray/tune/tests/test_progress_reporter.py | 200 +++++--- .../ray/tune/tests/test_ray_trial_executor.py | 69 ++- python/ray/tune/tests/test_trial_runner_3.py | 3 +- .../tune/tests/test_trial_runner_callbacks.py | 2 +- python/ray/tune/tests/test_trial_scheduler.py | 1 + .../tune/tests/test_trial_scheduler_pbt.py | 12 +- python/ray/tune/trainable.py | 59 ++- python/ray/tune/trial.py | 65 ++- python/ray/tune/trial_runner.py | 95 +++- python/ray/tune/tune.py | 48 +- python/ray/tune/utils/util.py | 9 +- python/ray/util/__init__.py | 2 +- python/ray/util/client/__init__.py | 3 +- python/ray/util/client/client_pickler.py | 15 +- python/ray/util/client/options.py | 1 - python/ray/util/client/server/proxier.py | 6 +- python/ray/util/client/worker.py | 16 +- python/ray/util/dask/scheduler_utils.py | 5 +- python/ray/util/placement_group.py | 4 +- python/ray/util/sgd/torch/torch_runner.py | 14 +- .../ray/util/sgd/torch/training_operator.py | 42 +- python/ray/util/sgd/v2/BUILD | 27 + python/ray/util/sgd/v2/__init__.py | 4 +- python/ray/util/sgd/v2/backends/backend.py | 26 +- python/ray/util/sgd/v2/backends/horovod.py | 2 + python/ray/util/sgd/v2/backends/torch.py | 2 + python/ray/util/sgd/v2/constants.py | 4 + .../v2/examples/tensorflow_mnist_example.py | 4 +- .../tune_cifar_pytorch_pbt_example.py | 200 ++++++++ python/ray/util/sgd/v2/tests/test_backend.py | 4 + python/ray/util/sgd/v2/tests/test_gpu.py | 92 ++++ python/ray/util/sgd/v2/tests/test_trainer.py | 82 +-- python/ray/util/sgd/v2/tests/test_tune.py | 31 +- python/ray/util/tracing/tracing_helper.py | 7 +- python/ray/worker.py | 44 +- python/ray/workers/setup_worker.py | 10 +- python/ray/workflow/common.py | 3 +- python/ray/workflow/execution.py | 7 +- python/ray/workflow/recovery.py | 29 +- python/ray/workflow/step_executor.py | 92 ++-- .../workflow/tests/test_basic_workflows_2.py | 43 +- python/ray/workflow/tests/test_lifetime.py | 26 +- python/ray/workflow/workflow_access.py | 4 +- python/ray/workflow/workflow_context.py | 109 +++- python/ray/workflow/workflow_storage.py | 27 +- python/requirements.txt | 3 +- python/requirements/ml/requirements_rllib.txt | 4 +- python/requirements_linters.txt | 1 + python/setup.py | 27 +- release/.buildkite/build_pipeline.py | 1 + release/RELEASE_CHECKLIST.md | 1 + release/RELEASE_PROCESS.rst | 3 + release/alerts/xgboost_tests.py | 4 +- release/e2e.py | 207 +++++--- .../dask_xgboost_app_config.yaml | 5 +- .../golden_notebook_tests.yaml | 21 +- .../modin_xgboost_app_config.yaml | 5 +- .../workloads/dask_xgboost_test.py | 123 +---- .../workloads/modin_xgboost_test.py | 119 +---- .../workloads/torch_tune_serve_test.py | 4 +- .../golden_notebook_tests/workloads/util.py | 49 ++ .../workloads/utils/utils.py | 5 - release/kubernetes_manual_tests/README.md | 25 + release/kubernetes_manual_tests/helm-test.sh | 8 + .../kubernetes_manual_tests/k8s-test-scale.sh | 11 + release/kubernetes_manual_tests/k8s-test.sh | 9 + .../k8s_release_tests.sh | 30 ++ release/long_running_tests/tpl_cpu_1.yaml | 5 + .../large_scale_dask_on_ray_app_config.yaml | 1 - release/nightly_tests/dataset/app_config.yaml | 1 - .../dataset/dataset_shuffle_data_loader.py | 2 +- .../dataset/pipelined_ingestion_app.yaml | 1 - .../dataset/pipelined_training.py | 4 +- .../dataset/pipelined_training_app.yaml | 1 - .../dataset/shuffle_app_config.yaml | 1 - .../decision_tree_app_config.yaml | 1 - .../many_nodes_tests/app_config.yaml | 2 +- release/nightly_tests/nightly_tests.yaml | 25 +- .../placement_group_tests/app_config.yaml | 12 + .../placement_group_tests/cluster.py | 13 + .../placement_group_tests/compute.yaml | 27 + .../placement_group_tests/pg_run.py | 65 +++ .../shuffle/shuffle_app_config.yaml | 2 - .../shuffle_data_loader_app_config.yaml | 1 - .../stress_tests/stress_tests_app_config.yaml | 1 - .../1.7.0/benchmarks/many_actors.txt | 10 + .../1.7.0/benchmarks/many_nodes.txt | 10 + .../1.7.0/benchmarks/many_pgs.txt | 10 + .../1.7.0/benchmarks/many_tasks.txt | 10 + release/release_logs/1.7.0/microbenchmark.txt | 134 +++++ .../1.7.0/scalability/object_store.txt | 10 + .../1.7.0/scalability/single_node.txt | 16 + .../1.7.0/stress_tests/dead_actors.txt | 11 + .../1.7.0/stress_tests/many_tasks.txt | 19 + .../1.7.0/stress_tests/placement_group.txt | 9 + release/util/pip_download_test.sh | 2 +- rllib/BUILD | 89 +++- rllib/agents/a3c/a3c_tf_policy.py | 2 +- rllib/agents/a3c/a3c_torch_policy.py | 18 +- rllib/agents/a3c/tests/test_a2c.py | 11 +- rllib/agents/a3c/tests/test_a3c.py | 3 +- rllib/agents/ars/tests/test_ars.py | 10 +- rllib/agents/cql/cql.py | 3 +- rllib/agents/cql/cql_torch_policy.py | 67 +-- rllib/agents/cql/tests/test_cql.py | 11 +- rllib/agents/ddpg/ddpg_tf_model.py | 12 +- rllib/agents/ddpg/ddpg_tf_policy.py | 8 +- rllib/agents/ddpg/ddpg_torch_model.py | 12 +- rllib/agents/ddpg/ddpg_torch_policy.py | 37 +- rllib/agents/ddpg/tests/test_apex_ddpg.py | 6 +- rllib/agents/ddpg/tests/test_ddpg.py | 8 +- rllib/agents/ddpg/tests/test_td3.py | 3 +- rllib/agents/dqn/apex.py | 3 +- rllib/agents/dqn/dqn.py | 16 +- rllib/agents/dqn/dqn_torch_policy.py | 46 +- rllib/agents/dqn/learner_thread.py | 24 +- rllib/agents/dqn/r2d2.py | 14 +- rllib/agents/dqn/r2d2_tf_policy.py | 6 +- rllib/agents/dqn/r2d2_torch_policy.py | 44 +- rllib/agents/dqn/simple_q_tf_policy.py | 2 +- rllib/agents/dqn/simple_q_torch_policy.py | 17 +- rllib/agents/dqn/tests/test_apex_dqn.py | 15 +- rllib/agents/dqn/tests/test_dqn.py | 4 +- rllib/agents/dqn/tests/test_r2d2.py | 3 +- rllib/agents/dqn/tests/test_simple_q.py | 3 +- rllib/agents/dreamer/dreamer.py | 3 +- rllib/agents/impala/tests/test_impala.py | 12 +- rllib/agents/impala/vtrace_tf_policy.py | 26 +- rllib/agents/impala/vtrace_torch_policy.py | 45 +- rllib/agents/maml/maml.py | 17 +- rllib/agents/maml/tests/test_maml.py | 6 +- rllib/agents/marwil/tests/test_bc.py | 8 +- rllib/agents/marwil/tests/test_marwil.py | 8 +- rllib/agents/mbmpo/mbmpo.py | 17 +- rllib/agents/mbmpo/tests/test_mbmpo.py | 8 +- rllib/agents/pg/pg_torch_policy.py | 14 +- rllib/agents/pg/tests/test_pg.py | 10 +- rllib/agents/ppo/appo_tf_policy.py | 2 +- rllib/agents/ppo/appo_torch_policy.py | 47 +- rllib/agents/ppo/ddppo.py | 15 +- rllib/agents/ppo/ppo.py | 12 +- rllib/agents/ppo/ppo_torch_policy.py | 33 +- rllib/agents/ppo/tests/test_appo.py | 16 +- rllib/agents/ppo/tests/test_ddppo.py | 26 +- rllib/agents/ppo/tests/test_ppo.py | 29 +- rllib/agents/qmix/qmix_policy.py | 2 +- rllib/agents/sac/rnnsac.py | 7 - rllib/agents/sac/rnnsac_torch_policy.py | 32 +- rllib/agents/sac/sac_tf_model.py | 10 +- rllib/agents/sac/sac_tf_policy.py | 8 +- rllib/agents/sac/sac_torch_model.py | 8 +- rllib/agents/sac/sac_torch_policy.py | 62 +-- rllib/agents/sac/tests/test_rnnsac.py | 73 +++ rllib/agents/sac/tests/test_sac.py | 38 +- rllib/agents/tests/test_trainer.py | 3 +- rllib/agents/trainer.py | 299 +++++++---- .../alpha_zero/core/alpha_zero_policy.py | 7 +- rllib/contrib/bandits/agents/policy.py | 2 +- .../bandits/examples/LinTS_train_wheel_env.py | 3 +- rllib/contrib/maddpg/maddpg_policy.py | 2 +- rllib/contrib/sumo/connector.py | 5 +- rllib/env/base_env.py | 7 +- rllib/env/multi_agent_env.py | 3 +- rllib/env/policy_server_input.py | 16 +- rllib/env/remote_vector_env.py | 20 +- rllib/env/tests/test_local_inference.sh | 42 -- .../tests/test_policy_client_server_setup.sh | 63 +++ rllib/env/tests/test_remote_inference.sh | 41 -- rllib/env/tests/test_remote_worker_envs.py | 98 ++++ rllib/env/wrappers/unity3d_env.py | 16 +- .../collectors/simple_list_collector.py | 7 +- rllib/evaluation/metrics.py | 21 +- rllib/evaluation/rollout_worker.py | 30 +- rllib/examples/centralized_critic.py | 2 +- rllib/examples/custom_keras_model.py | 5 +- .../examples/custom_model_loss_and_metrics.py | 12 +- rllib/examples/deterministic_training.py | 6 +- .../env/coin_game_non_vectorized_env.py | 11 +- .../examples/env/coin_game_vectorized_env.py | 9 +- .../env/matrix_sequential_social_dilemma.py | 6 +- rllib/examples/env/random_env.py | 26 +- rllib/examples/pettingzoo_env.py | 18 +- .../remote_vector_env_with_custom_api.py | 3 +- .../rock_paper_scissors_multiagent.py | 8 +- rllib/examples/serving/cartpole_client.py | 2 +- rllib/examples/serving/unity3d_client.py | 14 +- .../examples/serving/unity3d_dummy_client.py | 144 ++++++ rllib/examples/serving/unity3d_server.py | 70 ++- rllib/examples/trajectory_view_api.py | 50 +- rllib/execution/common.py | 3 - rllib/execution/learner_thread.py | 25 +- rllib/execution/multi_gpu_learner_thread.py | 68 ++- rllib/execution/rollout_ops.py | 24 +- rllib/execution/train_ops.py | 79 ++- rllib/models/tests/test_preprocessors.py | 6 +- rllib/models/tf/complex_input_net.py | 8 +- rllib/models/torch/complex_input_net.py | 9 +- rllib/models/torch/torch_modelv2.py | 8 + rllib/policy/eager_tf_policy.py | 118 +++-- rllib/policy/policy.py | 114 ++--- rllib/policy/policy_template.py | 3 +- rllib/policy/sample_batch.py | 7 +- .../tests/test_compute_log_likelihoods.py | 2 +- rllib/policy/tf_policy.py | 31 +- rllib/policy/tf_policy_template.py | 9 +- rllib/policy/torch_policy.py | 38 +- rllib/tests/test_exec_api.py | 3 +- rllib/tests/test_supported_multi_agent.py | 26 +- rllib/tests/test_supported_spaces.py | 8 +- rllib/utils/__init__.py | 3 +- .../utils/exploration/stochastic_sampling.py | 20 +- rllib/utils/metrics/__init__.py | 0 rllib/utils/metrics/learner_info.py | 84 ++++ rllib/utils/multi_agent.py | 21 +- rllib/utils/sgd.py | 55 +-- rllib/utils/test_utils.py | 316 +++++++++--- rllib/utils/tf_ops.py | 2 +- rllib/utils/tf_run_builder.py | 5 +- rllib/utils/torch_ops.py | 6 +- .../ray/gcs/gcs_server/gcs_node_manager.h | 1 + .../gcs_placement_group_scheduler.h | 65 +-- .../ray/gcs/gcs_server/gcs_resource_manager.h | 1 + src/mock/ray/gcs/pubsub/gcs_pub_sub.h | 27 + .../gcs/store_client/in_memory_store_client.h | 66 +++ .../ray/gcs/store_client/redis_store_client.h | 67 +++ src/mock/ray/gcs/store_client/store_client.h | 66 +++ src/mock/ray/pubsub/publisher.h | 100 ++++ src/mock/ray/pubsub/subscriber.h | 155 ++++++ src/mock/ray/raylet_client/raylet_client.h | 43 +- src/mock/ray/rpc/worker/core_worker_client.h | 123 +++++ .../ray/rpc/worker/core_worker_client_pool.h | 23 + src/ray/common/bundle_spec.cc | 25 +- src/ray/common/bundle_spec.h | 6 + src/ray/common/client_connection.cc | 2 +- src/ray/common/id.h | 1 + src/ray/common/network_util.h | 2 +- src/ray/common/ray_config_def.h | 22 +- src/ray/common/runtime_env_manager.cc | 16 +- src/ray/common/runtime_env_manager.h | 7 +- src/ray/common/task/task_spec.cc | 32 +- src/ray/common/task/task_spec.h | 16 +- src/ray/common/task/task_util.h | 11 +- src/ray/core_worker/common.h | 24 +- src/ray/core_worker/context.cc | 15 +- src/ray/core_worker/context.h | 9 +- src/ray/core_worker/core_worker.cc | 315 ++++++------ src/ray/core_worker/core_worker.h | 105 ++-- ...io_ray_runtime_object_NativeObjectStore.cc | 5 +- .../io_ray_runtime_object_NativeObjectStore.h | 7 +- ...io_ray_runtime_task_NativeTaskSubmitter.cc | 6 +- src/ray/core_worker/reference_count.h | 2 - src/ray/core_worker/reference_count_test.cc | 6 +- .../memory_store/memory_store.cc | 40 +- .../memory_store/memory_store.h | 16 - src/ray/core_worker/test/core_worker_test.cc | 2 +- .../test/direct_task_transport_mock_test.cc | 4 +- .../test/direct_task_transport_test.cc | 114 ++++- src/ray/core_worker/test/memory_store_test.cc | 11 +- .../transport/dependency_resolver.cc | 6 +- src/ray/gcs/asio.h | 2 +- .../gcs/gcs_client/service_based_accessor.cc | 2 +- .../test/global_state_accessor_test.cc | 1 + .../test/service_based_gcs_client_test.cc | 1 + .../gcs/gcs_server/gcs_actor_distribution.cc | 30 ++ .../gcs/gcs_server/gcs_actor_distribution.h | 17 + src/ray/gcs/gcs_server/gcs_actor_manager.cc | 67 ++- src/ray/gcs/gcs_server/gcs_actor_manager.h | 13 +- src/ray/gcs/gcs_server/gcs_actor_scheduler.cc | 2 +- src/ray/gcs/gcs_server/gcs_actor_scheduler.h | 10 + src/ray/gcs/gcs_server/gcs_node_manager.cc | 8 +- .../gcs_server/gcs_placement_group_manager.cc | 138 ++++-- .../gcs_server/gcs_placement_group_manager.h | 52 +- .../gcs_placement_group_scheduler.cc | 28 +- .../gcs_placement_group_scheduler.h | 42 +- src/ray/gcs/gcs_server/gcs_server.cc | 19 +- src/ray/gcs/gcs_server/gcs_server.h | 1 - src/ray/gcs/gcs_server/gcs_table_storage.h | 130 +++-- .../gcs_server/test/gcs_actor_manager_test.cc | 15 +- .../test/gcs_actor_scheduler_mock_test.cc | 139 ++++++ .../test/gcs_based_actor_scheduler_test.cc | 18 +- .../gcs_placement_group_manager_mock_test.cc | 174 +++++++ .../test/gcs_placement_group_manager_test.cc | 47 +- .../gcs_server/test/gcs_server_rpc_test.cc | 1 + .../gcs_server/test/gcs_server_test_util.h | 10 +- src/ray/gcs/pubsub/gcs_pub_sub.h | 3 + src/ray/gcs/redis_context.h | 2 +- src/ray/object_manager/object_buffer_pool.cc | 165 +++++-- src/ray/object_manager/object_buffer_pool.h | 59 ++- src/ray/object_manager/object_manager.cc | 24 +- src/ray/object_manager/object_manager.h | 6 +- src/ray/object_manager/plasma/store.cc | 7 +- src/ray/object_manager/pull_manager.h | 2 +- src/ray/protobuf/agent_manager.proto | 1 + src/ray/protobuf/common.proto | 23 +- src/ray/protobuf/core_worker.proto | 1 + src/ray/protobuf/event.proto | 1 + src/ray/protobuf/gcs.proto | 31 +- src/ray/protobuf/gcs_service.proto | 2 +- src/ray/protobuf/job_agent.proto | 1 + src/ray/protobuf/node_manager.proto | 1 + src/ray/protobuf/object_manager.proto | 1 + src/ray/protobuf/pubsub.proto | 1 + src/ray/protobuf/ray_client.proto | 5 +- src/ray/protobuf/reporter.proto | 1 + src/ray/protobuf/runtime_env_agent.proto | 5 + src/ray/protobuf/serialization.proto | 1 + src/ray/protobuf/serve.proto | 47 +- src/ray/ray_version_script.lds | 1 - src/ray/raylet/agent_manager.cc | 54 +- src/ray/raylet/agent_manager.h | 9 +- src/ray/raylet/main.cc | 1 + src/ray/raylet/node_manager.cc | 11 +- src/ray/raylet/node_manager.h | 8 - .../placement_group_resource_manager.cc | 3 + src/ray/raylet/raylet.cc | 2 +- .../scheduling/cluster_resource_data.cc | 5 +- .../raylet/scheduling/cluster_resource_data.h | 4 +- .../scheduling/cluster_resource_scheduler.cc | 3 +- .../raylet/scheduling/cluster_task_manager.cc | 27 +- .../scheduling/cluster_task_manager_test.cc | 139 ++++-- src/ray/raylet/scheduling/fixed_point.cc | 96 ---- src/ray/raylet/scheduling/fixed_point.h | 115 +++-- .../raylet/scheduling/scheduling_policy.cc | 20 +- src/ray/raylet/scheduling/scheduling_policy.h | 13 +- .../scheduling/scheduling_policy_test.cc | 36 ++ src/ray/raylet/worker.cc | 2 +- src/ray/raylet/worker_pool.cc | 124 +++-- src/ray/raylet/worker_pool.h | 7 +- src/ray/raylet/worker_pool_test.cc | 14 +- src/ray/raylet_client/raylet_client.cc | 19 +- src/ray/raylet_client/raylet_client.h | 15 +- src/ray/rpc/grpc_server.cc | 9 +- src/ray/rpc/grpc_server.h | 13 +- src/ray/rpc/server_call.h | 17 +- src/ray/rpc/test/grpc_server_client_test.cc | 18 +- src/ray/util/event.h | 10 +- src/ray/util/util.h | 44 +- src/ray/util/util_test.cc | 17 + streaming/src/queue/queue_handler.h | 2 +- streaming/src/test/mock_actor.cc | 2 +- .../patches/prometheus-windows-pollfd.patch | 37 +- ...les_boost-undefine-boost_fallthrough.patch | 8 - .../rules_boost-windows-linkopts.patch | 21 +- 588 files changed, 14996 insertions(+), 5084 deletions(-) create mode 100644 .buildkite/pipeline.gpu.large.yml create mode 100644 doc/examples/dask_xgboost/README.rst create mode 100644 doc/examples/dask_xgboost/dask_xgboost.py create mode 100644 doc/examples/dask_xgboost/dask_xgboost.yaml create mode 100644 doc/examples/modin_xgboost/README.rst create mode 100644 doc/examples/modin_xgboost/modin_xgboost.py create mode 100644 doc/examples/modin_xgboost/modin_xgboost.yaml create mode 100644 doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst create mode 100644 doc/source/raysgd/v2/migration-guide.rst create mode 100644 java/serve/src/main/java/io/ray/serve/DeploymentInfo.java create mode 100644 java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java create mode 100644 java/serve/src/main/java/io/ray/serve/HandleOptions.java create mode 100644 java/serve/src/main/java/io/ray/serve/HttpProxy.java create mode 100644 java/serve/src/main/java/io/ray/serve/ProxyActor.java create mode 100644 java/serve/src/main/java/io/ray/serve/ProxyRouter.java create mode 100644 java/serve/src/main/java/io/ray/serve/RayServeConfig.java create mode 100644 java/serve/src/main/java/io/ray/serve/RayServeHandle.java create mode 100644 java/serve/src/main/java/io/ray/serve/RayServeMetrics.java create mode 100644 java/serve/src/main/java/io/ray/serve/ReplicaSet.java create mode 100644 java/serve/src/main/java/io/ray/serve/Router.java create mode 100644 java/serve/src/main/java/io/ray/serve/ServeController.java create mode 100644 java/serve/src/main/java/io/ray/serve/ServeProxy.java create mode 100644 java/serve/src/main/java/io/ray/serve/api/Client.java delete mode 100644 java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java create mode 100644 java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java create mode 100644 java/serve/src/main/java/io/ray/serve/util/CommonUtil.java create mode 100644 java/serve/src/main/java/io/ray/serve/util/SocketUtil.java create mode 100644 java/serve/src/test/java/io/ray/serve/DummyServeController.java create mode 100644 java/serve/src/test/java/io/ray/serve/HttpProxyTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/ProxyActorTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/RouterTest.java create mode 100644 java/serve/src/test/java/io/ray/serve/api/ClientTest.java create mode 100644 python/ray/_private/runtime_env/plugin.py delete mode 100644 python/ray/data/impl/tensor_block.py rename python/ray/serve/{backend_worker.py => replica.py} (91%) create mode 100644 python/ray/sgd/callbacks.py create mode 100644 python/ray/tests/test_client_compat.py create mode 100644 python/ray/tests/test_runtime_env_plugin.py create mode 100644 python/ray/tests/test_runtime_env_validation.py create mode 100644 python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py create mode 100644 python/ray/util/sgd/v2/tests/test_gpu.py create mode 100644 release/golden_notebook_tests/workloads/util.py delete mode 100644 release/golden_notebook_tests/workloads/utils/utils.py create mode 100644 release/kubernetes_manual_tests/README.md create mode 100755 release/kubernetes_manual_tests/helm-test.sh create mode 100755 release/kubernetes_manual_tests/k8s-test-scale.sh create mode 100755 release/kubernetes_manual_tests/k8s-test.sh create mode 100644 release/kubernetes_manual_tests/k8s_release_tests.sh create mode 100644 release/nightly_tests/placement_group_tests/app_config.yaml create mode 100644 release/nightly_tests/placement_group_tests/cluster.py create mode 100644 release/nightly_tests/placement_group_tests/compute.yaml create mode 100644 release/nightly_tests/placement_group_tests/pg_run.py create mode 100644 release/release_logs/1.7.0/benchmarks/many_actors.txt create mode 100644 release/release_logs/1.7.0/benchmarks/many_nodes.txt create mode 100644 release/release_logs/1.7.0/benchmarks/many_pgs.txt create mode 100644 release/release_logs/1.7.0/benchmarks/many_tasks.txt create mode 100644 release/release_logs/1.7.0/microbenchmark.txt create mode 100644 release/release_logs/1.7.0/scalability/object_store.txt create mode 100644 release/release_logs/1.7.0/scalability/single_node.txt create mode 100644 release/release_logs/1.7.0/stress_tests/dead_actors.txt create mode 100644 release/release_logs/1.7.0/stress_tests/many_tasks.txt create mode 100644 release/release_logs/1.7.0/stress_tests/placement_group.txt create mode 100644 rllib/agents/sac/tests/test_rnnsac.py delete mode 100755 rllib/env/tests/test_local_inference.sh create mode 100755 rllib/env/tests/test_policy_client_server_setup.sh delete mode 100755 rllib/env/tests/test_remote_inference.sh create mode 100644 rllib/env/tests/test_remote_worker_envs.py create mode 100644 rllib/examples/serving/unity3d_dummy_client.py create mode 100644 rllib/utils/metrics/__init__.py create mode 100644 rllib/utils/metrics/learner_info.py create mode 100644 src/mock/ray/gcs/pubsub/gcs_pub_sub.h create mode 100644 src/mock/ray/gcs/store_client/in_memory_store_client.h create mode 100644 src/mock/ray/gcs/store_client/redis_store_client.h create mode 100644 src/mock/ray/gcs/store_client/store_client.h create mode 100644 src/mock/ray/pubsub/publisher.h create mode 100644 src/mock/ray/pubsub/subscriber.h create mode 100644 src/mock/ray/rpc/worker/core_worker_client.h create mode 100644 src/mock/ray/rpc/worker/core_worker_client_pool.h create mode 100644 src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc create mode 100644 src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc delete mode 100644 src/ray/raylet/scheduling/fixed_point.cc delete mode 100644 thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch diff --git a/.bazelrc b/.bazelrc index a6ebeba272c0f..2e4e7b36d10f9 100644 --- a/.bazelrc +++ b/.bazelrc @@ -14,12 +14,13 @@ build:macos --copt="-g1" build:linux --cxxopt="-std=c++17" build:macos --cxxopt="-std=c++17" build:clang-cl --cxxopt="-std=c++17" -build:msvc --cxxopt="/std:c++17" +build:msvc-cl --cxxopt="/std:c++17" +build:windows --cxxopt="/std:c++17" # This workaround is needed to prevent Bazel from compiling the same file twice (once PIC and once not). build:linux --force_pic build:macos --force_pic build:clang-cl --compiler=clang-cl -build:msvc --compiler=msvc-cl +build:msvc-cl --compiler=msvc-cl # `LC_ALL` and `LANG` is needed for cpp worker tests, because they will call "ray start". # If we don't add them, python's `click` library will raise an error. build --action_env=LC_ALL @@ -38,7 +39,7 @@ build:windows --enable_runfiles build:linux --per_file_copt="-\\.(asm|S)$@-Werror" build:macos --per_file_copt="-\\.(asm|S)$@-Werror" build:clang-cl --per_file_copt="-\\.(asm|S)$@-Werror" -build:msvc --per_file_copt="-\\.(asm|S)$@-WX" +build:msvc-cl --per_file_copt="-\\.(asm|S)$@-WX" # Ignore warnings for protobuf generated files and external projects. build --per_file_copt="\\.pb\\.cc$@-w" build --per_file_copt="-\\.(asm|S)$,external/.*@-w" @@ -51,7 +52,7 @@ build --per_file_copt="-\\.(asm|S)$,external/com_github_grpc_grpc/.*@-DGRPC_BAZE # Don't generate warnings about kernel features we don't need https://github.com/ray-project/ray/issues/6832 build:linux --per_file_copt="-\\.(asm|S)$,external/com_github_grpc_grpc/.*@-DGPR_MANYLINUX1" # Ignore wchar_t -> char conversion warning on MSVC -build:msvc --per_file_copt="external/boost/libs/regex/src/wc_regex_traits\\.cpp@-wd4244" +build:msvc-cl --per_file_copt="external/boost/libs/regex/src/wc_regex_traits\\.cpp@-wd4244" build --http_timeout_scaling=5.0 build --verbose_failures build:iwyu --experimental_action_listener=//:iwyu_cpp @@ -177,6 +178,7 @@ build:debug --strip="never" # Undefined Behavior Sanitizer build:ubsan --strip=never build:ubsan --copt -fsanitize=undefined +build:ubsan --copt -fno-sanitize=vptr build:ubsan --copt -fno-sanitize-recover=all build:ubsan --copt -g build:ubsan --linkopt -fsanitize=undefined diff --git a/.buildkite/pipeline.gpu.large.yml b/.buildkite/pipeline.gpu.large.yml new file mode 100644 index 0000000000000..0bdbca8846841 --- /dev/null +++ b/.buildkite/pipeline.gpu.large.yml @@ -0,0 +1,8 @@ +- label: ":tv: :octopus: SGD GPU tests " + conditions: ["RAY_CI_SGD_AFFECTED"] + commands: + - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT + - SGD_TESTING=1 INSTALL_HOROVOD=1 ./ci/travis/install-dependencies.sh + - pip install -Ur ./python/requirements_ml_docker.txt + - ./ci/travis/env_info.sh + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=gpu,gpu_only python/ray/util/sgd/... diff --git a/.buildkite/pipeline.gpu.yml b/.buildkite/pipeline.gpu.yml index 0c2c14ecf805f..eaf5a55e3b155 100644 --- a/.buildkite/pipeline.gpu.yml +++ b/.buildkite/pipeline.gpu.yml @@ -1,3 +1,13 @@ +# Todo: Enable once tests are available +#- label: ":tv: :octopus: Tune GPU tests " +# conditions: ["RAY_CI_TUNE_AFFECTED"] +# commands: +# - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT +# - TUNE_TESTING=1 ./ci/travis/install-dependencies.sh +# - pip install -Ur ./python/requirements_ml_docker.txt +# - ./ci/travis/env_info.sh +# - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-jenkins_only,-flaky,gpu,gpu_only python/ray/tune/... + - label: ":tv: :brain: RLlib: GPU Examples {A/B}" conditions: ["RAY_CI_RLLIB_AFFECTED"] commands: diff --git a/.buildkite/pipeline.macos.yml b/.buildkite/pipeline.macos.yml index 592347d44007c..4efee1df1351c 100644 --- a/.buildkite/pipeline.macos.yml +++ b/.buildkite/pipeline.macos.yml @@ -64,7 +64,7 @@ steps: commands: - *prelude_commands - TORCH_VERSION=1.6 ./ci/travis/install-dependencies.sh - - bazel test --config=ci --test_env=CI $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-flaky,-flaky-mac -- + - bazel test --config=ci --test_env=CI $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-flaky,-flaky-mac,-post_wheel_build -- //:all python/ray/serve/... python/ray/dashboard/... -rllib/... -core_worker_test - *epilogue_commands diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index c0f6ccda286df..3058abe413175 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -182,7 +182,9 @@ - TORCH_VERSION=1.6 ./ci/travis/install-dependencies.sh - ./dashboard/tests/run_ui_tests.sh - bazel test --config=ci $(./scripts/bazel_export_options) python/ray/dashboard/... - - bazel test --config=ci $(./scripts/bazel_export_options) python/ray/serve/... + - bazel test --config=ci $(./scripts/bazel_export_options) + --test_tag_filters=-post_wheel_build + python/ray/serve/... - label: ":python: Minimal install" conditions: ["RAY_CI_PYTHON_AFFECTED"] @@ -462,16 +464,16 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - TUNE_TESTING=1 ./ci/travis/install-dependencies.sh - - bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,-example,-flaky,-py37,-soft_imports python/ray/tune/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=example,-tf,-pytorch,-py37,-flaky,-soft_imports python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,-example,-flaky,-py37,-soft_imports,-gpu_only python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=example,-tf,-pytorch,-py37,-flaky,-soft_imports,-gpu_only python/ray/tune/... - label: ":octopus: Tune tests and examples {2/2}" conditions: ["RAY_CI_TUNE_AFFECTED"] commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - TUNE_TESTING=1 ./ci/travis/install-dependencies.sh - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-soft_imports python/ray/tune/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-soft_imports python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-soft_imports,-gpu_only python/ray/tune/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-soft_imports,-gpu_only python/ray/tune/... - label: ":octopus: Tune soft imports test" conditions: ["RAY_CI_TUNE_AFFECTED"] @@ -486,10 +488,10 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - SGD_TESTING=1 INSTALL_HOROVOD=1 ./ci/travis/install-dependencies.sh - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-client python/ray/util/sgd/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-client python/ray/util/sgd/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=client_unit_tests --test_env=RAY_CLIENT_MODE=1 python/ray/util/sgd/... - - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only python/ray/util/sgd/v2/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-client,-gpu_only python/ray/util/sgd/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-client,-gpu_only python/ray/util/sgd/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=client_unit_tests,-gpu_only --test_env=RAY_CLIENT_MODE=1 python/ray/util/sgd/... + - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-gpu_only python/ray/util/sgd/v2/... - label: ":octopus: Tune/SGD/Modin/Dask tests and examples. Python 3.7" conditions: ["RAY_CI_TUNE_AFFECTED", "RAY_CI_SGD_AFFECTED"] diff --git a/.clang-tidy b/.clang-tidy index 2aa176da910cc..607f19902f3f4 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,27 +1,64 @@ -# Disable the following checks due to frequent false positives, noisiness, -# inconsistent style with existing codebase and other reasons: +# Disable the following checks with reasons in parenthesis: +# +# -bugprone-macro-parentheses (inconsistent style) +# -google-readability-todo (potentially too restrictive) # -misc-non-private-member-variables-in-classes (potentially too restrictive) # -misc-unused-parameters (can be cleaned up in batch and enabled) # -modernize-avoid-c-arrays (too restrictive) +# -modernize-concat-nested-namespaces (inconsistent style) # -modernize-pass-by-value (too restrictive) # -modernize-return-braced-init-list (inconsistent style) # -modernize-use-emplace (more subtle behavior) +# -modernize-use-nodiscard (too much noise) # -modernize-use-trailing-return-type (inconsistent style) +# -modernize-avoid-bind (incorrect conversion) +# -modernize-loop-convert (more subtle behavior) +# -modernize-replace-disallow-copy-and-assign-macro (inconsistent style) +# -modernize-make-unique (doesn't work with private constructor) +# -modernize-make-shared (doesn't work with private constructor) +# Other readability-* rules (potentially too noisy, inconsistent style) +# Other rules not mentioned here or below (not yet evaluated) # # TODO: enable google-* and readability-* families of checks. Checks: > abseil-*, bugprone-*, + -bugprone-macro-parentheses, + google-*, + -google-readability-todo, misc-*, -misc-non-private-member-variables-in-classes, -misc-unused-parameters, modernize-*, -modernize-avoid-c-arrays, + -modernize-concat-nested-namespaces, -modernize-pass-by-value, -modernize-return-braced-init-list, -modernize-use-emplace, + -modernize-use-nodiscard, -modernize-use-trailing-return-type, + -modernize-avoid-bind, + -modernize-loop-convert, + -modernize-replace-disallow-copy-and-assign-macro, + -modernize-make-unique, + -modernize-make-shared, performance-*, + readability-avoid-const-params-in-decls, + readability-braces-around-statements, + readability-const-return-type, + readability-container-size-empty, + readability-delete-null-pointer, + readability-else-after-return, + readability-implicit-bool-conversion, + readability-make-member-function-const, + readability-misleading-indentation, + readability-misplaced-array-index, + readability-named-parameter, + readability-non-const-parameter, + readability-redundant-*, + readability-static-definition-in-anonymous-namespace, + readability-string-compare, + readability-suspicious-call-argument, CheckOptions: # Reduce noisiness of the bugprone-narrowing-conversions check. diff --git a/.flake8 b/.flake8 index a4a3510a1bbeb..cb93e3096d3ef 100644 --- a/.flake8 +++ b/.flake8 @@ -24,4 +24,20 @@ ignore = W605 I N + B001 + B002 + B003 + B004 + B005 + B007 + B008 + B009 + B010 + B011 + B012 + B013 + B014 + B015 + B016 + B017 avoid-escape = no diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c4e254c2dd0f9..3502b7042bf20 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -18,6 +18,9 @@ # Dependencies /python/setup.py @richardliaw @ericl @edoakes +# Formatting tool +/ci/travis/format.sh @richardliaw @ericl @edoakes + # Python worker. #/python/ray/ @ray-project/ray-core-python #!/python/ray/tune/ @ray-project/ray-core-python @@ -30,7 +33,6 @@ /java/*/pom_template.xml @jovany-wang @kfstorm @raulchen /java/api/ @jovany-wang @kfstorm @raulchen - # Ray Client /src/ray/protobuf/ray_client.proto @ijrsvt @ameerhajali @ckw017 @mwtian @@ -39,6 +41,14 @@ # Ray tune. /python/ray/tune/ @ray-project/ray-tune +# Ray data. +/python/ray/data/ @ericl @scv119 +/doc/source/data/ @ericl @scv119 + +# Ray workflows. +/python/ray/workflow/ @ericl @iycheng +/doc/source/workflows/ @ericl @iycheng + # RLlib. #/python/ray/rllib/ @ray-project/rllib diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8df9fe895df63..9404a4a4d2517 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,7 +26,7 @@ jobs: os: windows-2019 python-version: 3.8 # Can be 'msvc' or 'clang-cl' - config: msvc + config: msvc-cl env: BAZEL_CONFIG: ${{ matrix.config }} PYTHON: ${{ matrix.python-version }} @@ -111,7 +111,6 @@ jobs: TRAVIS_COMMIT: ${{ github.sha }} TRAVIS_JOB_ID: ${{ github.run_id }} run: | - # Multi thread in windowns for grpc not working now function clean_up() { echo "Performing cleanup" if [ "${GITHUB_EVENT_NAME}" != "pull_request" ]; then ./ci/travis/upload_build_info.sh; fi diff --git a/.gitpod/Dockerfile b/.gitpod/Dockerfile index 23682c0ed9687..ce2af682e0ed9 100644 --- a/.gitpod/Dockerfile +++ b/.gitpod/Dockerfile @@ -15,7 +15,7 @@ RUN set -x; apt update \ && mv bazel.gpg /etc/apt/trusted.gpg.d/ \ && echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list \ && apt update && apt install bazel-3.7.2 -y \ - && pip3 install cython==0.29.0 pytest pandas tree tabulate pexpect sklearn joblib yapf==0.23.0 flake8==3.9.1 mypy==0.782 flake8-quotes setproctitle==1.1.10 psutil \ + && pip3 install cython==0.29.0 pytest pandas tree tabulate pexpect sklearn joblib yapf==0.23.0 flake8==3.9.1 mypy==0.782 flake8-quotes flake8-bugbear==21.9.2 setproctitle==1.1.10 psutil \ && python3 -c 'print("startup --output_base=/workspace/ray/.bazel-cache\nstartup --host_jvm_args=-Xmx1800m\nbuild --jobs=6")' > /etc/bazel.bazelrc RUN update-alternatives --install /usr/local/bin/python python /usr/bin/python3 30 \ diff --git a/BUILD.bazel b/BUILD.bazel index ad6bd083fd4ad..0483d1bf062ae 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -414,7 +414,6 @@ cc_library( ], ) + [ "src/ray/raylet/scheduling/cluster_resource_data.cc", - "src/ray/raylet/scheduling/fixed_point.cc", "src/ray/raylet/scheduling/scheduling_ids.cc", ], hdrs = glob( @@ -553,6 +552,7 @@ cc_library( ":pubsub_lib", ":raylet_client_lib", ":worker_rpc", + "@com_google_absl//absl/container:btree", ], ) @@ -1181,6 +1181,22 @@ cc_test( ], ) +cc_test( + name = "gcs_placement_group_manager_mock_test", + size = "small", + srcs = [ + "src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc", + ], + copts = COPTS, + tags = ["team:core"], + deps = [ + ":gcs_server_lib", + ":gcs_test_util_lib", + ":ray_mock", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "placement_group_resource_manager_test", size = "small", @@ -1513,6 +1529,21 @@ cc_test( ], ) +# cc_test( +# name = "gcs_actor_scheduler_mock_test", +# size = "small", +# srcs = [ +# "src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc", +# ], +# copts = COPTS, +# tags = ["team:core"], +# deps = [ +# ":gcs_server_lib", +# ":ray_mock", +# "@com_google_googletest//:gtest_main", +# ], +# ) + cc_test( name = "gcs_based_actor_scheduler_test", size = "small", diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index 1925aedfa4edb..96131feadba41 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -151,8 +151,8 @@ def ray_deps_setup(): # declaring it here allows us to avoid patching the latter. name = "boost", build_file = "@com_github_nelhage_rules_boost//:BUILD.boost", - sha256 = "d73a8da01e8bf8c7eda40b4c84915071a8c8a0df4a6734537ddde4a8580524ee", - url = "https://boostorg.jfrog.io/artifactory/main/release/1.71.0/source/boost_1_71_0.tar.bz2", + sha256 = "83bfc1507731a0906e387fc28b7ef5417d591429e51e788417fe9ff025e116b1", + url = "https://boostorg.jfrog.io/artifactory/main/release/1.74.0/source/boost_1_74_0.tar.bz2", patches = [ "//thirdparty/patches:boost-exception-no_warn_typeid_evaluated.patch", ], @@ -161,10 +161,9 @@ def ray_deps_setup(): auto_http_archive( name = "com_github_nelhage_rules_boost", # If you update the Boost version, remember to update the 'boost' rule. - url = "https://github.com/nelhage/rules_boost/archive/2613d04ab3d22dfc4543ea0a083d9adeaa0daf09.tar.gz", - sha256 = "512f913240e026099d4ca4a98b1ce8048c99de77fdc8e8584e9e2539ee119ca2", + url = "https://github.com/nelhage/rules_boost/archive/652b21e35e4eeed5579e696da0facbe8dba52b1f.tar.gz", + sha256 = "c1b8b2adc3b4201683cf94dda7eef3fc0f4f4c0ea5caa3ed3feffe07e1fb5b15", patches = [ - "//thirdparty/patches:rules_boost-undefine-boost_fallthrough.patch", "//thirdparty/patches:rules_boost-windows-linkopts.patch", ], ) diff --git a/benchmarks/object_store/test_object_store.py b/benchmarks/object_store/test_object_store.py index 5e251f55f8884..022cb17e8b890 100644 --- a/benchmarks/object_store/test_object_store.py +++ b/benchmarks/object_store/test_object_store.py @@ -65,6 +65,7 @@ def sum(self, arr): if "TEST_OUTPUT_JSON" in os.environ: out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") results = { + "broadcast_time": end - start, "object_size": OBJECT_SIZE, "num_nodes": NUM_NODES, "success": "1" diff --git a/benchmarks/single_node/test_single_node.py b/benchmarks/single_node/test_single_node.py index fb44e7fe29ade..3deaa389de600 100644 --- a/benchmarks/single_node/test_single_node.py +++ b/benchmarks/single_node/test_single_node.py @@ -199,7 +199,8 @@ def test_large_object(): "num_args": MAX_ARGS, "returns_time": returns_time, "num_returns": MAX_RETURNS, - "get_time": MAX_RAY_GET_ARGS, + "get_time": get_time, + "num_get_args": MAX_RAY_GET_ARGS, "queued_time": queued_time, "num_queued": MAX_QUEUED_TASKS, "large_object_time": large_object_time, diff --git a/ci/travis/bazel.py b/ci/travis/bazel.py index d731734b6faa9..d462459fc1ead 100755 --- a/ci/travis/bazel.py +++ b/ci/travis/bazel.py @@ -98,35 +98,45 @@ def info(self, *args): return result def aquery(self, *args): - lines = self._call("aquery", "--output=textproto", *args).splitlines() - return textproto_parse(lines, self.encoding, json.JSONEncoder()) + out = self._call("aquery", "--output=jsonproto", *args) + return json.loads(out.decode(self.encoding)) def parse_aquery_shell_calls(aquery_results): """Extracts and yields the command lines representing the genrule() rules from Bazel aquery results. """ - for (key, val) in aquery_results: - if key == "actions": - [mnemonic] = [pair[1] for pair in val if pair[0] == "mnemonic"] - if mnemonic == "Genrule": - yield [pair[1] for pair in val if pair[0] == "arguments"] + for action in aquery_results["actions"]: + if action["mnemonic"] != "Genrule": + continue + yield action["arguments"] def parse_aquery_output_artifacts(aquery_results): """Extracts and yields the file paths representing the output artifact from the provided Bazel aquery results. + + To understand the output of aquery command in textproto format, try: + bazel aquery --include_artifacts=true --output=jsonproto \ + 'mnemonic("Genrule", deps(//:*))' """ + fragments = {} + for fragment in aquery_results["pathFragments"]: + fragments[fragment["id"]] = fragment + artifacts = {} - for (key, val) in aquery_results: - if key == "artifacts": - [artifact_id] = [pair[1] for pair in val if pair[0] == "id"] - [exec_path] = [pair[1] for pair in val if pair[0] == "exec_path"] - artifacts[artifact_id] = exec_path - elif key == "actions": - output_ids = [pair[1] for pair in val if pair[0] == "output_ids"] - for output_id in output_ids: - yield artifacts[output_id] + for artifact in aquery_results["artifacts"]: + artifacts[artifact["id"]] = artifact + + def _path(fragment_id): + fragment = fragments[fragment_id] + parent = _path(fragment["parentId"]) if "parentId" in fragment else [] + return parent + [fragment["label"]] + + for action in aquery_results["actions"]: + for output_id in action["outputIds"]: + path = os.path.join(*_path(artifacts[output_id]["pathFragmentId"])) + yield path def textproto2json(infile, outfile): diff --git a/ci/travis/ci.sh b/ci/travis/ci.sh index 6aa33a22a2000..7faa9ae02a5be 100755 --- a/ci/travis/ci.sh +++ b/ci/travis/ci.sh @@ -139,6 +139,7 @@ test_python() { args+=( python/ray/serve/... python/ray/tests/... + -python/ray/serve:conda_env # runtime_env unsupported on Windows -python/ray/serve:test_api # segfault on windows? https://github.com/ray-project/ray/issues/12541 -python/ray/serve:test_router # timeout -python/ray/serve:test_handle # "fatal error" (?) https://github.com/ray-project/ray/pull/13695 @@ -181,6 +182,7 @@ test_python() { -python/ray/tests:test_ray_init # test_redis_port() seems to fail here, but pass in isolation -python/ray/tests:test_resource_demand_scheduler -python/ray/tests:test_reference_counting # too flaky 9/25/21 + -python/ray/tests:test_runtime_env_plugin # runtime_env not supported on Windows -python/ray/tests:test_runtime_env_env_vars # runtime_env not supported on Windows -python/ray/tests:test_runtime_env_complicated # conda install slow leading to timeout -python/ray/tests:test_stress # timeout @@ -332,7 +334,52 @@ install_ray() { ) } +validate_wheels_commit_str() { + if [ "${OSTYPE}" = msys ]; then + echo "Windows builds do not set the commit string, skipping wheel commit validity check." + return 0 + fi + + if [ -n "${BUILDKITE_COMMIT}" ]; then + EXPECTED_COMMIT=${BUILDKITE_COMMIT:-} + else + EXPECTED_COMMIT=${TRAVIS_COMMIT:-} + fi + + if [ -z "$EXPECTED_COMMIT" ]; then + echo "Could not validate expected wheel commits: TRAVIS_COMMIT is empty." + return 0 + fi + + for whl in .whl/*.whl; do + basename=${whl##*/} + + if [[ "$basename" =~ "_cpp" ]]; then + # cpp wheels cannot be checked this way + echo "Skipping CPP wheel ${basename} for wheel commit validation." + continue + fi + + folder=${basename%%-cp*} + WHL_COMMIT=$(unzip -p "$whl" "${folder}.data/purelib/ray/__init__.py" | grep "__commit__" | awk -F'"' '{print $2}') + + if [ "${WHL_COMMIT}" != "${EXPECTED_COMMIT}" ]; then + echo "Error: Observed wheel commit (${WHL_COMMIT}) is not expected commit (${EXPECTED_COMMIT}). Aborting." + exit 1 + fi + + echo "Wheel ${basename} has the correct commit: ${WHL_COMMIT}" + done + + echo "All wheels passed the sanity check and have the correct wheel commit set." +} + build_wheels() { + # Create wheel output directory and empty contents + # If buildkite runners are re-used, wheels from previous builds might be here, so we delete them. + mkdir -p .whl + rm -rf .whl/* || true + case "${OSTYPE}" in linux*) # Mount bazel cache dir to the docker container. @@ -353,7 +400,6 @@ build_wheels() { -e "RAY_DEBUG_BUILD=${RAY_DEBUG_BUILD:-}" ) - if [ -z "${BUILDKITE-}" ]; then # This command should be kept in sync with ray/python/README-building-wheels.md, # except the "${MOUNT_BAZEL_CACHE[@]}" part. @@ -361,19 +407,25 @@ build_wheels() { quay.io/pypa/manylinux2014_x86_64 /ray/python/build-wheel-manylinux2014.sh else rm -rf /ray-mount/* + rm -rf /ray-mount/.whl || true + rm -rf /ray/.whl || true cp -rT /ray /ray-mount - ls /ray-mount + ls -a /ray-mount docker run --rm -v /ray:/ray-mounted ubuntu:focal ls / docker run --rm -v /ray:/ray-mounted ubuntu:focal ls /ray-mounted docker run --rm -w /ray -v /ray:/ray "${MOUNT_BAZEL_CACHE[@]}" \ quay.io/pypa/manylinux2014_x86_64 /ray/python/build-wheel-manylinux2014.sh cp -rT /ray-mount /ray # copy new files back here find . | grep whl # testing + + validate_wheels_commit_str fi ;; darwin*) # This command should be kept in sync with ray/python/README-building-wheels.md. "${WORKSPACE_DIR}"/python/build-wheel-macos.sh + + validate_wheels_commit_str ;; msys*) keep_alive "${WORKSPACE_DIR}"/python/build-wheel-windows.sh diff --git a/ci/travis/format.sh b/ci/travis/format.sh index e31245faad61d..7dbf608d18734 100755 --- a/ci/travis/format.sh +++ b/ci/travis/format.sh @@ -83,6 +83,10 @@ if [[ $(flake8 --version) != *"flake8_quotes"* ]]; then echo "WARNING: Ray uses flake8 with flake8_quotes. Might error without it. Install with: pip install flake8-quotes" fi +if [[ $(flake8 --version) != *"flake8-bugbear"* ]]; then + echo "WARNING: Ray uses flake8 with flake8-bugbear. Might error without it. Install with: pip install flake8-bugbear" +fi + SHELLCHECK_FLAGS=( --exclude=1090 # "Can't follow non-constant source. Use a directive to specify location." --exclude=1091 # "Not following {file} due to some error" diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index 32b39ded1401e..b52f75e8a4164 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -408,7 +408,7 @@ install_dependencies() { # RLlib testing with TF 1.x. if [ "${RLLIB_TESTING-}" = 1 ] && { [ -n "${TF_VERSION-}" ] || [ -n "${TFP_VERSION-}" ]; }; then - pip install --upgrade tensorflow-probability=="${TFP_VERSION}" tensorflow=="${TF_VERSION}" gym + pip install --upgrade tensorflow-probability=="${TFP_VERSION}" tensorflow=="${TF_VERSION}" gym==0.19 fi # Additional Tune dependency for Horovod. diff --git a/cpp/BUILD.bazel b/cpp/BUILD.bazel index 9d4e7416cda1b..9603c863546c1 100644 --- a/cpp/BUILD.bazel +++ b/cpp/BUILD.bazel @@ -90,6 +90,7 @@ genrule( mkdir -p "$$PY_CPP_DIR/lib/" && cp -f -r $$WORK_DIR/external/msgpack/include/* "$$PY_CPP_DIR/include" && cp -f -r "$$WORK_DIR/external/boost/boost/archive" "$$BOOST_DIR" && + cp -f -r "$$WORK_DIR/external/boost/boost/assert" "$$BOOST_DIR" && cp -f -r "$$WORK_DIR/external/boost/boost/bind" "$$BOOST_DIR" && cp -f -r "$$WORK_DIR/external/boost/boost/callable_traits" "$$BOOST_DIR" && cp -f -r "$$WORK_DIR/external/boost/boost/concept" "$$BOOST_DIR" && diff --git a/cpp/src/ray/api.cc b/cpp/src/ray/api.cc index a1a8c6507541c..ed2b1b89230cd 100644 --- a/cpp/src/ray/api.cc +++ b/cpp/src/ray/api.cc @@ -40,7 +40,7 @@ void Init() { bool IsInitialized() { return is_init_; } void Shutdown() { - // TODO(guyang.sgy): Clean the ray runtime. + // TODO(SongGuyang): Clean the ray runtime. internal::AbstractRayRuntime::DoShutdown(); is_init_ = false; } diff --git a/cpp/src/ray/runtime/abstract_ray_runtime.cc b/cpp/src/ray/runtime/abstract_ray_runtime.cc index 177fae17d3122..db9fac32db4e8 100644 --- a/cpp/src/ray/runtime/abstract_ray_runtime.cc +++ b/cpp/src/ray/runtime/abstract_ray_runtime.cc @@ -145,7 +145,7 @@ InvocationSpec BuildInvocationSpec1(TaskType task_type, InvocationSpec invocation_spec; invocation_spec.task_type = task_type; invocation_spec.task_id = - TaskID::ForFakeTask(); // TODO(Guyang Song): make it from different task + TaskID::ForFakeTask(); // TODO(SongGuyang): make it from different task invocation_spec.remote_function_holder = remote_function_holder; invocation_spec.actor_id = actor; invocation_spec.args = TransformArgs(args); diff --git a/cpp/src/ray/runtime/object/native_object_store.cc b/cpp/src/ray/runtime/object/native_object_store.cc index d9326feb2ae66..7add3b72b73af 100644 --- a/cpp/src/ray/runtime/object/native_object_store.cc +++ b/cpp/src/ray/runtime/object/native_object_store.cc @@ -116,7 +116,7 @@ std::vector NativeObjectStore::Wait(const std::vector &ids, int num_objects, int timeout_ms) { std::vector results; auto &core_worker = CoreWorkerProcess::GetCoreWorker(); - // TODO(guyang.sgy): Support `fetch_local` option in API. + // TODO(SongGuyang): Support `fetch_local` option in API. // Simply set `fetch_local` to be true. ::ray::Status status = core_worker.Wait(ids, num_objects, timeout_ms, &results, true); if (!status.ok()) { diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc index cb24e9d3a2b8d..40b7845578a74 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc @@ -32,7 +32,7 @@ LocalModeTaskSubmitter::LocalModeTaskSubmitter( ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, const ActorCreationOptions &options) { - /// TODO(Guyang Song): Make the information of TaskSpecification more reasonable + /// TODO(SongGuyang): Make the information of TaskSpecification more reasonable /// We just reuse the TaskSpecification class and make the single process mode work. /// Maybe some infomation of TaskSpecification are not reasonable or invalid. /// We will enhance this after implement the cluster mode. @@ -82,7 +82,7 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, AbstractRayRuntime *runtime = &local_mode_ray_tuntime_; if (invocation.task_type == TaskType::ACTOR_CREATION_TASK || invocation.task_type == TaskType::ACTOR_TASK) { - /// TODO(Guyang Song): Handle task dependencies. + /// TODO(SongGuyang): Handle task dependencies. /// Execute actor task directly in the main thread because we must guarantee the actor /// task executed by calling order. TaskExecutor::Invoke(task_specification, actor, runtime, actor_contexts_, diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index be24fe98d9a27..f0a1e12faaa78 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -75,7 +75,7 @@ std::shared_ptr TaskExecutor::current_actor_ = nullptr; TaskExecutor::TaskExecutor(AbstractRayRuntime &abstract_ray_tuntime_) : abstract_ray_tuntime_(abstract_ray_tuntime_) {} -// TODO(Guyang Song): Make a common task execution function used for both local mode and +// TODO(SongGuyang): Make a common task execution function used for both local mode and // cluster mode. std::unique_ptr TaskExecutor::Execute(InvocationSpec &invocation) { abstract_ray_tuntime_.GetWorkerContext(); diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index a528f17e03af3..825e5ca52ab20 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -16,8 +16,10 @@ #include #include + #include #include + #include "absl/synchronization/mutex.h" #include "invocation_spec.h" #include "ray/common/id.h" @@ -62,7 +64,7 @@ class TaskExecutor { public: TaskExecutor(AbstractRayRuntime &abstract_ray_tuntime_); - /// TODO(Guyang Song): support multiple tasks execution + /// TODO(SongGuyang): support multiple tasks execution std::unique_ptr Execute(InvocationSpec &invocation); static void Invoke( diff --git a/cpp/src/ray/util/process_helper.cc b/cpp/src/ray/util/process_helper.cc index 40f115e646e95..35ecd8123daa2 100644 --- a/cpp/src/ray/util/process_helper.cc +++ b/cpp/src/ray/util/process_helper.cc @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "process_helper.h" + #include -#include "process_helper.h" #include "ray/util/process.h" #include "ray/util/util.h" #include "src/ray/protobuf/gcs.pb.h" @@ -27,9 +28,9 @@ using ray::core::WorkerType; void ProcessHelper::StartRayNode(const int redis_port, const std::string redis_password, const std::vector &head_args) { - std::vector cmdargs({"ray", "start", "--head", "--port", - std::to_string(redis_port), "--redis-password", - redis_password}); + std::vector cmdargs( + {"ray", "start", "--head", "--port", std::to_string(redis_port), "--redis-password", + redis_password, "--node-ip-address", GetNodeIpAddress()}); if (!head_args.empty()) { cmdargs.insert(cmdargs.end(), head_args.begin(), head_args.end()); } @@ -124,7 +125,7 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback) if (!ConfigInternal::Instance().job_id.empty()) { options.job_id = JobID::FromHex(ConfigInternal::Instance().job_id); } else { - /// TODO(Guyang Song): Get next job id from core worker by GCS client. + /// TODO(SongGuyang): Get next job id from core worker by GCS client. /// Random a number to avoid repeated job ids. /// The repeated job ids will lead to task hang when driver connects to a existing /// cluster more than once. diff --git a/dashboard/agent.py b/dashboard/agent.py index 7301b4299f95f..522972fb1c06f 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -337,8 +337,8 @@ async def _check_parent(): # https://github.com/ray-project/ray/issues/14026. if sys.platform == "win32": logger.warning( - "The dashboard is currently disabled on windows." - "See https://github.com/ray-project/ray/issues/14026" + "The dashboard is currently disabled on windows. " + "See https://github.com/ray-project/ray/issues/14026 " "for more details") while True: time.sleep(999) @@ -362,14 +362,34 @@ async def _check_parent(): loop = asyncio.get_event_loop() loop.run_until_complete(agent.run()) except Exception as e: - # Something went wrong, so push an error to all drivers. - redis_client = ray._private.services.create_redis_client( - args.redis_address, password=args.redis_password) - traceback_str = ray._private.utils.format_error_message( - traceback.format_exc()) - message = ("The agent on node {} failed with the following " - "error:\n{}".format(platform.uname()[1], traceback_str)) - ray._private.utils.push_error_to_driver_through_redis( - redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, message) - logger.exception(message) - raise e + # All these env vars should be available because + # they are provided by the parent raylet. + restart_count = os.environ["RESTART_COUNT"] + max_restart_count = os.environ["MAX_RESTART_COUNT"] + raylet_pid = os.environ["RAY_RAYLET_PID"] + node_ip = args.node_ip_address + if restart_count >= max_restart_count: + # Agent is failed to be started many times. + # Push an error to all drivers, so that users can know the + # impact of the issue. + redis_client = ray._private.services.create_redis_client( + args.redis_address, password=args.redis_password) + traceback_str = ray._private.utils.format_error_message( + traceback.format_exc()) + message = ( + f"(ip={node_ip}) " + f"The agent on node {platform.uname()[1]} failed to " + f"be restarted {max_restart_count} " + "times. There are 3 possible problems if you see this error." + "\n 1. The dashboard might not display correct " + "information on this node." + "\n 2. Metrics on this node won't be reported." + "\n 3. runtime_env APIs won't work." + "\nCheck out the `dashboard_agent.log` to see the " + "detailed failure messages.") + ray._private.utils.push_error_to_driver_through_redis( + redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, + message) + logger.error(message) + logger.exception(e) + exit(1) diff --git a/dashboard/client/src/pages/job/JobDetail.tsx b/dashboard/client/src/pages/job/JobDetail.tsx index b720b9c057de1..892034937f107 100644 --- a/dashboard/client/src/pages/job/JobDetail.tsx +++ b/dashboard/client/src/pages/job/JobDetail.tsx @@ -11,6 +11,7 @@ import { TableRow, Tabs, } from "@material-ui/core"; +import dayjs from "dayjs"; import React from "react"; import { Link, RouteComponentProps } from "react-router-dom"; import ActorTable from "../../components/ActorTable"; @@ -140,6 +141,16 @@ const JobDetailPage = (props: RouteComponentProps<{ id: string }>) => { Driver Pid:{" "} {jobInfo.driverPid} + + StartTime:{" "} + {dayjs(Number(jobInfo.startTime)).format("YYYY/MM/DD HH:mm:ss")} + + + EndTime:{" "} + {jobInfo.endTime > 0 + ? dayjs(Number(jobInfo.endTime)).format("YYYY/MM/DD HH:mm:ss") + : "-"} + {jobInfo.eventUrl && ( Event Link:{" "} diff --git a/dashboard/client/src/pages/job/index.tsx b/dashboard/client/src/pages/job/index.tsx index e52af1ce5ec01..81be74b03e2f4 100644 --- a/dashboard/client/src/pages/job/index.tsx +++ b/dashboard/client/src/pages/job/index.tsx @@ -24,7 +24,14 @@ const useStyles = makeStyles((theme) => ({ }, })); -const columns = ["ID", "DriverIpAddress", "DriverPid", "IsDead", "Timestamp"]; +const columns = [ + "ID", + "DriverIpAddress", + "DriverPid", + "IsDead", + "StartTime", + "EndTime", +]; const JobList = () => { const classes = useStyles(); @@ -98,7 +105,8 @@ const JobList = () => { driverIpAddress, isDead, driverPid, - timestamp, + startTime, + endTime, }) => ( @@ -110,7 +118,12 @@ const JobList = () => { {isDead ? "true" : "false"} - {dayjs(Number(timestamp)).format("YYYY/MM/DD HH:mm:ss")} + {dayjs(Number(startTime)).format("YYYY/MM/DD HH:mm:ss")} + + + {endTime > 0 + ? dayjs(Number(endTime)).format("YYYY/MM/DD HH:mm:ss") + : "-"} ), diff --git a/dashboard/client/src/type/job.d.ts b/dashboard/client/src/type/job.d.ts index c5ca4dce874c1..ef9181dd2c92d 100644 --- a/dashboard/client/src/type/job.d.ts +++ b/dashboard/client/src/type/job.d.ts @@ -9,6 +9,8 @@ export type Job = { driverEntry: string; state: string; timestamp: number; + startTime: number; + endTime: number; namespaceId: string; driverPid: number; driverIpAddress: string; diff --git a/dashboard/modules/job/job_agent.py b/dashboard/modules/job/job_agent.py index 34b72462501ab..f56a24db83586 100644 --- a/dashboard/modules/job/job_agent.py +++ b/dashboard/modules/job/job_agent.py @@ -202,7 +202,9 @@ def _gen_driver_code(self): # Per job config job_config_items = { - "worker_env": self._job_info.env, + "runtime_env": { + "env_vars": self._job_info.env + }, "code_search_path": [job_package_dir], } diff --git a/dashboard/modules/runtime_env/runtime_env_agent.py b/dashboard/modules/runtime_env/runtime_env_agent.py index 5151278b1ab26..3c8b9c18bf9f3 100644 --- a/dashboard/modules/runtime_env/runtime_env_agent.py +++ b/dashboard/modules/runtime_env/runtime_env_agent.py @@ -6,6 +6,7 @@ import os import time from typing import Dict, Set +from ray._private.utils import import_attr from ray.core.generated import runtime_env_agent_pb2 from ray.core.generated import runtime_env_agent_pb2_grpc @@ -17,8 +18,8 @@ _internal_kv_initialized) from ray._private.ray_logging import setup_component_logger from ray._private.runtime_env.conda import CondaManager +from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.runtime_env.working_dir import WorkingDirManager -from ray._private.runtime_env import RuntimeEnvContext logger = logging.getLogger(__name__) @@ -78,13 +79,20 @@ def get_or_create_logger(self, job_id: bytes): return self._per_job_logger_cache[job_id] async def CreateRuntimeEnv(self, request, context): - async def _setup_runtime_env(serialized_runtime_env): + async def _setup_runtime_env(serialized_runtime_env, + serialized_allocated_resource_instances): # This function will be ran inside a thread def run_setup_with_logger(): runtime_env: dict = json.loads(serialized_runtime_env or "{}") + allocated_resource: dict = json.loads( + serialized_allocated_resource_instances or "{}") # Use a separate logger for each job. per_job_logger = self.get_or_create_logger(request.job_id) + # TODO(chenk008): Add log about allocated_resource to + # avoid lint error. That will be moved to cgroup plugin. + per_job_logger.debug(f"Worker has resource :" + f"{allocated_resource}") context = RuntimeEnvContext( env_vars=runtime_env.get("env_vars")) self._conda_manager.setup( @@ -98,6 +106,15 @@ def run_setup_with_logger(): self._working_dir_uri_to_envs[uri].add( serialized_runtime_env) + # Run setup function from all the plugins + for plugin_class_path in runtime_env.get("plugins", {}).keys(): + plugin_class = import_attr(plugin_class_path) + # TODO(simon): implement uri support + plugin_class.create("uri not implemented", runtime_env, + context) + plugin_class.modify_context("uri not implemented", + runtime_env, context) + return context loop = asyncio.get_event_loop() @@ -138,7 +155,8 @@ def run_setup_with_logger(): for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES): try: runtime_env_context = await _setup_runtime_env( - serialized_env) + serialized_env, + request.serialized_allocated_resource_instances) break except Exception as ex: logger.exception("Runtime env creation failed.") diff --git a/dashboard/modules/snapshot/snapshot_head.py b/dashboard/modules/snapshot/snapshot_head.py index 424e41ff45e16..87082f5463147 100644 --- a/dashboard/modules/snapshot/snapshot_head.py +++ b/dashboard/modules/snapshot/snapshot_head.py @@ -73,11 +73,10 @@ async def get_job_info(self): for job_table_entry in reply.job_info_list: job_id = job_table_entry.job_id.hex() config = { - "env_vars": dict(job_table_entry.config.worker_env), "namespace": job_table_entry.config.ray_namespace, "metadata": dict(job_table_entry.config.metadata), "runtime_env": json.loads( - job_table_entry.config.serialized_runtime_env), + job_table_entry.config.runtime_env.serialized_runtime_env), } entry = { "is_dead": job_table_entry.is_dead, diff --git a/dashboard/modules/snapshot/snapshot_schema.json b/dashboard/modules/snapshot/snapshot_schema.json index f660813110f1e..4768c2a5e292c 100644 --- a/dashboard/modules/snapshot/snapshot_schema.json +++ b/dashboard/modules/snapshot/snapshot_schema.json @@ -39,9 +39,6 @@ "config": { "type": "object", "properties": { - "envVars": { - "type": "object" - }, "namespace": { "type": "string" }, @@ -53,7 +50,6 @@ } }, "required": [ - "envVars", "namespace", "metadata", "runtimeEnv" diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index ea335c61bad21..6565ea08814cf 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -107,7 +107,7 @@ def _search_agent(processes): agent_proc.kill() agent_proc.wait() # The agent will be restarted for imports failure. - for x in range(50): + for _ in range(300): agent_proc = _search_agent(raylet_proc.children()) if agent_proc: agent_pids.add(agent_proc.pid) diff --git a/doc/BUILD b/doc/BUILD index eed30be63b145..5f03ce17afc26 100644 --- a/doc/BUILD +++ b/doc/BUILD @@ -3,6 +3,31 @@ # Please keep these sorted alphabetically, but start with the # root directory. # -------------------------------------------------------------------- + +# Support for Dask has been dropped in 3.6. +py_test( + name = "dask_xgboost", + size = "medium", + main = "examples/dask_xgboost/dask_xgboost.py", + srcs = ["examples/dask_xgboost/dask_xgboost.py"], + tags = ["exclusive", "team:ml", "py37"], + args = ["--smoke-test", "--address ''", "--num-actors 4", + "--cpus-per-actor 1", "--num-actors-inference 4", + "--cpus-per-actor-inference 1"] +) + +# Support for Modin has been dropped in 3.6. +py_test( + name = "modin_xgboost", + size = "medium", + main = "examples/modin_xgboost/modin_xgboost.py", + srcs = ["examples/modin_xgboost/modin_xgboost.py"], + tags = ["exclusive", "team:ml", "py37"], + args = ["--smoke-test", "--address ''", "--num-actors 4", + "--cpus-per-actor 1", "--num-actors-inference 4", + "--cpus-per-actor-inference 1"] +) + py_test( name = "plot_hyperparameter", size = "small", diff --git a/doc/examples/dask_xgboost/README.rst b/doc/examples/dask_xgboost/README.rst new file mode 100644 index 0000000000000..8feca331c5d78 --- /dev/null +++ b/doc/examples/dask_xgboost/README.rst @@ -0,0 +1 @@ +:orphan: diff --git a/doc/examples/dask_xgboost/dask_xgboost.py b/doc/examples/dask_xgboost/dask_xgboost.py new file mode 100644 index 0000000000000..d4e50a33faf70 --- /dev/null +++ b/doc/examples/dask_xgboost/dask_xgboost.py @@ -0,0 +1,321 @@ +# flake8: noqa: E501 +""" +XGBoost-Ray with Dask +====================== + +This notebook includes an example workflow using +`XGBoost-Ray `_ and +`Dask `_ for distributed model training, +hyperparameter optimization, and prediction. +""" + +############################################################################### +# Cluster Setup +# ------------- +# +# First, we'll set up our Ray Cluster. The provided ``dask_xgboost.yaml`` +# cluster config can be used to set up an AWS cluster with 64 CPUs. +# +# The following steps assume you are in a directory with both +# ``dask_xgboost.yaml`` and this file saved as ``dask_xgboost.ipynb``. +# +# **Step 1:** Bring up the Ray cluster. +# +# .. code-block:: bash +# +# $ pip install ray boto3 +# $ ray up dask_xgboost.yaml +# +# **Step 2:** Move ``dask_xgboost.ipynb`` to the cluster and start Jupyter. +# +# .. code-block:: bash +# +# $ ray rsync_up dask_xgboost.yaml "./dask_xgboost.ipynb" \ +# "~/dask_xgboost.ipynb" +# $ ray exec dask_xgboost.yaml --port-forward=9999 "jupyter notebook \ +# --port=9999" +# +# You can then access this notebook at the URL that is output: +# ``http://localhost:9999/?token=`` + +############################################################################### +# Python Setup +# ------------ +# +# First, we'll import all the libraries we'll be using. This step also helps us +# verify that the environment is configured correctly. If any of the imports +# are missing, an exception will be raised. + +import argparse +import time + +import dask +import dask.dataframe as dd +from xgboost_ray import RayDMatrix, RayParams, train, predict + +import ray +from ray import tune +from ray.util.dask import ray_dask_get + +############################################################################### +# +# Next, let's parse some arguments. This will be used for executing the ``.py`` +# file, but not for the ``.ipynb``. If you are using the interactive notebook, +# you can directly override the arguments manually. + +parser = argparse.ArgumentParser() +parser.add_argument( + "--address", type=str, default="auto", help="The address to use for Ray.") +parser.add_argument( + "--smoke-test", + action="store_true", + help="Read a smaller dataset for quick testing purposes.") +parser.add_argument( + "--num-actors", + type=int, + default=4, + help="Sets number of actors for training.") +parser.add_argument( + "--cpus-per-actor", + type=int, + default=6, + help="The number of CPUs per actor for training.") +parser.add_argument( + "--num-actors-inference", + type=int, + default=16, + help="Sets number of actors for inference.") +parser.add_argument( + "--cpus-per-actor-inference", + type=int, + default=2, + help="The number of CPUs per actor for inference.") +# Ignore -f from ipykernel_launcher +args, _ = parser.parse_known_args() + +############################################################################### +# Override these arguments as needed: + +address = args.address +smoke_test = args.smoke_test +num_actors = args.num_actors +cpus_per_actor = args.cpus_per_actor +num_actors_inference = args.num_actors_inference +cpus_per_actor_inference = args.cpus_per_actor_inference + +############################################################################### +# Connecting to the Ray cluster +# ----------------------------- +# Now, let's connect our Python script to this newly deployed Ray cluster! + +if not ray.is_initialized(): + ray.init(address=address) + +############################################################################### +# Data Preparation +# ----------------- +# We will use the `HIGGS dataset from the UCI Machine Learning dataset +# repository `_. The HIGGS +# dataset consists of 11,000,000 samples and 28 attributes, which is large +# enough size to show the benefits of distributed computation. +# +# We set the Dask scheduler to ``ray_dask_get`` to use `Dask on Ray +# `_ backend. + +LABEL_COLUMN = "label" +if smoke_test: + # Test dataset with only 10,000 records. + FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \ + ".csv" +else: + # Full dataset. This may take a couple of minutes to load. + FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \ + "/00280/HIGGS.csv.gz" +colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)] +dask.config.set(scheduler=ray_dask_get) + +############################################################################### + +load_data_start_time = time.time() + +data = dd.read_csv(FILE_URL, names=colnames) +data = data[sorted(colnames)] +data = data.persist() + +load_data_end_time = time.time() +load_data_duration = load_data_end_time - load_data_start_time +print(f"Dataset loaded in {load_data_duration} seconds.") + +############################################################################### +# With the connection established, we can now create the Dask dataframe. +# +# We will split the data into a training set and a evaluation set using a 80-20 +# proportion. + +train_df, eval_df = data.random_split([0.8, 0.2]) +train_df, eval_df = train_df.persist(), eval_df.persist() +print(train_df, eval_df) + +############################################################################### +# Distributed Training +# -------------------- +# The ``train_xgboost`` function contains all of the logic necessary for +# training using XGBoost-Ray. +# +# Distributed training can not only speed up the process, but also allow you +# to use datasets that are to large to fit in memory of a single node. With +# distributed training, the dataset is sharded across different actors +# running on separate nodes. Those actors communicate with each other to +# create the final model. +# +# First, the dataframes are wrapped in ``RayDMatrix`` objects, which handle +# data sharding across the cluster. Then, the ``train`` function is called. +# The evaluation scores will be saved to ``evals_result`` dictionary. The +# function returns a tuple of the trained model (booster) and the evaluation +# scores. +# +# The ``ray_params`` variable expects a ``RayParams`` object that contains +# Ray-specific settings, such as the number of workers. + + +def train_xgboost(config, train_df, test_df, target_column, ray_params): + train_set = RayDMatrix(train_df, target_column) + test_set = RayDMatrix(test_df, target_column) + + evals_result = {} + + train_start_time = time.time() + + # Train the classifier + bst = train( + params=config, + dtrain=train_set, + evals=[(test_set, "eval")], + evals_result=evals_result, + ray_params=ray_params) + + train_end_time = time.time() + train_duration = train_end_time - train_start_time + print(f"Total time taken: {train_duration} seconds.") + + model_path = "model.xgb" + bst.save_model(model_path) + print("Final validation error: {:.4f}".format( + evals_result["eval"]["error"][-1])) + + return bst, evals_result + + +############################################################################### +# We can now pass our Dask dataframes and run the function. We will use +# ``RayParams`` to specify that the number of actors and CPUs to train with. +# +# The dataset has to be downloaded onto the cluster, which may take a few +# minutes. + +# standard XGBoost config for classification +config = { + "tree_method": "approx", + "objective": "binary:logistic", + "eval_metric": ["logloss", "error"], +} + +bst, evals_result = train_xgboost( + config, train_df, eval_df, LABEL_COLUMN, + RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors)) +print(f"Results: {evals_result}") + +############################################################################### +# Hyperparameter optimization +# --------------------------- +# If we are not content with the results obtained with default XGBoost +# parameters, we can use `Ray Tune +# `_ for cutting-edge +# distributed hyperparameter tuning. XGBoost-Ray automatically integrates +# with Ray Tune, meaning we can use the same training function as before. +# +# In this workflow, we will tune three hyperparameters - ``eta``, ``subsample`` +# and ``max_depth``. We are using `Tune's samplers to define the search +# space `_. +# +# The experiment configuration is done through ``tune.run``. We set the amount +# of resources each trial (hyperparameter combination) requires by using the +# ``get_tune_resources`` method of ``RayParams``. The ``num_samples`` argument +# controls how many trials will be ran in total. In the end, the best +# combination of hyperparameters evaluated during the experiment will be +# returned. +# +# By default, Tune will use simple random search. However, Tune also +# provides various `search algorithms +# `_ and +# `schedulers `_ +# to further improve the optimization process. + + +def tune_xgboost(train_df, test_df, target_column): + # Set XGBoost config. + config = { + "tree_method": "approx", + "objective": "binary:logistic", + "eval_metric": ["logloss", "error"], + "eta": tune.loguniform(1e-4, 1e-1), + "subsample": tune.uniform(0.5, 1.0), + "max_depth": tune.randint(1, 9) + } + + ray_params = RayParams( + max_actor_restarts=1, + cpus_per_actor=cpus_per_actor, + num_actors=num_actors) + + tune_start_time = time.time() + + analysis = tune.run( + tune.with_parameters( + train_xgboost, + train_df=train_df, + test_df=test_df, + target_column=target_column, + ray_params=ray_params), + # Use the `get_tune_resources` helper function to set the resources. + resources_per_trial=ray_params.get_tune_resources(), + config=config, + num_samples=10, + metric="eval-error", + mode="min") + + tune_end_time = time.time() + tune_duration = tune_end_time - tune_start_time + print(f"Total time taken: {tune_duration} seconds.") + + accuracy = 1. - analysis.best_result["eval-error"] + print(f"Best model parameters: {analysis.best_config}") + print(f"Best model total accuracy: {accuracy:.4f}") + + return analysis.best_config + + +############################################################################### +# Hyperparameter optimization may take some time to complete. + +tune_xgboost(train_df, eval_df, LABEL_COLUMN) + +############################################################################### +# Prediction +# ---------- +# With the model trained, we can now predict on unseen data. For the +# purposes of this example, we will use the same dataset for prediction as +# for training. +# +# Since prediction is naively parallelizable, distributing it over multiple +# actors can measurably reduce the amount of time needed. + +inference_df = RayDMatrix(data, ignore=[LABEL_COLUMN, "partition"]) +results = predict( + bst, + inference_df, + ray_params=RayParams( + cpus_per_actor=cpus_per_actor_inference, + num_actors=num_actors_inference)) + +print(results) diff --git a/doc/examples/dask_xgboost/dask_xgboost.yaml b/doc/examples/dask_xgboost/dask_xgboost.yaml new file mode 100644 index 0000000000000..e598a115069b6 --- /dev/null +++ b/doc/examples/dask_xgboost/dask_xgboost.yaml @@ -0,0 +1,24 @@ +cluster_name: dask_xgboost + +max_workers: 3 + +provider: + type: aws + region: us-west-1 + +auth: + ssh_user: ubuntu + +available_node_types: + 16_cpu_node: + min_workers: 3 + max_workers: 3 + node_config: + InstanceType: m5.4xlarge + ImageId: latest_dlami + resources: { } + +head_node_type: 16_cpu_node + +setup_commands: + - pip install -U jupyter ray[tune] xgboost_ray[default] dask pandas diff --git a/doc/examples/modin_xgboost/README.rst b/doc/examples/modin_xgboost/README.rst new file mode 100644 index 0000000000000..8feca331c5d78 --- /dev/null +++ b/doc/examples/modin_xgboost/README.rst @@ -0,0 +1 @@ +:orphan: diff --git a/doc/examples/modin_xgboost/modin_xgboost.py b/doc/examples/modin_xgboost/modin_xgboost.py new file mode 100644 index 0000000000000..bcbe6c0968068 --- /dev/null +++ b/doc/examples/modin_xgboost/modin_xgboost.py @@ -0,0 +1,233 @@ +""" +XGBoost-Ray with Modin +====================== + +This notebook includes an example workflow using +`XGBoost-Ray `_ and +`Modin `_ for distributed model +training and prediction. +""" + +############################################################################### +# Cluster Setup +# ------------- +# +# First, we'll set up our Ray Cluster. The provided ``modin_xgboost.yaml`` +# cluster config can be used to set up an AWS cluster with 64 CPUs. +# +# The following steps assume you are in a directory with both +# ``modin_xgboost.yaml`` and this file saved as ``modin_xgboost.ipynb``. +# +# **Step 1:** Bring up the Ray cluster. +# +# .. code-block:: bash +# +# $ pip install ray boto3 +# $ ray up modin_xgboost.yaml +# +# **Step 2:** Move ``modin_xgboost.ipynb`` to the cluster and start Jupyter. +# +# .. code-block:: bash +# +# $ ray rsync_up modin_xgboost.yaml "./modin_xgboost.ipynb" \ +# "~/modin_xgboost.ipynb" +# $ ray exec modin_xgboost.yaml --port-forward=9999 "jupyter notebook \ +# --port=9999" +# +# You can then access this notebook at the URL that is output: +# ``http://localhost:9999/?token=`` + +############################################################################### +# Python Setup +# ------------ +# +# First, we'll import all the libraries we'll be using. This step also helps us +# verify that the environment is configured correctly. If any of the imports +# are missing, an exception will be raised. + +import argparse +import time + +import modin.pandas as pd +from modin.experimental.sklearn.model_selection import train_test_split +from xgboost_ray import RayDMatrix, RayParams, train, predict + +import ray + +############################################################################### +# +# Next, let's parse some arguments. This will be used for executing the ``.py`` +# file, but not for the ``.ipynb``. If you are using the interactive notebook, +# you can directly override the arguments manually. + +parser = argparse.ArgumentParser() +parser.add_argument( + "--address", type=str, default="auto", help="The address to use for Ray.") +parser.add_argument( + "--smoke-test", + action="store_true", + help="Read a smaller dataset for quick testing purposes.") +parser.add_argument( + "--num-actors", + type=int, + default=4, + help="Sets number of actors for training.") +parser.add_argument( + "--cpus-per-actor", + type=int, + default=8, + help="The number of CPUs per actor for training.") +parser.add_argument( + "--num-actors-inference", + type=int, + default=16, + help="Sets number of actors for inference.") +parser.add_argument( + "--cpus-per-actor-inference", + type=int, + default=2, + help="The number of CPUs per actor for inference.") +# Ignore -f from ipykernel_launcher +args, _ = parser.parse_known_args() + +############################################################################### +# Override these arguments as needed: + +address = args.address +smoke_test = args.smoke_test +num_actors = args.num_actors +cpus_per_actor = args.cpus_per_actor +num_actors_inference = args.num_actors_inference +cpus_per_actor_inference = args.cpus_per_actor_inference + +############################################################################### +# Connecting to the Ray cluster +# ----------------------------- +# Now, let's connect our Python script to this newly deployed Ray cluster! + +if not ray.is_initialized(): + ray.init(address=address) + +############################################################################### +# Data Preparation +# ----------------- +# We will use the `HIGGS dataset from the UCI Machine Learning dataset +# repository `_. The HIGGS +# dataset consists of 11,000,000 samples and 28 attributes, which is large +# enough size to show the benefits of distributed computation. + +LABEL_COLUMN = "label" +if smoke_test: + # Test dataset with only 10,000 records. + FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \ + ".csv" +else: + # Full dataset. This may take a couple of minutes to load. + FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \ + "/00280/HIGGS.csv.gz" + +colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)] + +############################################################################### + +load_data_start_time = time.time() + +df = pd.read_csv(FILE_URL, names=colnames) + +load_data_end_time = time.time() +load_data_duration = load_data_end_time - load_data_start_time +print(f"Dataset loaded in {load_data_duration} seconds.") + +############################################################################### +# Split data into training and validation. + +df_train, df_validation = train_test_split(df) +print(df_train, df_validation) + +############################################################################### +# Distributed Training +# -------------------- +# The ``train_xgboost`` function contains all of the logic necessary for +# training using XGBoost-Ray. +# +# Distributed training can not only speed up the process, but also allow you +# to use datasets that are to large to fit in memory of a single node. With +# distributed training, the dataset is sharded across different actors +# running on separate nodes. Those actors communicate with each other to +# create the final model. +# +# First, the dataframes are wrapped in ``RayDMatrix`` objects, which handle +# data sharding across the cluster. Then, the ``train`` function is called. +# The evaluation scores will be saved to ``evals_result`` dictionary. The +# function returns a tuple of the trained model (booster) and the evaluation +# scores. +# +# The ``ray_params`` variable expects a ``RayParams`` object that contains +# Ray-specific settings, such as the number of workers. + + +def train_xgboost(config, train_df, test_df, target_column, ray_params): + train_set = RayDMatrix(train_df, target_column) + test_set = RayDMatrix(test_df, target_column) + + evals_result = {} + + train_start_time = time.time() + + # Train the classifier + bst = train( + params=config, + dtrain=train_set, + evals=[(test_set, "eval")], + evals_result=evals_result, + verbose_eval=False, + num_boost_round=100, + ray_params=ray_params) + + train_end_time = time.time() + train_duration = train_end_time - train_start_time + print(f"Total time taken: {train_duration} seconds.") + + model_path = "model.xgb" + bst.save_model(model_path) + print("Final validation error: {:.4f}".format( + evals_result["eval"]["error"][-1])) + + return bst, evals_result + + +############################################################################### +# We can now pass our Modin dataframes and run the function. We will use +# ``RayParams`` to specify that the number of actors and CPUs to train with. + +# standard XGBoost config for classification +config = { + "tree_method": "approx", + "objective": "binary:logistic", + "eval_metric": ["logloss", "error"], +} + +bst, evals_result = train_xgboost( + config, df_train, df_validation, LABEL_COLUMN, + RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors)) +print(f"Results: {evals_result}") + +############################################################################### +# Prediction +# ---------- +# With the model trained, we can now predict on unseen data. For the +# purposes of this example, we will use the same dataset for prediction as +# for training. +# +# Since prediction is naively parallelizable, distributing it over multiple +# actors can measurably reduce the amount of time needed. + +inference_df = RayDMatrix(df, ignore=[LABEL_COLUMN, "partition"]) +results = predict( + bst, + inference_df, + ray_params=RayParams( + cpus_per_actor=cpus_per_actor_inference, + num_actors=num_actors_inference)) + +print(results) diff --git a/doc/examples/modin_xgboost/modin_xgboost.yaml b/doc/examples/modin_xgboost/modin_xgboost.yaml new file mode 100644 index 0000000000000..914cbdb207af2 --- /dev/null +++ b/doc/examples/modin_xgboost/modin_xgboost.yaml @@ -0,0 +1,24 @@ +cluster_name: modin_xgboost + +max_workers: 3 + +provider: + type: aws + region: us-west-1 + +auth: + ssh_user: ubuntu + +available_node_types: + 16_cpu_node: + min_workers: 3 + max_workers: 3 + node_config: + InstanceType: m5.4xlarge + ImageId: latest_dlami + resources: { } + +head_node_type: 16_cpu_node + +setup_commands: + - pip install -U jupyter ray xgboost_ray[default] modin pandas diff --git a/doc/examples/overview.rst b/doc/examples/overview.rst index 8555799094ef9..48cf3c2805918 100644 --- a/doc/examples/overview.rst +++ b/doc/examples/overview.rst @@ -61,6 +61,8 @@ Machine Learning Examples plot_lbfgs.rst plot_example-lm.rst plot_newsreader.rst + dask_xgboost/dask_xgboost.rst + modin_xgboost/modin_xgboost.rst .. customgalleryitem:: @@ -86,6 +88,14 @@ Machine Learning Examples :tooltip: Implementing a simple news reader using Ray. :description: :doc:`/auto_examples/plot_newsreader` +.. customgalleryitem:: + :tooltip: Train an XGBoost-Ray model using Dask for data processing. + :description: :doc:`/auto_examples/dask_xgboost/dask_xgboost` + +.. customgalleryitem:: + :tooltip: Train an XGBoost-Ray model using Modin for data processing. + :description: :doc:`/auto_examples/modin_xgboost/modin_xgboost` + .. raw:: html diff --git a/doc/kubernetes/ray-cluster.yaml b/doc/kubernetes/ray-cluster.yaml index 1b3da82e9ccaa..f4f493152608c 100644 --- a/doc/kubernetes/ray-cluster.yaml +++ b/doc/kubernetes/ray-cluster.yaml @@ -3,7 +3,7 @@ apiVersion: v1 kind: Service metadata: namespace: ray - name: ray-head + name: example-cluster-ray-head spec: ports: - name: client @@ -111,7 +111,7 @@ spec: imagePullPolicy: IfNotPresent command: ["/bin/bash", "-c", "--"] args: - - "ray start --num-cpus=$MY_CPU_REQUEST --address=$RAY_HEAD_SERVICE_HOST:$RAY_HEAD_SERVICE_PORT_REDIS --object-manager-port=12345 --node-manager-port=12346 --block" + - "ray start --num-cpus=$MY_CPU_REQUEST --address=$EXAMPLE_CLUSTER_RAY_HEAD_SERVICE_HOST:$EXAMPLE_CLUSTER_RAY_HEAD_SERVICE_PORT_REDIS --object-manager-port=12345 --node-manager-port=12346 --block" # This volume allocates shared memory for Ray to use for its plasma # object store. If you do not provide this, Ray will fall back to # /tmp which cause slowdowns if is not a shared memory volume. diff --git a/doc/source/advanced.rst b/doc/source/advanced.rst index 75ff25045592e..fa4ff9cffa65c 100644 --- a/doc/source/advanced.rst +++ b/doc/source/advanced.rst @@ -42,17 +42,23 @@ This often occurs for data loading and preprocessing. # hi there! # hi there! -Multi-node synchronization using ``SignalActor`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Multi-node synchronization using an Actor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -When you have multiple tasks that need to wait on some condition, you can use a ``SignalActor`` to coordinate. +When you have multiple tasks that need to wait on some condition or otherwise +need to synchronize across tasks & actors on a cluster, you can use a central +actor to coordinate among them. Below is an example of using a ``SignalActor`` +that wraps an ``asyncio.Event`` for basic synchronization. .. code-block:: python - # Also available via `from ray._private.test_utils import SignalActor` - import ray import asyncio + import ray + + ray.init() + + # We set num_cpus to zero because this actor will mostly just block on I/O. @ray.remote(num_cpus=0) class SignalActor: def __init__(self): @@ -73,7 +79,6 @@ When you have multiple tasks that need to wait on some condition, you can use a print("go!") - ray.init() signal = SignalActor.remote() tasks = [wait_and_go.remote(signal) for _ in range(4)] print("ready...") @@ -441,7 +446,7 @@ On Mac OS and Linux, Ray 1.4+ supports dynamically setting the runtime environme The ``runtime_env`` is a (JSON-serializable) dictionary that can be passed as an option to tasks and actors, and can also be passed to ``ray.init()``. The runtime environment defines the dependencies required for your workload. -You can specify a runtime environment for your whole job using ``ray.init()`` or Ray Client... +You can specify a runtime environment for your whole job using ``ray.init()`` or Ray Client: .. literalinclude:: ../examples/doc_code/runtime_env_example.py :language: python @@ -456,19 +461,20 @@ You can specify a runtime environment for your whole job using ``ray.init()`` or # Using Ray Client ray.init("ray://localhost:10001", runtime_env=runtime_env) -...or specify per-actor or per-task in the ``@ray.remote()`` decorator or by using ``.options()``: +Or specify per-actor or per-task in the ``@ray.remote()`` decorator or by using ``.options()``: .. literalinclude:: ../examples/doc_code/runtime_env_example.py :language: python :start-after: __per_task_per_actor_start__ :end-before: __per_task_per_actor_end__ +Note: specifying within the ``@ray.remote()`` decorator is currently unsupported while using Ray Client; please use ``.options()`` instead in this case. + The ``runtime_env`` is a Python dictionary including one or more of the following arguments: - ``working_dir`` (Path): Specifies the working directory for your job. This must be an existing local directory. It will be cached on the cluster, so the next time you connect with Ray Client you will be able to skip uploading the directory contents. - Furthermore, if you locally make a small change to your directory, the next time you connect only the updated part will be uploaded. - All Ray workers for your job will be started in their node's copy of this working directory. + All Ray workers for your job will be started in their node's local copy of this working directory. - Examples @@ -486,7 +492,7 @@ The ``runtime_env`` is a Python dictionary including one or more of the followin - Example: ``["my_file.txt", "path/to/dir", "*.log"]`` - ``pip`` (List[str] | str): Either a list of pip packages, or a string containing the path to a pip - `“requirements.txt” `_ file. The path may be an absolute path or a relative path. (Note: A relative path will be interpreted relative to ``working_dir`` if ``working_dir`` is specified.) + `“requirements.txt” `_ file. The path may be an absolute path or a relative path. This will be dynamically installed in the ``runtime_env``. To use a library like Ray Serve or Ray Tune, you will need to include ``"ray[serve]"`` or ``"ray[tune]"`` here. @@ -494,7 +500,7 @@ The ``runtime_env`` is a Python dictionary including one or more of the followin - Example: ``"./requirements.txt"`` -- ``conda`` (dict | str): Either (1) a dict representing the conda environment YAML, (2) a string containing the path to a +- ``conda`` (dict | str): Either (1) a dict representing the conda environment YAML, (2) a string containing the absolute or relative path to a `conda “environment.yml” `_ file, or (3) the name of a local conda environment already installed on each node in your cluster (e.g., ``"pytorch_p36"``). In the first two cases, the Ray and Python dependencies will be automatically injected into the environment to ensure compatibility, so there is no need to manually include them. @@ -506,12 +512,15 @@ The ``runtime_env`` is a Python dictionary including one or more of the followin - Example: ``"pytorch_p36"`` - Note: if specifying the path to an "environment.yml" file, you may provide an absolute path or a relative path. A relative path will be interpreted relative to ``working_dir`` if ``working_dir`` is specified. - ``env_vars`` (Dict[str, str]): Environment variables to set. - Example: ``{"OMP_NUM_THREADS": "32", "TF_WARNINGS": "none"}`` +- ``eager_install`` (bool): A boolean indicates whether to install runtime env eagerly before the workers are leased. This flag is set to false by default. + + - Example: ``{"eager_install": True}`` + The runtime environment is inheritable, so it will apply to all tasks/actors within a job and all child tasks/actors of a task or actor, once set. If a child actor or task specifies a new ``runtime_env``, it will be merged with the parent’s ``runtime_env`` via a simple dict update. diff --git a/doc/source/cluster/config.rst b/doc/source/cluster/config.rst index 7ba7e2ccbcbef..867e8398e6985 100644 --- a/doc/source/cluster/config.rst +++ b/doc/source/cluster/config.rst @@ -109,6 +109,8 @@ Provider :ref:`region `: str :ref:`availability_zone `: str :ref:`cache_stopped_nodes `: bool + :ref:`security_group `: + :ref:`Security Group ` .. group-tab:: Azure @@ -130,6 +132,20 @@ Provider :ref:`project_id `: str :ref:`cache_stopped_nodes `: bool +.. _cluster-configuration-security-group-type: + +Security Group +~~~~~~~~~~~~~~ + +.. tabs:: + .. group-tab:: AWS + + .. parsed-literal:: + + :ref:`GroupName `: str + :ref:`IpPermissions `: + - `IpPermission `_ + .. _cluster-configuration-node-types-type: Node types @@ -923,6 +939,52 @@ If enabled, nodes will be *stopped* when the cluster scales down. If disabled, n * **Type:** Boolean * **Default:** ``True`` +.. _cluster-configuration-security-group: + +``provider.security_group`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. tabs:: + .. group-tab:: AWS + + A security group that can be used to specify custom inbound rules. + + * **Required:** No + * **Importance:** Medium + * **Type:** :ref:`Security Group ` + + .. group-tab:: Azure + + Not available. + + .. group-tab:: GCP + + Not available. + + +.. _cluster-configuration-group-name: + +``security_group.GroupName`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The name of the security group. This name must be unique within the VPC. + +* **Required:** No +* **Importance:** Low +* **Type:** String +* **Default:** ``"ray-autoscaler-{cluster-name}"`` + +.. _cluster-configuration-ip-permissions: + +``security_group.IpPermissions`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The inbound rules associated with the security group. + +* **Required:** No +* **Importance:** Medium +* **Type:** `IpPermission `_ + .. _cluster-configuration-node-config: ``available_node_types..node_type.node_config`` diff --git a/doc/source/cluster/ray-client.rst b/doc/source/cluster/ray-client.rst index 550bb75480127..1b9099160f9ae 100644 --- a/doc/source/cluster/ray-client.rst +++ b/doc/source/cluster/ray-client.rst @@ -62,9 +62,9 @@ Step 1: set up your Ray cluster First, you'll want to create a remote Ray cluster. Follow the directions in :ref:`ref-cluster-quick-start` to do this. -If using the `Ray cluster launcher `_, the remote cluster will be listening on port ``10001`` of the head node. If necessary, you can modify this port by setting ``--ray-client-server-port`` to the ``ray start`` `command `_. +If using the :doc:`Ray cluster launcher `, the remote cluster will be listening on port ``10001`` of the head node. If necessary, you can modify this port by setting ``--ray-client-server-port`` to the ``ray start`` `command `_. -If not using the `Ray cluster launcher `_, you can start the "Ray Client Server" manually on the head node of your remote cluster by running the following: +If not using the :doc:`Ray cluster launcher `, you can start the "Ray Client Server" manually on the head node of your remote cluster by running the following: .. code-block:: bash @@ -77,6 +77,32 @@ Ensure that the Ray Client port on the head node is reachable from your local ma This means opening that port up by configuring security groups or other access controls (on `EC2 `_) or proxying from your local machine to the cluster (on `K8s `_). +.. tabs:: + .. group-tab:: AWS + + With the Ray cluster launcher, you can configure the security group + to allow inbound access by defining :ref:`cluster-configuration-security-group` + in your `cluster.yaml`. + + .. code-block:: yaml + + # An unique identifier for the head node and workers of this cluster. + cluster_name: minimal_security_group + + # Cloud-provider specific configuration. + provider: + type: aws + region: us-west-2 + security_group: + GroupName: ray_client_security_group + IpPermissions: + - FromPort: 10001 + ToPort: 10001 + IpProtocol: TCP + IpRanges: + # This will enable inbound access from ALL IPv4 addresses. + - CidrIp: 0.0.0.0/0 + Step 3: Run Ray code ~~~~~~~~~~~~~~~~~~~~ @@ -99,8 +125,43 @@ Now, connect to the Ray Cluster with the following and then use Ray like you nor #.... -Connect to multiple ray clusters --------------------------------- +Alternative Approach: SSH Port Forwarding +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +As an alternative to configuring inbound traffic rules, you can also set up +Ray Client via port forwarding. While this approach does require an open SSH +connection, it can be useful in a test environment where the +``head_node_host`` often changes. + +First, open up an SSH connection with your Ray cluster and forward the +listening port (``10001``). + +.. code-block:: bash + + $ ray up cluster.yaml + $ ray attach cluster.yaml -p 10001 + +Then, you can connect to the Ray cluster using ``localhost`` as the +``head_node_host``. + +.. code-block:: python + + import ray + + # This will connect to the cluster via the open SSH session. + ray.init("ray://localhost:10001") + + # Normal Ray code follows + @ray.remote + def do_work(x): + return x ** x + + do_work.remote(2) + + #.... + +Connect to multiple ray clusters (Experimental) +----------------------------------------------- Ray client allows connecting to multiple ray clusters in one Python process. To do this, just pass ``allow_multiple=True`` to ``ray.init``: diff --git a/doc/source/data/dask-on-ray.rst b/doc/source/data/dask-on-ray.rst index 6057b740db441..9e08977bdb16e 100644 --- a/doc/source/data/dask-on-ray.rst +++ b/doc/source/data/dask-on-ray.rst @@ -6,16 +6,16 @@ Dask on Ray `Dask `__ is a Python parallel computing library geared towards scaling analytics and scientific computing workloads. It provides `big data collections `__ that mimic the APIs of -the familiar `NumPy `__ and `Pandas `__ libraries, +the familiar `NumPy `__ and `Pandas `__ libraries, allowing those abstractions to represent -larger-than-memory data and/or allowing operations on that data to be run on a multi-machine cluster, +larger-than-memory data and/or allowing operations on that data to be run on a multi-machine cluster, while also providing automatic data parallelism, smart scheduling, and optimized operations. Operations on these collections create a task graph, which is executed by a scheduler. Ray provides a scheduler for Dask (`dask_on_ray`) which allows you to build data analyses using Dask's collections and execute -the underlying tasks on a Ray cluster. +the underlying tasks on a Ray cluster. `dask_on_ray` uses Dask's scheduler API, which allows you to specify any callable as the scheduler that you would like Dask to use to execute your @@ -30,8 +30,12 @@ workload. Using the Dask-on-Ray scheduler, the entire Dask ecosystem can be exec * - Ray Version - Dask Version + * - ``1.7.0`` + - ``2021.9.1`` + * - ``1.6.0`` + - ``2021.8.1`` * - ``1.5.0`` - - ``2021.7.0`` + - ``2021.7.0`` * - ``1.4.1`` - ``2021.6.1`` * - ``1.4.0`` @@ -82,7 +86,7 @@ In this case, there are two recommended setup. # Head node. Set `num_cpus=0` to avoid tasks are being scheduled on a head node. RAY_SCHEDULER_SPREAD_THRESHOLD=0.0 ray start --head --num-cpus=0 - # Worker node. + # Worker node. RAY_SCHEDULER_SPREAD_THRESHOLD=0.0 ray start --address=[head-node-address] Out-of-Core Data Processing @@ -101,10 +105,10 @@ Persist .. _dask-on-ray-persist: -Dask-on-Ray patches `dask.persist() -`__ in order to match `Dask +Dask-on-Ray patches `dask.persist() +`__ in order to match `Dask Distributed's persist semantics -`; namely, calling `dask.persist()` with a Dask-on-Ray +`; namely, calling `dask.persist()` with a Dask-on-Ray scheduler will submit the tasks to the Ray cluster and return Ray futures inlined in the Dask collection. This is nice if you wish to compute some base collection (such as a Dask array), followed by multiple different downstream computations (such as diff --git a/doc/source/data/dataset-pipeline.rst b/doc/source/data/dataset-pipeline.rst index 8b60ca3cb7985..d954df8051eb5 100644 --- a/doc/source/data/dataset-pipeline.rst +++ b/doc/source/data/dataset-pipeline.rst @@ -6,12 +6,12 @@ Overview Datasets execute their transformations synchronously in blocking calls. However, it can be useful to overlap dataset computations with output. This can be done with a `DatasetPipeline `__. -A DatasetPipeline is an unified iterator over a (potentially infinite) sequence of Ray Datasets. Conceptually it is similar to a `Spark DStream `__, but manages execution over a bounded amount of source data instead of an unbounded stream. Ray computes each dataset on-demand and stitches their output together into a single logical data iterator. DatasetPipeline implements most of the same transformation and output methods as Datasets (e.g., map, filter, split, iter_rows, to_torch, etc.). +A DatasetPipeline is an unified iterator over a (potentially infinite) sequence of Ray Datasets, each of which represents a *window* over the original data. Conceptually it is similar to a `Spark DStream `__, but manages execution over a bounded amount of source data instead of an unbounded stream. Ray computes each dataset window on-demand and stitches their output together into a single logical data iterator. DatasetPipeline implements most of the same transformation and output methods as Datasets (e.g., map, filter, split, iter_rows, to_torch, etc.). Creating a DatasetPipeline ~~~~~~~~~~~~~~~~~~~~~~~~~~ -A DatasetPipeline can be constructed in two ways: either by pipelining the execution of an existing Dataset (via ``Dataset.pipeline``), or generating repeats of an existing Dataset (via ``Dataset.repeat``). Similar to Datasets, you can freely pass DatasetPipelines between Ray tasks, actors, and libraries. Get started with this synthetic data example: +A DatasetPipeline can be constructed in two ways: either by pipelining the execution of an existing Dataset (via ``Dataset.window``), or generating repeats of an existing Dataset (via ``Dataset.repeat``). Similar to Datasets, you can freely pass DatasetPipelines between Ray tasks, actors, and libraries. Get started with this synthetic data example: .. code-block:: python @@ -30,16 +30,16 @@ A DatasetPipeline can be constructed in two ways: either by pipelining the execu base = ray.data.range(1000000) print(base) # -> Dataset(num_blocks=200, num_rows=1000000, schema=) - pipe = base.pipeline(parallelism=10) + pipe = base.window(blocks_per_window=10) print(pipe) - # -> DatasetPipeline(length=20, num_stages=1) + # -> DatasetPipeline(num_windows=20, num_stages=1) # Applying transforms to pipelines adds more pipeline stages. pipe = pipe.map(func1) pipe = pipe.map(func2) pipe = pipe.map(func3) print(pipe) - # -> DatasetPipeline(length=20, num_stages=4) + # -> DatasetPipeline(num_windows=20, num_stages=4) # Output can be pulled from the pipeline concurrently with its execution. num_rows = 0 @@ -53,8 +53,7 @@ A DatasetPipeline can be constructed in two ways: either by pipelining the execu print("Total num rows", num_rows) # -> Total num rows 1000000 - -You can also create a DatasetPipeline from a custom iterator over dataset creators using ``DatasetPipeline.from_iterable``. For example, this is how you would implement ``Dataset.repeat`` and ``Dataset.pipeline`` using ``from_iterable``: +You can also create a DatasetPipeline from a custom iterator over dataset creators using ``DatasetPipeline.from_iterable``. For example, this is how you would implement ``Dataset.repeat`` and ``Dataset.window`` using ``from_iterable``: .. code-block:: python @@ -66,10 +65,52 @@ You can also create a DatasetPipeline from a custom iterator over dataset creato pipe = DatasetPipeline.from_iterable( [lambda: source, lambda: source, lambda: source, lambda: source]) - # Equivalent to ray.data.range(1000).pipeline(parallelism=10) + # Equivalent to ray.data.range(1000).window(blocks_per_window=10) splits = ray.data.range(1000, parallelism=200).split(20) pipe = DatasetPipeline.from_iterable([lambda s=s: s for s in splits]) +Per-Window Transformations +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +While most Dataset operations are per-row (e.g., map, filter), some operations apply to the Dataset as a whole (e.g., sort, shuffle). When applied to a pipeline, holistic transforms like shuffle are applied separately to each window in the pipeline: + +.. code-block:: python + + # Example of randomly shuffling each window of a pipeline. + ray.data.range(5).repeat(2).random_shuffle_each_window().show_windows() + # -> + # === Window 0 === + # 4 + # 3 + # 1 + # 0 + # 2 + # === Window 1 === + # 2 + # 1 + # 4 + # 0 + # 3 + +You can also apply arbitrary transformations to each window using ``DatasetPipeline.foreach_window()``: + +.. code-block:: python + + # Equivalent transformation using .foreach_window() + ray.data.range(5).repeat(2).foreach_window(lambda w: w.random_shuffle()).show_windows() + # -> + # === Window 0 === + # 1 + # 0 + # 4 + # 2 + # 3 + # === Window 1 === + # 4 + # 2 + # 0 + # 3 + # 1 Example: Pipelined Batch Inference ---------------------------------- @@ -109,28 +150,28 @@ Ignoring the output, the above script has three separate stages: loading, prepro Enabling Pipelining ~~~~~~~~~~~~~~~~~~~ -We can optimize this by *pipelining* the execution of the dataset with the ``.pipeline()`` call, which returns a DatasetPIpeline instead of a Dataset object. The pipeline supports similar transformations to the original Dataset: +We can optimize this by *pipelining* the execution of the dataset with the ``.window()`` call, which returns a DatasetPipeline instead of a Dataset object. The pipeline supports similar transformations to the original Dataset: .. code-block:: python # Convert the Dataset into a DatasetPipeline. pipe: DatasetPipeline = ray.data \ .read_binary_files("s3://bucket/image-dir") \ - .pipeline(parallelism=2) + .window(blocks_per_window=2) # The remainder of the steps do not change. pipe = pipe.map(preprocess) pipe = pipe.map_batches(BatchInferModel, compute="actors", batch_size=256, num_gpus=1) pipe.write_json("/tmp/results") -Here we specified ``parallelism=2``, which means that the Dataset is split into smaller sub-Datasets of two blocks each. Each transformation or *stage* of the pipeline is operating over these two-block Datasets in parallel. This means batch inference processing can start as soon as two blocks are read and preprocessed, greatly reducing the GPU idle time: +Here we specified ``blocks_per_window=2``, which means that the Dataset is split into smaller sub-Datasets of two blocks each. Each transformation or *stage* of the pipeline is operating over these two-block Datasets in parallel. This means batch inference processing can start as soon as two blocks are read and preprocessed, greatly reducing the GPU idle time: .. image:: dataset-pipeline-2.svg Tuning Parallelism ~~~~~~~~~~~~~~~~~~ -Tune the throughput vs latency of your pipeline with the ``parallelism`` setting. As a rule of thumb, higher parallelism settings perform better, however ``parallelism == num_blocks`` effectively disables pipelining, since the DatasetPipeline will only contain a single Dataset. The other extreme is setting ``parallelism=1``, which minimizes the latency to initial output but only allows one concurrent transformation task per stage: +Tune the throughput vs latency of your pipeline with the ``blocks_per_window`` setting. As a rule of thumb, higher parallelism settings perform better, however ``blocks_per_window == num_blocks`` effectively disables pipelining, since the DatasetPipeline will only contain a single Dataset. The other extreme is setting ``blocks_per_window=1``, which minimizes the latency to initial output but only allows one concurrent transformation task per stage: .. image:: dataset-pipeline-3.svg @@ -155,7 +196,7 @@ Transformations made prior to the Dataset prior to the call to ``.repeat()`` are pipe: DatasetPipeline = ray.data \ .read_datasource(...) \ .repeat() \ - .random_shuffle() + .random_shuffle_each_window() @ray.remote(num_gpus=1) def train_func(pipe: DatasetPipeline): @@ -184,7 +225,7 @@ Similar to how you can ``.split()`` a Dataset, you can also split a DatasetPipel pipe: DatasetPipeline = ray.data \ .read_parquet("s3://bucket/dir") \ .repeat() \ - .random_shuffle() + .random_shuffle_each_window() @ray.remote(num_gpus=1) class TrainingWorker: @@ -201,3 +242,55 @@ Similar to how you can ``.split()`` a Dataset, you can also split a DatasetPipel **Pipeline**: .. image:: dataset-repeat-2.svg + +Changing Pipeline Structure +--------------------------- + +Sometimes, you may want to change the structure of an existing pipeline. For example, after generating a pipeline with ``ds.window(k)``, you may want to repeat that windowed pipeline ``n`` times. This can be done with ``ds.window(k).repeat(n)``. As another example, suppose you have a repeating pipeline generated with ``ds.repeat(n)``. The windowing of that pipeline can be changed with ``ds.repeat(n).rewindow(k)``. Note the subtle difference in the two examples: the former is repeating a windowed pipeline that has a base window size of ``k``, while the latter is re-windowing a pipeline of initial window size of ``ds.num_blocks()``. The latter may produce windows that span multiple copies of the same original data: + +.. code-block:: python + + # Window followed by repeat. + ray.data.range(5) \ + .window(blocks_per_window=2) \ + .repeat(2) \ + .show_windows() + # -> + # === Window 0 === + # 0 + # 1 + # === Window 1 === + # 2 + # 3 + # === Window 2 === + # 4 + # === Window 3 === + # 0 + # 1 + # === Window 4 === + # 2 + # 3 + # === Window 5 === + # 4 + + # Repeat followed by window. + ray.data.range(5) \ + .repeat(2) \ + .rewindow(blocks_per_window=2) \ + .show_windows() + # -> + # === Window 0 === + # 0 + # 1 + # === Window 1 === + # 2 + # 3 + # === Window 2 === + # 4 + # 0 + # === Window 3 === + # 1 + # 2 + # === Window 4 === + # 3 + # 4 diff --git a/doc/source/data/dataset-tensor-support.rst b/doc/source/data/dataset-tensor-support.rst index b8a4ad68eed4e..d2d3ebf40c6f1 100644 --- a/doc/source/data/dataset-tensor-support.rst +++ b/doc/source/data/dataset-tensor-support.rst @@ -3,66 +3,34 @@ Dataset Tensor Support ====================== -Tensor-typed values -------------------- +Tables with tensor columns +-------------------------- + +Datasets supports tables with fixed-shape tensor columns, where each element in the column is a tensor (n-dimensional array) with the same shape. As an example, this allows you to use Pandas and Ray Datasets to read, write, and manipulate e.g., images. All conversions between Pandas, Arrow, and Parquet, and all application of aggregations/operations to the underlying image ndarrays are taken care of by Ray Datasets. + +With our Pandas extension type, :class:`TensorDtype `, and extension array, :class:`TensorArray `, you can do familiar aggregations and arithmetic, comparison, and logical operations on a DataFrame containing a tensor column and the operations will be applied to the underlying tensors as expected. With our Arrow extension type, :class:`ArrowTensorType `, and extension array, :class:`ArrowTensorArray `, you'll be able to import that DataFrame into Ray Datasets and read/write the data from/to the Parquet format. + +Automatic conversion between the Pandas and Arrow extension types/arrays keeps the details under-the-hood, so you only have to worry about casting the column to a tensor column using our Pandas extension type when first ingesting the table into a ``Dataset``, whether from storage or in-memory. All table operations downstream from that cast should work automatically. -Datasets support tensor-typed values, which are represented in-memory as Arrow tensors (i.e., np.ndarray format). Tensor datasets can be read from and written to ``.npy`` files. Here are some examples: +Single-column tensor datasets +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The most basic case is when a dataset only has a single column, which is of tensor type. This kind of dataset can be created with ``.range_tensor()``, and can be read from and written to ``.npy`` files. Here are some examples: .. code-block:: python # Create a Dataset of tensor-typed values. ds = ray.data.range_tensor(10000, shape=(3, 5)) # -> Dataset(num_blocks=200, num_rows=10000, - # schema=) - - ds.map_batches(lambda t: t + 2).show(2) - # -> [[2 2 2 2 2] - # [2 2 2 2 2] - # [2 2 2 2 2]] - # [[3 3 3 3 3] - # [3 3 3 3 3] - # [3 3 3 3 3]] + # schema={value: }) # Save to storage. - ds.write_numpy("/tmp/tensor_out") + ds.write_numpy("/tmp/tensor_out", column="value") # Read from storage. ray.data.read_numpy("/tmp/tensor_out") # -> Dataset(num_blocks=200, num_rows=?, - # schema=) - -Tensor datasets are also created whenever an array type is returned from a map function: - -.. code-block:: python - - # Create a dataset of Python integers. - ds = ray.data.range(10) - # -> Dataset(num_blocks=10, num_rows=10, schema=) - - # It is now converted into a Tensor dataset. - ds = ds.map_batches(lambda x: np.array(x)) - # -> Dataset(num_blocks=10, num_rows=10, - # schema=) - -Tensor datasets can also be created from NumPy ndarrays that are already stored in the Ray object store: - -.. code-block:: python - - import numpy as np - - # Create a Dataset from a list of NumPy ndarray objects. - arr1 = np.arange(0, 10) - arr2 = np.arange(10, 20) - ds = ray.data.from_numpy([ray.put(arr1), ray.put(arr2)]) - -Tables with tensor columns --------------------------- - -In addition to tensor datasets, Datasets also supports tables with fixed-shape tensor columns, where each element in the column is a tensor (n-dimensional array) with the same shape. As an example, this allows you to use both Pandas and Ray Datasets to read, write, and manipulate a table with a column of e.g. images (2D arrays), with all conversions between Pandas, Arrow, and Parquet, and all application of aggregations/operations to the underlying image ndarrays, being taken care of by Ray Datasets. - -With our Pandas extension type, :class:`TensorDtype `, and extension array, :class:`TensorArray `, you can do familiar aggregations and arithmetic, comparison, and logical operations on a DataFrame containing a tensor column and the operations will be applied to the underlying tensors as expected. With our Arrow extension type, :class:`ArrowTensorType `, and extension array, :class:`ArrowTensorArray `, you'll be able to import that DataFrame into Ray Datasets and read/write the data from/to the Parquet format. - -Automatic conversion between the Pandas and Arrow extension types/arrays keeps the details under-the-hood, so you only have to worry about casting the column to a tensor column using our Pandas extension type when first ingesting the table into a ``Dataset``, whether from storage or in-memory. All table operations downstream from that cast should work automatically. + # schema={value: }) Reading existing serialized tensor columns ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -87,7 +55,7 @@ If you already have a Parquet dataset with columns containing serialized tensors # Write the dataset to Parquet. The tensor column will be written as an # array of opaque byte blobs. - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(path) # Read the Parquet files into a new Dataset, with the serialized tensors @@ -117,7 +85,7 @@ If your serialized tensors don't fit the above constraints (e.g. they're stored # Write the dataset to Parquet. The tensor column will be written as an # array of opaque byte blobs. - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(path) # Manually deserialize the tensor pickle bytes and cast to our tensor @@ -150,7 +118,7 @@ Now that the tensor column is properly typed and in a ``Dataset``, we can perfor # Arrow and Pandas is now aware of this tensor column, so we can do the # typical DataFrame operations on this column. - ds = ds.map_batches(lambda x: 2 * (x + 1), format="pandas") + ds = ds.map_batches(lambda x: 2 * (x + 1), batch_format="pandas") # -> Map Progress: 100%|████████████████████| 200/200 [00:00<00:00, 1123.54it/s] print(ds) # -> Dataset( @@ -244,7 +212,7 @@ If working with in-memory Pandas DataFrames that you want to analyze, manipulate # In addition to doing Pandas operations on the tensor column, # you can now put the DataFrame directly into a Dataset. - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) # Internally, this column is represented with the corresponding # Arrow tensor extension type. print(ds.schema()) @@ -259,7 +227,7 @@ If working with in-memory Pandas DataFrames that you want to analyze, manipulate # -> one: int64 # two: extension> - read_df = ray.get(read_ds.to_pandas())[0] + read_df = read_ds.to_pandas() print(read_df.dtypes) # -> one int64 # two TensorDtype diff --git a/doc/source/data/dataset.rst b/doc/source/data/dataset.rst index 7142691e5df45..20018765c1a69 100644 --- a/doc/source/data/dataset.rst +++ b/doc/source/data/dataset.rst @@ -16,7 +16,7 @@ Ray Datasets are the standard way to load and exchange data in Ray libraries and Concepts -------- -Ray Datasets implement `Distributed Arrow `__. A Dataset consists of a list of Ray object references to *blocks*. Each block holds a set of items in either an `Arrow table `__, `Arrow tensor `__, or a Python list (for Arrow incompatible objects). Having multiple blocks in a dataset allows for parallel transformation and ingest of the data. +Ray Datasets implement `Distributed Arrow `__. A Dataset consists of a list of Ray object references to *blocks*. Each block holds a set of items in either an `Arrow table `__ or a Python list (for Arrow incompatible objects). Having multiple blocks in a dataset allows for parallel transformation and ingest of the data. The following figure visualizes a Dataset that has three Arrow table blocks, each block holding 1000 rows each: @@ -145,6 +145,10 @@ Datasource Compatibility Matrices Creating Datasets ----------------- +.. tip:: + + Run ``pip install ray[data]`` to get started! + Get started by creating Datasets from synthetic data using ``ray.data.range()`` and ``ray.data.from_items()``. Datasets can hold either plain Python objects (schema is a Python type), or Arrow records (schema is Arrow). .. code-block:: python @@ -198,7 +202,7 @@ Finally, you can create a ``Dataset`` from existing data in the Ray object store # Create a Dataset from a list of Pandas DataFrame objects. pdf = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([ray.put(pdf)]) + ds = ray.data.from_pandas([pdf]) # Create a Dataset from a Dask-on-Ray DataFrame. dask_df = dd.from_pandas(pdf, npartitions=10) diff --git a/doc/source/data/package-ref.rst b/doc/source/data/package-ref.rst index 0af38ba8297c7..afdace98bf719 100644 --- a/doc/source/data/package-ref.rst +++ b/doc/source/data/package-ref.rst @@ -15,11 +15,13 @@ Creating a Dataset .. autofunction:: ray.data.read_datasource .. autofunction:: ray.data.from_items .. autofunction:: ray.data.from_arrow +.. autofunction:: ray.data.from_arrow_refs .. autofunction:: ray.data.from_spark .. autofunction:: ray.data.from_dask .. autofunction:: ray.data.from_modin .. autofunction:: ray.data.from_mars .. autofunction:: ray.data.from_pandas +.. autofunction:: ray.data.from_pandas_refs .. autofunction:: ray.data.from_numpy Dataset API diff --git a/doc/source/index.rst b/doc/source/index.rst index 2024802af37d7..492adbd42d6ee 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -277,8 +277,8 @@ Papers :caption: Ray Data data/dataset.rst - data/dataset-tensor-support.rst data/dataset-pipeline.rst + data/dataset-tensor-support.rst data/package-ref.rst data/dask-on-ray.rst data/mars-on-ray.rst @@ -338,6 +338,7 @@ Papers raysgd/v2/examples.rst raysgd/v2/architecture.rst raysgd/v2/api.rst + raysgd/v2/migration-guide.rst RaySGD v1: Distributed Training Wrappers .. toctree:: @@ -365,7 +366,7 @@ Papers .. toctree:: :hidden: :maxdepth: -1 - :caption: Contributing + :caption: Contributor Guide getting-involved.rst development.rst diff --git a/doc/source/raysgd/raysgd.rst b/doc/source/raysgd/raysgd.rst index 87696e68d6535..55ddcdb389fc1 100644 --- a/doc/source/raysgd/raysgd.rst +++ b/doc/source/raysgd/raysgd.rst @@ -6,7 +6,7 @@ RaySGD: Distributed Training Wrappers .. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. - See the documentation :ref:`here `. + See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. RaySGD is a lightweight library for distributed deep learning, providing thin wrappers around PyTorch and TensorFlow native modules for data parallel training. diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index 5e9c1ce099141..635d003e55032 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -3,13 +3,16 @@ Distributed PyTorch =================== +.. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. + See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. + The RaySGD ``TorchTrainer`` simplifies distributed model training for PyTorch. .. image:: raysgd-actors.svg :align: center -.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! +.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! The ``TorchTrainer`` is a wrapper around ``torch.distributed.launch`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to needing to wrap your training code in bash scripts. diff --git a/doc/source/raysgd/raysgd_tensorflow.rst b/doc/source/raysgd/raysgd_tensorflow.rst index f18d7f9ec3924..2cbf01da2e3c3 100644 --- a/doc/source/raysgd/raysgd_tensorflow.rst +++ b/doc/source/raysgd/raysgd_tensorflow.rst @@ -1,6 +1,9 @@ Distributed TensorFlow ====================== +.. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. + See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. + RaySGD's ``TFTrainer`` simplifies distributed model training for Tensorflow. The ``TFTrainer`` is a wrapper around ``MultiWorkerMirroredStrategy`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to write custom logic of setting environments and starting separate processes. Under the hood, ``TFTrainer`` will create *replicas* of your model (controlled by ``num_replicas``), each of which is managed by a Ray actor. @@ -8,7 +11,7 @@ Under the hood, ``TFTrainer`` will create *replicas* of your model (controlled b .. image:: raysgd-actors.svg :align: center -.. tip:: We need your feedback! RaySGD is currently early in its development, and we're hoping to get feedback from people using or considering it. We'd love `to get in touch `_! +.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! ---------- diff --git a/doc/source/raysgd/raysgd_tune.rst b/doc/source/raysgd/raysgd_tune.rst index cacaea0a20c4e..740ff78b0390c 100644 --- a/doc/source/raysgd/raysgd_tune.rst +++ b/doc/source/raysgd/raysgd_tune.rst @@ -3,6 +3,9 @@ RaySGD Hyperparameter Tuning ============================ +.. warning:: This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD is in alpha as of Ray 1.7. + See the documentation :ref:`here `. To migrate from v1 to v2 you can follow the :ref:`migration guide `. + RaySGD integrates with :ref:`Ray Tune ` to easily run distributed hyperparameter tuning experiments with your RaySGD Trainer. PyTorch diff --git a/doc/source/raysgd/v2/api.rst b/doc/source/raysgd/v2/api.rst index fc3028bc9fc19..97b48a26b11ce 100644 --- a/doc/source/raysgd/v2/api.rst +++ b/doc/source/raysgd/v2/api.rst @@ -22,10 +22,8 @@ SGDIterator .. _sgd-api-backend-config: -BackendConfig -------------- - -.. autoclass:: ray.sgd.BackendConfig +Backend Configurations +---------------------- .. _sgd-api-torch-config: @@ -48,10 +46,14 @@ HorovodConfig .. autoclass:: ray.sgd.HorovodConfig + +Callbacks +--------- + .. _sgd-api-callback: SGDCallback ------------ +~~~~~~~~~~~ .. autoclass:: ray.sgd.SGDCallback :members: @@ -61,19 +63,22 @@ SGDCallback JsonLoggerCallback ~~~~~~~~~~~~~~~~~~ -.. autoclass:: ray.sgd.JsonLoggerCallback +.. autoclass:: ray.sgd.callbacks.JsonLoggerCallback .. _sgd-api-tbx-logger-callback: TBXLoggerCallback ~~~~~~~~~~~~~~~~~ -.. autoclass:: ray.sgd.TBXLoggerCallback +.. autoclass:: ray.sgd.callbacks.TBXLoggerCallback + +Checkpointing +------------- .. _sgd-api-checkpoint-strategy: CheckpointStrategy ------------------- +~~~~~~~~~~~~~~~~~~ .. autoclass:: ray.sgd.CheckpointStrategy diff --git a/doc/source/raysgd/v2/examples.rst b/doc/source/raysgd/v2/examples.rst index a35f394c7593c..3edee334aea2a 100644 --- a/doc/source/raysgd/v2/examples.rst +++ b/doc/source/raysgd/v2/examples.rst @@ -61,6 +61,9 @@ Ray Tune Integration Examples * :doc:`/raysgd/v2/examples/tune_tensorflow_mnist_example`: End-to-end example for tuning a TensorFlow model. +* :doc:`/raysgd/v2/examples/tune_cifar_pytorch_pbt_example`: + End-to-end example for tuning a PyTorch model with PBT. + .. TODO implement these examples! diff --git a/doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst b/doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst new file mode 100644 index 0000000000000..31aabc7ca78ab --- /dev/null +++ b/doc/source/raysgd/v2/examples/tune_cifar_pytorch_pbt_example.rst @@ -0,0 +1,6 @@ +:orphan: + +tune_cifar_pytorch_pbt_example +============================== + +.. literalinclude:: /../../python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py diff --git a/doc/source/raysgd/v2/migration-guide.rst b/doc/source/raysgd/v2/migration-guide.rst new file mode 100644 index 0000000000000..08effe4b25e98 --- /dev/null +++ b/doc/source/raysgd/v2/migration-guide.rst @@ -0,0 +1,393 @@ +.. _sgd-migration: + +Migrating from Ray SGD v1 +========================= + +In Ray 1.7, we are rolling out a new and more streamlined version of Ray SGD. Ray SGD v2 focuses on usability and composability - it has a much simpler API, has support for more deep learning backends, integrates better with other libraries in the Ray ecosystem, and will continue to be actively developed with more features. + +This guide will help you easily migrate existing code from Ray SGD v1 to Ray SGD v2. If you are new to Ray SGD as a whole, you should get started with :ref:`Ray SGD v2 directly `. + +For a full list of features that Ray SGD v2 provides, please check out the :ref:`user guide`. + +.. note:: If there are any issues or anything missing with this guide or any feedback on Ray SGD v2 overall, please file a `Github issue on the Ray repo `_! + +What are the API differences? +----------------------------- + +There are 3 primary API differences between Ray SGD v1 and v2. + +1. There is a single ``Trainer`` interface for all backends (torch, tensorflow, horovod), and the backend is simply specified via an argument: ``Trainer(backend="torch")``\ , ``Trainer(backend="horovod")``\ , etc. Any features that we add to Ray SGD will be supported for all backends, and there won't be any API divergence like there was with a separate ``TorchTrainer`` and ``TFTrainer``. +2. The ``TrainingOperator`` and creator functions are replaced by a more natural user-defined training function. You no longer have to make your training logic fit into a restrictive interface. In Ray SGD v2, you simply have to provide a training function that describes the full logic for your training execution and this will be distributed by Ray SGD v2. + + .. code-block:: python + + from torch.nn.parallel import DistributedDataParallel + from torch import nn, optim + + # Torch Example + def train_func_distributed(): + num_epochs = 3 + model = NeuralNetwork() + model = DistributedDataParallel(model) + loss_fn = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.1) + + for epoch in range(num_epochs): + output = model(input) + loss = loss_fn(output, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + print(f"epoch: {epoch}, loss: {loss.item()}") + + from ray.sgd import Trainer + + trainer = Trainer(backend="torch", num_workers=4) + trainer.start() + results = trainer.run(train_func_distributed) + trainer.shutdown() + +Currently, this means that you are now responsible for modifying your code to support distributed training (specifying ``DistributedDataParallel`` for ``torch`` or ``MultiWorkerMirroredStrategy`` for ``tensorflow``) as opposed to having this be automatically handled internally. However, we have plans to provide utilities that you can use to automatically handle these recipes for you. + +3. Rather than iteratively calling ``trainer.train()`` or ``trainer.validate()`` for each epoch, in Ray SGD v2 the training function defines the full training execution and is run via ``trainer.run(train_func)``. + +In the following sections, we will guide you through the steps to migrate: + +1. :ref:`sgd-migration-logic` +2. :ref:`Interacting with Trainer state (intermediate metrics, checkpointing) ` +3. :ref:`Hyperparameter Tuning with Ray Tune ` + +.. _sgd-migration-logic: + +Training Logic +-------------- +The main change you will have to make is how you define your training logic. In Ray SGD v1, the API for defining training logic differed for `TorchTrainer` vs. `TFTrainer`, so the steps to migrate will be different for each of these. + +PyTorch +~~~~~~~ +In v1, the training logic is defined through the ``train_epoch`` and ``train_batch`` methods of a ``TrainingOperator`` class which is passed into the ``TorchTrainer``. To migrate to Ray SGD v2, there are 2 options: + +1. If you felt the ``TrainingOperator`` is too unnecessary and complex, or you had to customize it extensively, you can define your own training function. +2. If you liked having your training logic in the ``TrainingOperator``, you can continue to use the ``TrainingOperator`` with Ray SGD v2. + +**Alternative 1: Custom Training Function** +You can define your own custom training function, and use only the parts from ``TrainingOperator.train_epoch``, ``TrainingOperator.setup``, and ``TrainingOperator.validate`` that are necessary for your application. + +You can see a full example on how to :ref:`port over regular PyTorch DDP code to Ray SGD here ` + +**Alternative 2: Continue to use TrainingOperator** +Alternatively, if you liked having the ``TrainingOperator``, you can define a training function that instantiates your `TrainingOperator` and you can call methods directly on the operator object. + +So instead of + +.. code-block:: python + + from ray.util.sgd import TrainingOperator, TorchTrainer + + class MyTrainingOperator(TrainingOperator): + ... + + trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=4, use_gpu=True) + + num_epochs=10 + for _ in range(num_epochs): + trainer.train() + trainer.validate() + + final_model = trainer.get_model() + + +you would do + +.. code-block:: python + + from ray.util.sgd import TrainingOperator + from ray.sgd import Trainer + from ray import sgd + + class MyTrainingOperator(TrainingOperator): + ... + + def train_func(config): + device = torch.device(f"cuda:{sgd.local_rank()}" if + torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + torch.cuda.set_device(device) + + # Set the args to whatever values you want. + training_operator = MyTrainingOperator( + config=config, + world_rank=sgd.world_rank(), + local_rank=sgd.local_rank(), + is_distributed=True, + device=device, + use_gpu=True, + wrap_ddp=True, + add_dist_sampler=True + + training_operator.setup(config) + + for idx in range(config["num_epochs"]): + train_loader = training_operator._get_train_loader() + # If using DistributedSampler, set the epoch here. + train_loader.set_epoch(idx) + training_operator.train_epoch(epoch_idx=idx, iter(train_loader)) + + validation_loader = training_operator._get_validation_loader() + training_operator.validate(iterator=iter(validation_loader)) + + if sgd.world_rank() == 0: + return training_operator._get_original_models() + else: + return None + + trainer = Trainer(backend="torch", num_workers=4, use_gpu=True) + trainer.start() + results = trainer.run(train_func, config={"num_epochs": 10}) + final_model = results[0] + +Tensorflow +~~~~~~~~~~ + +The API for ``TFTrainer`` uses creator functions instead of a ``TrainingOperator`` to define the training logic. To port over Ray SGD v1 Tensorflow code to v2 you can do the following: + +.. code-block:: python + + from tensorflow.distribute import MultiWorkerMirroredStrategy + + from ray.sgd import Trainer + from ray import sgd + + def train_func(config): + train_dataset, val_dataset = data_creator(config) + strategy = MultiWorkerMirroredStrategy() + with strategy.scope(): + model = model_creator(config) + + for epoch_idx in range(config["num_epochs"]): + model.fit(train_dataset) + + if sgd.world_rank() == 0: + return model + else: + return None + + trainer = Trainer(backend="tensorflow", num_workers=4, config={"num_epochs": 3, ...}) + trainer.start() + model = trainer.run(train_func)[0] + +You can see a full example :ref:`here `. + +.. _sgd-migration-trainer: + +Interacting with the ``Trainer`` +-------------------------------- + +In Ray SGD v1, you can iteratively call ``trainer.train()`` or ``trainer.validate()`` for each epoch, and can then interact with the trainer to get certain state (model, checkpoints, results, etc.). In Ray SGD v2, this is replaced by a single training function that defines the full training & validation loop for all epochs. + +There are 3 ways to get state during or after the training execution: + + +#. Return values from your training function +#. Intermediate results via ``sgd.report()`` +#. Saving & loading checkpoints via ``sgd.save_checkpoint()`` and ``sgd.load_checkpoint()`` + +Return Values +~~~~~~~~~~~~~ + +To get any state from training *after* training has completed, you can simply return it from your training function. The return values from each the workers will be added to a list and returned from the ``trainer.run()`` call. + +For example, to get the final model: + +**SGD v1** + +.. code-block:: python + + from ray.util.sgd import TorchTrainer, TrainingOperator + + class MyTrainingOperator(TrainingOperator): + ... + + trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=2) + + trainer.train() + + trained_model = trainer.get_model() + +**SGD v2** + +.. code-block:: python + + from ray.sgd import Trainer + + def train_func(): + model = Net() + trainer_loader = MyDataset() + for batch in train_loader: + model.train(batch) + + return model + + trainer = Trainer(backend="torch") + trainer.start() + results = trainer.run(train_func, num_workers=2) + assert len(results) == 2 + trained_model = results[0] + +Intermediate Reporting +~~~~~~~~~~~~~~~~~~~~~~ + +If you want to access any values *during* the training process, you can do so via ``sgd.report()``. You can pass in any values to ``sgd.report()`` and these values from all workers will be sent to any callbacks passed into your ``Trainer``. + +**SGD v1** + +.. code-block:: python + + from ray.util.sgd import TorchTrainer, TrainingOperator + + class MyTrainingOperator(TrainingOperator): + ... + + trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=2) + + for _ in range(3): + print(trainer.train(reduce_results=False)) + + +**SGD v2** + +.. code-block:: python + + from ray import sgd + from ray.sgd Trainer + from ray.sgd.callbacks import SGDCallback + from typing import List, Dict + + class PrintingCallback(SGDCallback): + def handle_result(self, results: List[Dict], **info): + print(results) + + def train_func(): + for i in range(3): + sgd.report(epoch=i) + + trainer = Trainer(backend="torch", num_workers=2) + trainer.start() + result = trainer.run( + train_func, + callbacks=[PrintingCallback()] + ) + # [{'epoch': 0, '_timestamp': 1630471763, '_time_this_iter_s': 0.0020279884338378906, '_training_iteration': 1}, {'epoch': 0, '_timestamp': 1630471763, '_time_this_iter_s': 0.0014922618865966797, '_training_iteration': 1}] + # [{'epoch': 1, '_timestamp': 1630471763, '_time_this_iter_s': 0.0008401870727539062, '_training_iteration': 2}, {'epoch': 1, '_timestamp': 1630471763, '_time_this_iter_s': 0.0007486343383789062, '_training_iteration': 2}] + # [{'epoch': 2, '_timestamp': 1630471763, '_time_this_iter_s': 0.0014500617980957031, '_training_iteration': 3}, {'epoch': 2, '_timestamp': 1630471763, '_time_this_iter_s': 0.0015292167663574219, '_training_iteration': 3}] + trainer.shutdown() + +See the :ref:`v2 User Guide ` for more details. + +Checkpointing +~~~~~~~~~~~~~ + +Finally, you can also use ``sgd.save_checkpoint()`` and ``sgd.load_checkpoint()`` to write checkpoints to disk during the training process, and to load from the most recently saved checkpoint in the case of node failures. + +See the :ref:`Checkpointing ` and :ref:`Fault Tolerance & Elastic Training ` sections on the user guide for more info. + +For example, in order to save checkpoints after every epoch: + +**SGD v1** + +.. code-block:: python + + from ray.util.sgd import TorchTrainer, TrainingOperator + + class MyTrainingOperator(TrainingOperator): + ... + + trainer = TorchTrainer(training_operator_cls=MyTrainingOperator, num_workers=2) + + for _ in range(3): + trainer.train() + trainer.save_checkpoint(checkpoint_dir="~/ray_results") + + +**SGD v2** + +.. code-block:: python + + from ray.sgd import Trainer + from ray import sgd + + def train_func(): + model = Net() + trainer_loader = MyDataset() + for i in range(3): + for batch in train_loader: + model.train(batch) + sgd.save_checkpoint(epoch=i, model=model.state_dict())) + + trainer = Trainer(backend="torch") + trainer.start() + trainer.run(train_func, num_workers=2) + + +.. _sgd-migration-tune: + +Hyperparameter Tuning with Ray Tune +----------------------------------- + +Ray SGD v2 also comes with an easier to use interface for Hyperparameter Tuning with Ray Tune using Tune's function API instead of its Class API. In particular, it is much easier to define custom procedures because the logic is entirely defined by your training function. + +There is a 1:1 mapping between rank 0 worker's ``sgd.report()``\ , ``sgd.save_checkpoint()``\ , and ``sgd.load_checkpoint()`` with ``tune.report()``\ , ``tune.save_checkpoint()``\ , and ``tune.load_checkpoint()``. + +**SGD v1** + +.. code-block:: python + + from ray import tune + from ray.util.sgd import TrainingOperator, TorchTrainer + + class MyTrainingOperator(TrainingOperator): + ... + + def custom_step(trainer, info): + train_stats = trainer.train() + return train_stats + + # TorchTrainable is subclass of BaseTorchTrainable. + TorchTrainable = TorchTrainer.as_trainable( + training_operator_cls=MyTrainingOperator, + num_workers=2, + use_gpu=True, + override_tune_step=custom_step + ) + + analysis = tune.run( + TorchTrainable, + config={"input": tune.grid_search([1, 2, 3])} + ) + + + +**SGD v2** + +.. code-block:: python + + from ray import tune + from ray import sgd + from ray.sgd import Trainer + + def train_func(config) + # In this example, nothing is expected to change over epochs, + # and the output metric is equivalent to the input value. + for _ in range(config["num_epochs"]): + sgd.report(output=config["input"]) + + trainer = Trainer(backend="torch", num_workers=2) + trainable = trainer.to_tune_trainable(train_func) + analysis = tune.run(trainable, config={ + "num_epochs": 2, + "input": tune.grid_search([1, 2, 3]) + }) + print(analysis.get_best_config(metric="output", mode="max")) + # {'num_epochs': 2, 'input': 3} + +For more information see :ref:`sgd-tune` \ No newline at end of file diff --git a/doc/source/raysgd/v2/raysgd.rst b/doc/source/raysgd/v2/raysgd.rst index 02111cdae1672..a37e583a7fe7e 100644 --- a/doc/source/raysgd/v2/raysgd.rst +++ b/doc/source/raysgd/v2/raysgd.rst @@ -5,6 +5,8 @@ RaySGD: Deep Learning on Ray .. _`issue on GitHub`: https://github.com/ray-project/ray/issues +.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! + RaySGD is a lightweight library for distributed deep learning, allowing you to scale up and speed up training for your deep learning models. @@ -21,7 +23,6 @@ The main features are: `issue on GitHub`_. If you are looking for the previous API documentation, see :ref:`sgd-index`. - Intro to RaySGD --------------- diff --git a/doc/source/raysgd/v2/user_guide.rst b/doc/source/raysgd/v2/user_guide.rst index fe33949342af0..2c34e59dd29f2 100644 --- a/doc/source/raysgd/v2/user_guide.rst +++ b/doc/source/raysgd/v2/user_guide.rst @@ -3,6 +3,8 @@ RaySGD User Guide ================= +.. tip:: Get in touch with us if you're using or considering using `RaySGD `_! + In this guide, we cover examples for the following use cases: * How do I :ref:`port my code ` to using RaySGD? @@ -88,6 +90,7 @@ training. If you are using GPUs, you need to make sure to the CUDA devices are properly setup inside your training function. This involves 3 steps: + 1. Use the local rank to set the default CUDA device for the worker. 2. Move the model to the default CUDA device (or a specific CUDA device). 3. Specify ``device_ids`` when wrapping in ``DistributedDataParallel``. @@ -341,7 +344,8 @@ You can plug all of these into RaySGD with the following interface: .. code-block:: python from ray import sgd - from ray.sgd import SGDCallback, Trainer + from ray.sgd Trainer + from ray.sgd.callbacks import SGDCallback from typing import List, Dict class PrintingCallback(SGDCallback): @@ -395,7 +399,7 @@ A simple example for creating a callback that will print out results: .. code-block:: python - from ray.sgd import SGDCallback + from ray.sgd.callbacks import SGDCallback class PrintingCallback(SGDCallback): def handle_result(self, results: List[Dict], **info): @@ -635,7 +639,7 @@ Underneath the hood, RaySGD will automatically shard the given dataset. return model trainer = Trainer(num_workers=8, backend="torch") - dataset = ray.data.read_csv("...").filter().pipeline(length=50) + dataset = ray.data.read_csv("...").filter().window(blocks_per_window=50) result = trainer.run( train_func, @@ -738,7 +742,7 @@ A couple caveats: # Declare the specification for training. trainer = Trainer(backend="torch", num_workers=12, use_gpu=True) - dataset = ray.dataset.pipeline() + dataset = ray.dataset.window() # Convert this to a trainable. trainable = trainer.to_tune_trainable(training_func, dataset=dataset) diff --git a/doc/source/serve/core-apis.rst b/doc/source/serve/core-apis.rst index e5130821c98be..2bd1f834c465d 100644 --- a/doc/source/serve/core-apis.rst +++ b/doc/source/serve/core-apis.rst @@ -35,7 +35,14 @@ Deployments can be exposed in two ways: over HTTP or in Python via the :ref:`ser By default, HTTP requests will be forwarded to the ``__call__`` method of the class (or the function) and a ``Starlette Request`` object will be the sole argument. You can also define a deployment that wraps a FastAPI app for more flexible handling of HTTP requests. See :ref:`serve-fastapi-http` for details. -We can also list all available deployments and dynamically get a reference to them: +To serve multiple deployments defined by the same class, use the ``name`` option: + +.. code-block:: python + + MyFirstDeployment.options(name="hello_service").deploy("Hello!") + MyFirstDeployment.options(name="hi_service").deploy("Hi!) + +You can also list all available deployments and dynamically get references to them: .. code-block:: python @@ -238,27 +245,31 @@ Ray Serve supports serving deployments with different (possibly conflicting) Python dependencies. For example, you can simultaneously serve one deployment that uses legacy Tensorflow 1 and another that uses Tensorflow 2. -Currently this is supported on Mac OS and Linux using `conda `_ -via Ray's built-in ``runtime_env`` option for actors. -As with all other actor options, pass these in via ``ray_actor_options`` in -your deployment. -You must have a conda environment set up for each set of -dependencies you want to isolate. If using a multi-node cluster, the -desired conda environment must be present on all nodes. Also, the Python patch version -(e.g. 3.8.10) must be identical on all nodes (this is a requirement for any Ray cluster). -See :ref:`runtime-environments` for details. - -Here's an example script. For it to work, first create a conda -environment named ``ray-tf1`` with Ray Serve and Tensorflow 1 installed, -and another named ``ray-tf2`` with Ray Serve and Tensorflow 2. The Ray and -Python versions must be the same in both environments. +This is supported on Mac OS and Linux using Ray's :ref:`runtime-environments` feature. +As with all other Ray actor options, pass the runtime environment in via ``ray_actor_options`` in +your deployment. Be sure to first run ``pip install "ray[default]"`` to ensure the +Runtime Environments feature is installed. + +Example: .. literalinclude:: ../../../python/ray/serve/examples/doc/conda_env.py +.. note:: + When using a Ray library (for example, Ray Serve) in a runtime environment, it must + explicitly be included in the dependencies, as in the above example. This is not + required when just using Ray Core. + +.. tip:: + Avoid dynamically installing packages that install from source: these can be slow and + use up all resources while installing, leading to problems with the Ray cluster. Consider + precompiling such packages in a private repository or Docker image. + The dependencies required in the deployment may be different than the dependencies installed in the driver program (the one running Serve API calls). In this case, you should use a delayed import within the class to avoid -importing unavailable packages in the driver. +importing unavailable packages in the driver. This applies even when not +using runtime environments. + Example: .. literalinclude:: ../../../python/ray/serve/examples/doc/imported_backend.py diff --git a/doc/source/serve/ml-models.rst b/doc/source/serve/ml-models.rst index 8fe3330af0498..192207b041ac5 100644 --- a/doc/source/serve/ml-models.rst +++ b/doc/source/serve/ml-models.rst @@ -70,10 +70,10 @@ Integration with Model Registries Ray Serve is flexible. If you can load your model as a Python function or class, then you can scale it up and serve it with Ray Serve. -For example, if you are using the +For example, if you are using the `MLflow Model Registry `_ to manage your models, the following wrapper -class will allow you to load a model using its MLflow `Model URI`: +class will allow you to load a model using its MLflow `Model URI`: .. code-block:: python @@ -93,12 +93,19 @@ class will allow you to load a model using its MLflow `Model URI`: model_uri = "model:/my_registered_model/Production" MLflowDeployment.deploy(model_uri) -.. tip:: +To serve multiple different MLflow models in the same program, use the ``name`` option: + +.. code-block:: python + + MLflowDeployment.options(name="my_mlflow_model_1").deploy(model_uri) + + +.. tip:: The above approach will work for any model registry, not just MLflow. Namely, load the model from the registry in ``__init__``, and forward the request to the model in ``__call__``. -For an even more hands-off and seamless integration with MLflow, check out the +For an even more hands-off and seamless integration with MLflow, check out the `Ray Serve MLflow deployment plugin `__. A full tutorial is available `here `__. diff --git a/doc/source/tune/_tutorials/_faq.inc b/doc/source/tune/_tutorials/_faq.inc index d9bb39e1f94dc..c14a0aa4504cd 100644 --- a/doc/source/tune/_tutorials/_faq.inc +++ b/doc/source/tune/_tutorials/_faq.inc @@ -19,10 +19,18 @@ Deciding on which to use mostly depends on your problem: * How many hyperparameters would you like to tune? * What values are valid for hyperparameters? +**If your model returns incremental results** (eg. results per epoch in deep learning, +results per each added tree in GBDTs, etc.) using early stopping usually allows for sampling +more configurations, as unpromising trials are pruned before they run their full course. +Please note that not all search algorithms can use information from pruned trials. +Early stopping cannot be used without incremental results - in case of the functional API, +that means that ``tune.report()`` has to be called more than once - usually in a loop. + **If your model is small**, you can usually try to run many different configurations. A **random search** can be used to generate configurations. You can also grid search over some values. You should probably still use -:ref:`ASHA for early termination of bad trials `. +:ref:`ASHA for early termination of bad trials ` (if your problem +supports early stopping). **If your model is large**, you can try to either use **Bayesian Optimization-based search algorithms** like :ref:`BayesOpt ` or @@ -33,14 +41,19 @@ Alternatively, you can use :ref:`Population Based Training ` works well with few trials, e.g. 8 or even 4. However, this will output a hyperparameter *schedule* rather than one fixed set of hyperparameters. -**If you have a small number of hyperparameters**, Bayesian Optimization-methods -work well. Take a look at :ref:`BOHB ` to combine the -benefits of bayesian optimization with early stopping. +**If you have a small number of hyperparameters**, Bayesian Optimization methods +work well. Take a look at :ref:`BOHB ` or :ref:`Optuna ` +with the :ref:`ASHA ` scheduler to combine the +benefits of Bayesian Optimization with early stopping. **If you only have continuous values for hyperparameters** this will work well -with most Bayesian-Optimization methods. Discrete or categorical variables still +with most Bayesian Optimization methods. Discrete or categorical variables still work, but less good with an increasing number of categories. +**If you have many categorical values for hyperparameters**, consider using random search, +or a TPE-based Bayesian Optimization algorithm such as :ref:`Optuna ` or +:ref:`HyperOpt `. + **Our go-to solution** is usually to use **random search** with :ref:`ASHA for early stopping ` for smaller problems. Use :ref:`BOHB ` for **larger problems** with a **small number of hyperparameters** and :ref:`Population Based Training ` for **larger problems** with a **large number of hyperparameters** @@ -248,6 +261,34 @@ on other nodes as well. Please refer to the :ref:`placement groups documentation ` to learn more about these placement strategies. +Why is my training stuck and Ray reporting that pending actor or tasks cannot be scheduled? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This is usually caused by Ray actors or tasks being started by the +trainable without the trainable resources accounting for them, leading to a deadlock. +This can also be "stealthly" caused by using other libraries in the trainable that are +based on Ray, such as Modin. In order to fix the issue, request additional resources for +the trial using :ref:`placement groups `, as outlined in +the section above. + +For example, if your trainable is using Modin dataframes, operations on those will spawn +Ray tasks. By allocating an additional CPU bundle to the trial, those tasks will be able +to run without being starved of resources. + +.. code-block:: python + + import modin.pandas as pd + + def train_fn(config, checkpoint_dir=None): + # some Modin operations here + tune.report(metric=metric) + + tune.run( + train_fn, + resources_per_trial=tune.PlacementGroupFactory([ + {"CPU": 1}, # this bundle will be used by the trainable itself + {"CPU": 1}, # this bundle will be used by Modin + ], strategy="PACK") How can I pass further parameter values to my trainable? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -286,8 +327,8 @@ also works with class trainables. Please see :ref:`here for further details ` and examples. -How can I reproduce experiments -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +How can I reproduce experiments? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Reproducing experiments and experiment results means that you get the exact same results when running an experiment again and again. To achieve this, the conditions have to be exactly the same each time you run the exeriment. diff --git a/doc/source/tune/api_docs/suggestion.rst b/doc/source/tune/api_docs/suggestion.rst index 32728c4ab2273..4795f0c97816f 100644 --- a/doc/source/tune/api_docs/suggestion.rst +++ b/doc/source/tune/api_docs/suggestion.rst @@ -16,6 +16,7 @@ Summary ------- .. list-table:: + :widths: 5 5 2 10 :header-rows: 1 * - SearchAlgorithm @@ -137,8 +138,6 @@ identifier. search_alg2.restore_from_dir( os.path.join("~/my_results", "my-experiment-1")) -.. note:: This is currently not implemented for: AxSearch, TuneBOHB, SigOptSearch, and DragonflySearch. - .. _tune-basicvariant: Random search and grid search (tune.suggest.basic_variant.BasicVariantGenerator) diff --git a/doc/source/tune/user-guide.rst b/doc/source/tune/user-guide.rst index a4f522add908e..962d53bdad848 100644 --- a/doc/source/tune/user-guide.rst +++ b/doc/source/tune/user-guide.rst @@ -50,7 +50,8 @@ the respective placement group. If not enough resources are available, this will If your trainable function starts more remote workers, you will need to pass placement groups factory objects to request these resources. See the :class:`PlacementGroupFactory documentation ` -for further information. +for further information. This also applies if you are using other libraries making use of Ray, such +as Modin. Failure to set resources correctly may result in a deadlock, "hanging" the cluster. Using GPUs ~~~~~~~~~~ @@ -870,6 +871,10 @@ These are the environment variables Ray Tune currently considers: Ctrl+C) to gracefully shutdown and do a final checkpoint. Setting this variable to ``1`` will disable signal handling and stop execution right away. Defaults to ``0``. +* **TUNE_FORCE_TRIAL_CLEANUP_S**: By default, Ray Tune will gracefully terminate trials, + letting them finish the current training step and any user-defined cleanup. + Setting this variable to a non-zero, positive integer will cause trials to be forcefully + terminated after a grace period of that many seconds. Defaults to ``0``. * **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits for threads to finish after instructing them to complete. Defaults to ``2``. * **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's @@ -903,6 +908,9 @@ These are the environment variables Ray Tune currently considers: to the driver. Enabling this might delay scheduling decisions, as trainables are speculatively continued. Setting this to ``0`` disables result buffering. Defaults to 1000 (results), or to 1 (no buffering) if used with ``checkpoint_at_end``. +* **TUNE_RESULT_DELIM**: Delimiter used for nested entries in + :class:`ExperimentAnalysis ` dataframes. Defaults to ``.`` (but will be + changed to ``/`` in future versions of Ray). * **TUNE_RESULT_BUFFER_MAX_TIME_S**: Similarly, Ray Tune buffers results up to ``number_of_trial/10`` seconds, but never longer than this value. Defaults to 100 (seconds). * **TUNE_RESULT_BUFFER_MIN_TIME_S**: Additionally, you can specify a minimum time to buffer results. Defaults to 0. diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 2d98abb9402ba..06a974befed02 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -147,9 +147,12 @@ define_java_module( ":io_ray_ray_api", ":io_ray_ray_runtime", ":io_ray_ray_serve", + "@maven//:com_google_code_gson_gson", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:org_apache_commons_commons_lang3", + "@maven//:org_apache_httpcomponents_client5_httpclient5", + "@maven//:org_apache_httpcomponents_core5_httpcore5", "@maven//:org_slf4j_slf4j_api", "@maven//:org_testng_testng", ], @@ -157,9 +160,11 @@ define_java_module( deps = [ ":io_ray_ray_api", ":io_ray_ray_runtime", + "@maven//:com_google_code_gson_gson", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:org_apache_commons_commons_lang3", + "@maven//:org_apache_httpcomponents_core5_httpcore5", "@maven//:org_slf4j_slf4j_api", ], ) diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 9c411a1bd9982..e6bb9e384d1cf 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -24,6 +24,8 @@ def gen_java_deps(): "com.lmax:disruptor:3.3.4", "org.yaml:snakeyaml:1.26", "net.java.dev.jna:jna:5.5.0", + "org.apache.httpcomponents.client5:httpclient5:5.0.3", + "org.apache.httpcomponents.core5:httpcore5:5.0.2", maven.artifact( group = "org.testng", artifact = "testng", diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index acda82aa6f1d6..172ff78dfa397 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -1,6 +1,7 @@ package io.ray.runtime; import com.google.common.base.Preconditions; +import com.google.gson.Gson; import io.ray.api.BaseActorHandle; import io.ray.api.id.ActorId; import io.ray.api.id.JobId; @@ -10,6 +11,7 @@ import io.ray.runtime.exception.RayIntentionalSystemExitException; import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.gcs.GcsClientOptions; +import io.ray.runtime.generated.Common.RuntimeEnv; import io.ray.runtime.generated.Common.WorkerType; import io.ray.runtime.generated.Gcs.GcsNodeInfo; import io.ray.runtime.generated.Gcs.JobConfig; @@ -20,6 +22,8 @@ import io.ray.runtime.task.TaskExecutor; import io.ray.runtime.util.BinaryFileUtil; import io.ray.runtime.util.JniUtils; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; @@ -102,8 +106,20 @@ public void start() { JobConfig.newBuilder() .setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess) .addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker) - .putAllWorkerEnv(rayConfig.workerEnv) .addAllCodeSearchPath(rayConfig.codeSearchPath); + RuntimeEnv.Builder runtimeEnvBuilder = RuntimeEnv.newBuilder(); + if (!rayConfig.workerEnv.isEmpty()) { + // TODO(SongGuyang): Suppport complete runtime env interface for users. + // Set worker env to the serialized runtime env json. + Gson gson = new Gson(); + Map> runtimeEnv = new HashMap<>(); + runtimeEnv.put("env_vars", rayConfig.workerEnv); + String gsonString = gson.toJson(runtimeEnv); + runtimeEnvBuilder.setSerializedRuntimeEnv(gsonString); + } else { + runtimeEnvBuilder.setSerializedRuntimeEnv("{}"); + } + jobConfigBuilder.setRuntimeEnv(runtimeEnvBuilder.build()); serializedJobConfig = jobConfigBuilder.build().toByteArray(); } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java index 131d71c5fa2f9..fc139985955c9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java @@ -117,7 +117,7 @@ public Address getOwnerAddress(ObjectId id) { } @Override - public byte[] promoteAndGetOwnershipInfo(ObjectId objectId) { + public byte[] getOwnershipInfo(ObjectId objectId) { return new byte[0]; } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java index 7e0ddc5c9aa74..136712c096cd8 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java @@ -78,8 +78,8 @@ public void removeLocalReference(UniqueId workerId, ObjectId objectId) { } @Override - public byte[] promoteAndGetOwnershipInfo(ObjectId objectId) { - return nativePromoteAndGetOwnershipInfo(objectId.getBytes()); + public byte[] getOwnershipInfo(ObjectId objectId) { + return nativeGetOwnershipInfo(objectId.getBytes()); } @Override @@ -132,7 +132,7 @@ private static native List nativeWait( private static native byte[] nativeGetOwnerAddress(byte[] objectId); - private static native byte[] nativePromoteAndGetOwnershipInfo(byte[] objectId); + private static native byte[] nativeGetOwnershipInfo(byte[] objectId); private static native void nativeRegisterOwnershipInfoAndResolveFuture( byte[] objectId, byte[] outerObjectId, byte[] ownerAddress); diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java index cb9b35becd02d..a352ca22632ef 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java @@ -63,7 +63,7 @@ public void writeExternal(ObjectOutput out) throws IOException { out.writeObject(this.getId()); out.writeObject(this.getType()); RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal(); - byte[] ownerAddress = runtime.getObjectStore().promoteAndGetOwnershipInfo(this.getId()); + byte[] ownerAddress = runtime.getObjectStore().getOwnershipInfo(this.getId()); out.writeInt(ownerAddress.length); out.write(ownerAddress); ObjectSerializer.addContainedObjectId(this.getId()); diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java index d61694fab7e93..6db39cc1e4bd6 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java @@ -224,12 +224,12 @@ public WaitResult wait( public abstract Address getOwnerAddress(ObjectId id); /** - * Promote the given object to the underlying object store, and get the ownership info. + * Get the ownership info. * * @param objectId The ID of the object to promote * @return the serialized ownership address */ - public abstract byte[] promoteAndGetOwnershipInfo(ObjectId objectId); + public abstract byte[] getOwnershipInfo(ObjectId objectId); /** * Add a reference to an ObjectID that will deserialized. This will also start the process to diff --git a/java/serve/pom.xml b/java/serve/pom.xml index d945f8fe83172..7291d4ec79666 100644 --- a/java/serve/pom.xml +++ b/java/serve/pom.xml @@ -27,6 +27,11 @@ ray-runtime ${project.version} + + com.google.code.gson + gson + 2.8.5 + com.google.guava guava @@ -42,6 +47,16 @@ commons-lang3 3.4 + + org.apache.httpcomponents.client5 + httpclient5 + 5.0.3 + + + org.apache.httpcomponents.core5 + httpcore5 + 5.0.2 + org.slf4j slf4j-api diff --git a/java/serve/src/main/java/io/ray/serve/Constants.java b/java/serve/src/main/java/io/ray/serve/Constants.java index 2d8ac4f702839..1ca1739f8d734 100644 --- a/java/serve/src/main/java/io/ray/serve/Constants.java +++ b/java/serve/src/main/java/io/ray/serve/Constants.java @@ -16,4 +16,10 @@ public class Constants { /** Name of controller listen_for_change method. */ public static final String CONTROLLER_LISTEN_FOR_CHANGE_METHOD = "listen_for_change"; + + public static final String SERVE_CONTROLLER_NAME = "SERVE_CONTROLLER_ACTOR"; + + public static final String DEFAULT_CALL_METHOD = "call"; + + public static final String UTF8 = "UTF-8"; } diff --git a/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java b/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java new file mode 100644 index 0000000000000..2ab02deeeeaeb --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java @@ -0,0 +1,38 @@ +package io.ray.serve; + +import java.io.Serializable; + +public class DeploymentInfo implements Serializable { + + private static final long serialVersionUID = -4198364411759931955L; + + private byte[] backendConfig; + + private ReplicaConfig replicaConfig; + + private byte[] backendVersion; + + public byte[] getBackendConfig() { + return backendConfig; + } + + public void setBackendConfig(byte[] backendConfig) { + this.backendConfig = backendConfig; + } + + public ReplicaConfig getReplicaConfig() { + return replicaConfig; + } + + public void setReplicaConfig(ReplicaConfig replicaConfig) { + this.replicaConfig = replicaConfig; + } + + public byte[] getBackendVersion() { + return backendVersion; + } + + public void setBackendVersion(byte[] backendVersion) { + this.backendVersion = backendVersion; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java b/java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java new file mode 100644 index 0000000000000..874a71c26d6db --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/DummyBackendReplica.java @@ -0,0 +1,12 @@ +package io.ray.serve; + +import java.util.concurrent.atomic.AtomicInteger; + +public class DummyBackendReplica { + + private AtomicInteger counter = new AtomicInteger(); + + public String call() { + return String.valueOf(counter.incrementAndGet()); + } +} diff --git a/java/serve/src/main/java/io/ray/serve/HandleOptions.java b/java/serve/src/main/java/io/ray/serve/HandleOptions.java new file mode 100644 index 0000000000000..e301332976ea3 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/HandleOptions.java @@ -0,0 +1,15 @@ +package io.ray.serve; + +/** Options for each ServeHandle instances. These fields are immutable. */ +public class HandleOptions { + + private String methodName = "call"; + + public String getMethodName() { + return methodName; + } + + public void setMethodName(String methodName) { + this.methodName = methodName; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/HttpProxy.java b/java/serve/src/main/java/io/ray/serve/HttpProxy.java new file mode 100644 index 0000000000000..809337e75d902 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/HttpProxy.java @@ -0,0 +1,161 @@ +package io.ray.serve; + +import com.google.common.collect.ImmutableMap; +import io.ray.api.Ray; +import io.ray.runtime.metric.Count; +import io.ray.runtime.metric.Metrics; +import io.ray.runtime.metric.TagKey; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.util.LogUtil; +import io.ray.serve.util.SocketUtil; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.InetAddress; +import java.nio.charset.Charset; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.apache.hc.core5.http.ClassicHttpRequest; +import org.apache.hc.core5.http.ClassicHttpResponse; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpException; +import org.apache.hc.core5.http.impl.bootstrap.HttpServer; +import org.apache.hc.core5.http.impl.bootstrap.ServerBootstrap; +import org.apache.hc.core5.http.io.HttpRequestHandler; +import org.apache.hc.core5.http.io.entity.ByteArrayEntity; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.protocol.HttpContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class HttpProxy implements ServeProxy { + + private static final Logger LOGGER = LoggerFactory.getLogger(HttpProxy.class); + + public static final String PROXY_NAME = "HTTP_PROXY"; + + public static final String PROXY_HTTP_PORT = "ray.serve.proxy.http.port"; + + public static final String PROXY_HTTP_METHODS = "ray.serve.proxy.http.methods"; + + private int port; + + private Count requestCounter; + + private HttpServer httpServer; + + private ProxyRouter proxyRouter; + + private Object asyncContext = Ray.getAsyncContext(); + + @Override + public void init(Map config, ProxyRouter proxyRouter) { + this.port = + Optional.ofNullable(config) + .map(conf -> conf.get(PROXY_HTTP_PORT)) + .map(httpPort -> Integer.valueOf(httpPort)) + .orElse(SocketUtil.findAvailableTcpPort(8000)); + this.proxyRouter = proxyRouter; + RayServeMetrics.execute( + () -> + this.requestCounter = + Metrics.count() + .name("serve_num_http_requests") + .description("The number of HTTP requests processed.") + .unit("") + .tags(new HashMap<>()) + .register()); + startupHttpServer(port); + LOGGER.info("Proxy {} has been started with port:{}", getName(), this.port); + } + + private void startupHttpServer(int port) { + try { + this.httpServer = + ServerBootstrap.bootstrap() + .setListenerPort(port) + .register("*", new ServeHttpHandler()) + .registerVirtual( + InetAddress.getLocalHost().getHostAddress(), "*", new ServeHttpHandler()) + .create(); + this.httpServer.start(); + } catch (Throwable e) { + String errMsg = + LogUtil.format( + "Proxy {} failed to startup HTTP server on port {}.", getName(), this.port); + LOGGER.error(errMsg); + throw new RayServeException(errMsg, e); + } + } + + @Override + public String getName() { + return PROXY_NAME; + } + + private class ServeHttpHandler implements HttpRequestHandler { + + @Override + public void handle( + ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context) + throws HttpException, IOException { + + Ray.setAsyncContext(asyncContext); + + int code = HttpURLConnection.HTTP_OK; + Object result = null; + String route = request.getPath(); + try { + RayServeMetrics.execute( + () -> + requestCounter.update( + 1.0, + ImmutableMap.of( + new TagKey(RayServeMetrics.TAG_ROUTE), + route))); // TODO the old tag will be covered, it may be a bug. + + Object[] parameters = null; + HttpEntity httpEntity = request.getEntity(); + if (null == httpEntity) { + parameters = new Object[0]; + } else { + byte[] body = EntityUtils.toByteArray(httpEntity); + parameters = MessagePackSerializer.decode(body, Object[].class); + } + + RayServeHandle rayServeHandle = proxyRouter.matchRoute(route); + if (rayServeHandle == null) { + code = HttpURLConnection.HTTP_NOT_FOUND; + } else { + result = rayServeHandle.remote(parameters).get(); + } + + } catch (Throwable e) { + LOGGER.error("HTTP Proxy failed to process request.", e); + code = HttpURLConnection.HTTP_INTERNAL_ERROR; + } finally { + response.setCode(code); + if (code == HttpURLConnection.HTTP_NOT_FOUND) { + response.setEntity( + new StringEntity( + LogUtil.format( + "Path '{}' not found. Please ping http://.../-/routes for route table.", + route), + Charset.forName(Constants.UTF8))); + } else if (result != null) { + response.setEntity( + new ByteArrayEntity(MessagePackSerializer.encode(result).getLeft(), null)); + } + } + } + } + + public int getPort() { + return port; + } + + public ProxyRouter getProxyRouter() { + return proxyRouter; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/ProxyActor.java b/java/serve/src/main/java/io/ray/serve/ProxyActor.java new file mode 100644 index 0000000000000..ac5d1cf870ea9 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/ProxyActor.java @@ -0,0 +1,175 @@ +package io.ray.serve; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; +import io.ray.serve.api.Serve; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.generated.EndpointSet; +import io.ray.serve.poll.KeyListener; +import io.ray.serve.poll.KeyType; +import io.ray.serve.poll.LongPollClient; +import io.ray.serve.poll.LongPollNamespace; +import io.ray.serve.util.CollectionUtil; +import io.ray.serve.util.LogUtil; +import io.ray.serve.util.ReflectUtil; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.ServiceLoader; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ProxyActor { + + private static final Logger LOGGER = LoggerFactory.getLogger(ProxyActor.class); + + private Map config; + + private Map proxies = new ConcurrentHashMap<>(); + + /** Used only for displaying the route table. Key: route, value: endpoint. */ + private volatile Map routeInfo = new HashMap<>(); + + private LongPollClient longPollClient; + + private ProxyRouter proxyRouter = new ProxyRouter(); + + public ProxyActor(String controllerName, Map config) { + this.config = config; + + // Set the controller name so that serve will connect to the controller instance this proxy is + // running in. + Serve.setInternalReplicaContext(null, null, controllerName, null); + + Optional optional = Ray.getActor(controllerName); + Preconditions.checkState(optional.isPresent(), "Controller does not exist"); + + Map keyListeners = new HashMap<>(); + keyListeners.put( + new KeyType(LongPollNamespace.ROUTE_TABLE, null), endpoints -> updateRoutes(endpoints)); + this.longPollClient = new LongPollClient(optional.get(), keyListeners); + this.longPollClient.start(); + this.run(); + } + + private void run() { + startupProxy(); + registerServiceDiscovery(); + } + + private void startupProxy() { + + List serveProxies = null; + + // Get proxy instances according to class names. + String proxyClassNames = config != null ? config.get(RayServeConfig.PROXY_CLASS) : null; + if (StringUtils.isNotBlank(proxyClassNames)) { + try { + serveProxies = ReflectUtil.getInstancesByClassNames(proxyClassNames, ServeProxy.class); + } catch (ClassNotFoundException + | InstantiationException + | IllegalAccessException + | IllegalArgumentException + | InvocationTargetException + | NoSuchMethodException + | SecurityException e) { + String errorMsg = + LogUtil.format("Failed to initialize proxies by class names : {}", proxyClassNames); + LOGGER.error(errorMsg, e); + throw new RayServeException(errorMsg, e); + } + } + + // Get proxy instances through SPI. + if (CollectionUtil.isEmpty(serveProxies)) { + List spiProxies = new ArrayList<>(); + ServiceLoader serviceLoader = ServiceLoader.load(ServeProxy.class); + serviceLoader.forEach(serveProxy -> spiProxies.add(serveProxy)); + serveProxies = spiProxies; + } + + // Set the default proxy if proxies still empty. + if (CollectionUtil.isEmpty(serveProxies)) { + serveProxies = Lists.newArrayList(new HttpProxy()); + } + + if (!CollectionUtil.isEmpty(serveProxies)) { + for (ServeProxy serveProxy : serveProxies) { + if (proxies.containsKey(serveProxy.getName())) { + String errorMsg = + LogUtil.format( + "Proxy {} name {} is duplicate with proxy {} name {}", + serveProxy.getClass().getName(), + serveProxy.getName(), + proxies.get(serveProxy.getName()).getClass().getName(), + proxies.get(serveProxy.getName()).getName()); + LOGGER.error(errorMsg); + throw new RayServeException(errorMsg); + } + proxies.put(serveProxy.getName(), serveProxy); + serveProxy.init(config, proxyRouter); + LOGGER.info("Proxy actor initialized proxy: {}", serveProxy.getName()); + } + } + } + + public void registerServiceDiscovery() { + proxies.forEach((key, value) -> value.registerServiceDiscovery()); + } + + public void updateRoutes(Object endpoints) { + Map endpointInfos = ((EndpointSet) endpoints).getEndpointsMap(); + Map routeInfo = new HashMap<>(); + if (endpointInfos != null) { + endpointInfos.forEach( + (key, value) -> + routeInfo.put( + StringUtils.isNotBlank(value.getRoute()) ? value.getRoute() : key, value)); + } + this.routeInfo = routeInfo; + this.proxyRouter.updateRoutes(endpointInfos); + } + + public void ready() { + return; + } + + public void blockUntilEndpointExists(String endpoint, double timeoutS) { + long timeoutMs = (long) (timeoutS * 1000); + long startTime = System.currentTimeMillis(); + while (true) { + if (System.currentTimeMillis() - startTime > timeoutMs) { + throw new RayServeException( + LogUtil.format("Waited {} for {} to propagate.", timeoutS, endpoint)); + } + for (EndpointInfo endpointInfo : routeInfo.values()) { + if (StringUtils.equals(endpointInfo.getEndpointName(), endpoint)) { + return; + } + } + try { + Thread.sleep(200); + } catch (InterruptedException e) { + LOGGER.error( + "The sleeping was interrupted when waiting for the endpoint {} being existing.", + endpoint, + e); + } + } + } + + public ProxyRouter getProxyRouter() { + return proxyRouter; + } + + public Map getProxies() { + return proxies; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/ProxyRouter.java b/java/serve/src/main/java/io/ray/serve/ProxyRouter.java new file mode 100644 index 0000000000000..041da46bfee08 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/ProxyRouter.java @@ -0,0 +1,72 @@ +package io.ray.serve; + +import io.ray.serve.api.Serve; +import io.ray.serve.generated.EndpointInfo; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Default common router for proxy to match incomming routes. */ +public class ProxyRouter { + + private static final Logger LOGGER = LoggerFactory.getLogger(ProxyRouter.class); + + /** Key: route, value: endpoint. */ + private Map routeInfo = new HashMap<>(); + + /** Key: endpointName, value: handle. */ + private Map handles = new ConcurrentHashMap<>(); + + public void updateRoutes(Map endpoints) { + LOGGER.info("Got updated endpoints: {}.", endpoints); + + Set existingHandles = new HashSet<>(handles.keySet()); + Map routeInfo = new HashMap<>(); + + if (endpoints != null) { + for (Map.Entry entry : endpoints.entrySet()) { + String route = + StringUtils.isNotBlank(entry.getValue().getRoute()) + ? entry.getValue().getRoute() + : entry.getKey(); + routeInfo.put(route, entry.getValue()); + + if (handles.containsKey(entry.getKey())) { + existingHandles.remove(entry.getKey()); + } else { + handles.put(entry.getKey(), Serve.getGlobalClient().getHandle(entry.getKey(), true)); + } + } + } + + this.routeInfo = routeInfo; + for (String endpoint : existingHandles) { + handles.remove(endpoint); + } + LOGGER.info("The final route info: {}.", routeInfo); + } + + /** + * Return the longest prefix match among existing routes for the route. + * + * @param route route to match against. + * @return serve_handle (RayServeHandle) if found, else null. + */ + public RayServeHandle matchRoute(String route) { + EndpointInfo endpointInfo = routeInfo.get(route); + return endpointInfo == null ? null : handles.get(endpointInfo.getEndpointName()); + } + + public Map getRouteInfo() { + return routeInfo; + } + + public Map getHandles() { + return handles; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeConfig.java b/java/serve/src/main/java/io/ray/serve/RayServeConfig.java new file mode 100644 index 0000000000000..5762aae40be4e --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/RayServeConfig.java @@ -0,0 +1,6 @@ +package io.ray.serve; + +public class RayServeConfig { + + public static final String PROXY_CLASS = "ray.serve.proxy.class"; +} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeHandle.java b/java/serve/src/main/java/io/ray/serve/RayServeHandle.java new file mode 100644 index 0000000000000..abcf6ac5abdf2 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/RayServeHandle.java @@ -0,0 +1,73 @@ +package io.ray.serve; + +import com.google.common.collect.ImmutableMap; +import io.ray.api.BaseActorHandle; +import io.ray.api.ObjectRef; +import io.ray.runtime.metric.Count; +import io.ray.runtime.metric.Metrics; +import io.ray.serve.generated.RequestMetadata; +import org.apache.commons.lang3.RandomStringUtils; + +public class RayServeHandle { + + private String endpointName; + + private HandleOptions handleOptions; + + private String handleTag; + + private Count requestCounter; + + private Router router; + + public RayServeHandle( + BaseActorHandle controllerHandle, + String endpointName, + HandleOptions handleOptions, + Router router) { + this.endpointName = endpointName; + this.handleOptions = handleOptions != null ? handleOptions : new HandleOptions(); + this.handleTag = endpointName + "#" + RandomStringUtils.randomAlphabetic(6); + this.router = router != null ? router : new Router(controllerHandle, endpointName); + RayServeMetrics.execute( + () -> + this.requestCounter = + Metrics.count() + .name(RayServeMetrics.SERVE_HANDLE_REQUEST_COUNTER.name()) + .description(RayServeMetrics.SERVE_HANDLE_REQUEST_COUNTER.getDescription()) + .unit("") + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_HANDLE, + handleTag, + RayServeMetrics.TAG_ENDPOINT, + endpointName)) + .register()); + } + + /** + * Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get + * (or ``await object_ref``), respectively. + * + * @param parameters The input parameters of the specified method to invoke on the backend. + * @return ray.ObjectRef + */ + public ObjectRef remote(Object[] parameters) { + RayServeMetrics.execute(() -> requestCounter.inc(1.0)); + RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); + requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); + requestMetadata.setEndpoint(endpointName); + requestMetadata.setCallMethod( + handleOptions != null ? handleOptions.getMethodName() : Constants.DEFAULT_CALL_METHOD); + return router.assignRequest(requestMetadata.build(), parameters); + } + + public RayServeHandle setMethodName(String methodName) { + handleOptions.setMethodName(methodName); + return this; + } + + public Router getRouter() { + return router; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeMetrics.java b/java/serve/src/main/java/io/ray/serve/RayServeMetrics.java new file mode 100644 index 0000000000000..f7b1fac730da9 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/RayServeMetrics.java @@ -0,0 +1,74 @@ +package io.ray.serve; + +import io.ray.api.Ray; + +public enum RayServeMetrics { + SERVE_HANDLE_REQUEST_COUNTER( + "serve_handle_request_counter", + "The number of handle.remote() calls that have been made on this handle."), + + SERVE_NUM_ROUTER_REQUESTS( + "serve_num_router_requests", "The number of requests processed by the router."), + + SERVE_DEPLOYMENT_QUEUED_QUERIES( + "serve_deployment_queued_queries", + "The current number of queries to this deployment waiting to be assigned to a replica."), + + SERVE_BACKEND_REQUEST_COUNTER( + "serve_backend_request_counter", + "The number of queries that have been processed in this replica."), + + SERVE_BACKEND_ERROR_COUNTER( + "serve_backend_error_counter", + "The number of exceptions that have occurred in this replica."), + + SERVE_BACKEND_REPLICA_STARTS( + "serve_backend_replica_starts", + "The number of times this replica has been restarted due to failure."), + + SERVE_BACKEND_PROCESSING_LATENCY_MS( + "serve_backend_processing_latency_ms", "The latency for queries to be processed."), + + SERVE_REPLICA_PROCESSING_QUERIES( + "serve_replica_processing_queries", "The current number of queries being processed."), + ; + + public static final String TAG_HANDLE = "handle"; + + public static final String TAG_ENDPOINT = "endpoint"; + + public static final String TAG_DEPLOYMENT = "deployment"; + + public static final String TAG_ROUTE = "route"; + + public static final String TAG_BACKEND = "backend"; + + public static final String TAG_REPLICA = "replica"; + + private static final boolean isMetricsEnabled = + Ray.isInitialized() && !Ray.getRuntimeContext().isSingleProcess(); + + private String name; + + private String description; + + private RayServeMetrics(String name, String description) { + this.name = name; + this.description = description; + } + + public static void execute(Runnable runnable) { + if (!isMetricsEnabled) { + return; + } + runnable.run(); + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/RayServeReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeReplica.java index 9949115fbbd72..259c8555cf3e4 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeReplica.java @@ -1,16 +1,16 @@ package io.ray.serve; import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; import io.ray.api.BaseActorHandle; -import io.ray.api.Ray; import io.ray.runtime.metric.Count; import io.ray.runtime.metric.Gauge; import io.ray.runtime.metric.Histogram; -import io.ray.runtime.metric.MetricConfig; import io.ray.runtime.metric.Metrics; import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendVersion; import io.ray.serve.generated.RequestWrapper; import io.ray.serve.poll.KeyListener; import io.ray.serve.poll.KeyType; @@ -18,7 +18,6 @@ import io.ray.serve.poll.LongPollNamespace; import io.ray.serve.util.LogUtil; import io.ray.serve.util.ReflectUtil; -import io.ray.serve.util.ServeProtoUtil; import java.lang.reflect.Method; import java.util.HashMap; import java.util.Map; @@ -41,8 +40,6 @@ public class RayServeReplica { private Object callable; - private boolean metricsRegistered = false; - private Count requestCounter; private Count errorCounter; @@ -55,13 +52,20 @@ public class RayServeReplica { private LongPollClient longPollClient; + private BackendVersion version; + + private boolean isDeleted = false; + public RayServeReplica( - Object callable, BackendConfig backendConfig, BaseActorHandle actorHandle) { + Object callable, + BackendConfig backendConfig, + BackendVersion version, + BaseActorHandle actorHandle) { this.backendTag = Serve.getReplicaContext().getBackendTag(); this.replicaTag = Serve.getReplicaContext().getReplicaTag(); this.callable = callable; this.config = backendConfig; - this.reconfigure(ServeProtoUtil.parseUserConfig(backendConfig)); + this.version = version; Map keyListeners = new HashMap<>(); keyListeners.put( @@ -73,55 +77,84 @@ public RayServeReplica( } private void registerMetrics() { - if (!Ray.isInitialized() || Ray.getRuntimeContext().isSingleProcess()) { - return; - } - - Metrics.init(MetricConfig.DEFAULT_CONFIG); - requestCounter = - Metrics.count() - .name("serve_backend_request_counter") - .description("The number of queries that have been processed in this replica.") - .unit("") - .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) - .register(); - - errorCounter = - Metrics.count() - .name("serve_backend_error_counter") - .description("The number of exceptions that have occurred in this replica.") - .unit("") - .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) - .register(); - - restartCounter = - Metrics.count() - .name("serve_backend_replica_starts") - .description("The number of times this replica has been restarted due to failure.") - .unit("") - .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) - .register(); - - processingLatencyTracker = - Metrics.histogram() - .name("serve_backend_processing_latency_ms") - .description("The latency for queries to be processed.") - .unit("") - .boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS) - .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) - .register(); - - numProcessingItems = - Metrics.gauge() - .name("serve_replica_processing_queries") - .description("The current number of queries being processed.") - .unit("") - .tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag)) - .register(); - - metricsRegistered = true; - - restartCounter.inc(1.0); + RayServeMetrics.execute( + () -> + requestCounter = + Metrics.count() + .name(RayServeMetrics.SERVE_BACKEND_REQUEST_COUNTER.getName()) + .description(RayServeMetrics.SERVE_BACKEND_REQUEST_COUNTER.getDescription()) + .unit("") + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_BACKEND, + backendTag, + RayServeMetrics.TAG_REPLICA, + replicaTag)) + .register()); + + RayServeMetrics.execute( + () -> + errorCounter = + Metrics.count() + .name(RayServeMetrics.SERVE_BACKEND_ERROR_COUNTER.getName()) + .description(RayServeMetrics.SERVE_BACKEND_ERROR_COUNTER.getDescription()) + .unit("") + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_BACKEND, + backendTag, + RayServeMetrics.TAG_REPLICA, + replicaTag)) + .register()); + + RayServeMetrics.execute( + () -> + restartCounter = + Metrics.count() + .name(RayServeMetrics.SERVE_BACKEND_REPLICA_STARTS.getName()) + .description(RayServeMetrics.SERVE_BACKEND_REPLICA_STARTS.getDescription()) + .unit("") + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_BACKEND, + backendTag, + RayServeMetrics.TAG_REPLICA, + replicaTag)) + .register()); + + RayServeMetrics.execute( + () -> + processingLatencyTracker = + Metrics.histogram() + .name(RayServeMetrics.SERVE_BACKEND_PROCESSING_LATENCY_MS.getName()) + .description( + RayServeMetrics.SERVE_BACKEND_PROCESSING_LATENCY_MS.getDescription()) + .unit("") + .boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS) + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_BACKEND, + backendTag, + RayServeMetrics.TAG_REPLICA, + replicaTag)) + .register()); + + RayServeMetrics.execute( + () -> + numProcessingItems = + Metrics.gauge() + .name(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getName()) + .description(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getDescription()) + .unit("") + .tags( + ImmutableMap.of( + RayServeMetrics.TAG_BACKEND, + backendTag, + RayServeMetrics.TAG_REPLICA, + replicaTag)) + .register()); + + RayServeMetrics.execute(() -> restartCounter.inc(1.0)); } public Object handleRequest(Query request) { @@ -130,7 +163,7 @@ public Object handleRequest(Query request) { "Replica {} received request {}", replicaTag, request.getMetadata().getRequestId()); numOngoingRequests.incrementAndGet(); - reportMetrics(() -> numProcessingItems.update(numOngoingRequests.get())); + RayServeMetrics.execute(() -> numProcessingItems.update(numOngoingRequests.get())); Object result = invokeSingle(request); numOngoingRequests.decrementAndGet(); @@ -157,10 +190,10 @@ private Object invokeSingle(Query requestItem) { Object[] args = parseRequestItem(requestItem); methodToCall = getRunnerMethod(requestItem.getMetadata().getCallMethod(), args); Object result = methodToCall.invoke(callable, args); - reportMetrics(() -> requestCounter.inc(1.0)); + RayServeMetrics.execute(() -> requestCounter.inc(1.0)); return result; } catch (Throwable e) { - reportMetrics(() -> errorCounter.inc(1.0)); + RayServeMetrics.execute(() -> errorCounter.inc(1.0)); throw new RayServeException( LogUtil.format( "Replica {} failed to invoke method {}", @@ -168,7 +201,8 @@ private Object invokeSingle(Query requestItem) { methodToCall == null ? "unknown" : methodToCall.getName()), e); } finally { - reportMetrics(() -> processingLatencyTracker.update(System.currentTimeMillis() - start)); + RayServeMetrics.execute( + () -> processingLatencyTracker.update(System.currentTimeMillis() - start)); } } @@ -209,10 +243,12 @@ private Method getRunnerMethod(String methodName, Object[] args) { * Perform graceful shutdown. Trigger a graceful shutdown protocol that will wait for all the * queued tasks to be completed and return to the controller. */ - public void drainPendingQueries() { + public synchronized boolean prepareForShutdown() { while (true) { + // Sleep first because we want to make sure all the routers receive the notification to remove + // this replica first. try { - Thread.sleep((long) (config.getExperimentalGracefulShutdownWaitLoopS() * 1000)); + Thread.sleep((long) (config.getGracefulShutdownWaitLoopS() * 1000)); } catch (InterruptedException e) { LOGGER.error( "Replica {} was interrupted in sheep when draining pending queries", replicaTag); @@ -220,13 +256,27 @@ public void drainPendingQueries() { if (numOngoingRequests.get() == 0) { break; } else { - LOGGER.debug( + LOGGER.info( "Waiting for an additional {}s to shut down because there are {} ongoing requests.", - config.getExperimentalGracefulShutdownWaitLoopS(), + config.getGracefulShutdownWaitLoopS(), numOngoingRequests.get()); } } - Ray.exitActor(); + + // Explicitly call the del method to trigger clean up. We set isDeleted = true after + // succssifully calling it so the destructor is called only once. + try { + if (!isDeleted) { + ReflectUtil.getMethod(callable.getClass(), "del").invoke(callable); + } + } catch (NoSuchMethodException e) { + LOGGER.warn("Deployment {} has no del method.", backendTag); + } catch (Throwable e) { + LOGGER.error("Exception during graceful shutdown of replica."); + } finally { + isDeleted = true; + } + return true; } /** @@ -234,28 +284,34 @@ public void drainPendingQueries() { * * @param userConfig new user's configuration */ - private void reconfigure(Object userConfig) { - if (userConfig == null) { - return; + public BackendVersion reconfigure(Object userConfig) { + BackendVersion.Builder builder = BackendVersion.newBuilder(); + builder.setCodeVersion(version.getCodeVersion()); + if (userConfig != null) { + builder.setUserConfig(ByteString.copyFrom((byte[]) userConfig)); } + version = builder.build(); + try { Method reconfigureMethod = ReflectUtil.getMethod( callable.getClass(), Constants.BACKEND_RECONFIGURE_METHOD, - userConfig); // TODO cache reconfigureMethod + userConfig != null + ? MessagePackSerializer.decode((byte[]) userConfig, Object[].class) + : new Object[0]); // TODO cache reconfigure method reconfigureMethod.invoke(callable, userConfig); } catch (NoSuchMethodException e) { - throw new RayServeException( - LogUtil.format( - "user_config specified but backend {} missing {} method", - backendTag, - Constants.BACKEND_RECONFIGURE_METHOD)); + LOGGER.warn( + "user_config specified but backend {} missing {} method", + backendTag, + Constants.BACKEND_RECONFIGURE_METHOD); } catch (Throwable e) { throw new RayServeException( LogUtil.format("Backend {} failed to reconfigure user_config {}", backendTag, userConfig), e); } + return version; } /** @@ -265,12 +321,9 @@ private void reconfigure(Object userConfig) { */ private void updateBackendConfigs(Object newConfig) { config = (BackendConfig) newConfig; - reconfigure(((BackendConfig) newConfig).getUserConfig()); } - private void reportMetrics(Runnable runnable) { - if (metricsRegistered) { - runnable.run(); - } + public BackendVersion getVersion() { + return version; } } diff --git a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java index 9ccc6c6f7a448..53e0854044c71 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java @@ -7,6 +7,7 @@ import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendVersion; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.util.ReflectUtil; import io.ray.serve.util.ServeProtoUtil; @@ -27,6 +28,7 @@ public RayServeWrappedReplica( String backendDef, byte[] initArgsbytes, byte[] backendConfigBytes, + byte[] backendVersionBytes, String controllerName) throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException { @@ -52,7 +54,26 @@ public RayServeWrappedReplica( Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, callable); // Construct worker replica. - backend = new RayServeReplica(callable, backendConfig, optional.get()); + backend = + new RayServeReplica( + callable, + backendConfig, + ServeProtoUtil.parseBackendVersion(backendVersionBytes), + optional.get()); + } + + public RayServeWrappedReplica( + String backendTag, String replicaTag, DeploymentInfo deploymentInfo, String controllerName) + throws ClassNotFoundException, NoSuchMethodException, InstantiationException, + IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException { + this( + backendTag, + replicaTag, + deploymentInfo.getReplicaConfig().getBackendDef(), + deploymentInfo.getReplicaConfig().getInitArgs(), + deploymentInfo.getBackendConfig(), + deploymentInfo.getBackendVersion(), + controllerName); } private Object[] parseInitArgs(byte[] initArgsbytes, BackendConfig backendConfig) @@ -101,8 +122,21 @@ public void ready() { return; } - /** Wait until there is no request in processing. It is used for stopping replica gracefully. */ - public void drainPendingQueries() { - backend.drainPendingQueries(); + /** + * Wait until there is no request in processing. It is used for stopping replica gracefully. + * + * @return true if it is ready for shutdown. + */ + public boolean prepareForShutdown() { + return backend.prepareForShutdown(); + } + + public byte[] reconfigure(Object userConfig) { + BackendVersion backendVersion = backend.reconfigure(userConfig); + return backendVersion.toByteArray(); + } + + public byte[] getVersion() { + return backend.getVersion().toByteArray(); } } diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java b/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java index ff19348098027..a24ceea124963 100644 --- a/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java +++ b/java/serve/src/main/java/io/ray/serve/ReplicaConfig.java @@ -12,13 +12,13 @@ public class ReplicaConfig implements Serializable { private String backendDef; - private Object[] initArgs; + private byte[] initArgs; private Map rayActorOptions; private Map resource; - public ReplicaConfig(String backendDef, Object[] initArgs, Map rayActorOptions) { + public ReplicaConfig(String backendDef, byte[] initArgs, Map rayActorOptions) { this.backendDef = backendDef; this.initArgs = initArgs; this.rayActorOptions = rayActorOptions; @@ -89,11 +89,11 @@ public void setBackendDef(String backendDef) { this.backendDef = backendDef; } - public Object[] getInitArgs() { + public byte[] getInitArgs() { return initArgs; } - public void setInitArgs(Object[] initArgs) { + public void setInitArgs(byte[] initArgs) { this.initArgs = initArgs; } diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaContext.java b/java/serve/src/main/java/io/ray/serve/ReplicaContext.java index 10c62cf7eb411..7bd768f7cdd53 100644 --- a/java/serve/src/main/java/io/ray/serve/ReplicaContext.java +++ b/java/serve/src/main/java/io/ray/serve/ReplicaContext.java @@ -3,7 +3,7 @@ /** Stores data for Serve API calls from within the user's backend code. */ public class ReplicaContext { - private String backendTag; + private String backendTag; // TODO deployment private String replicaTag; diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaSet.java b/java/serve/src/main/java/io/ray/serve/ReplicaSet.java new file mode 100644 index 0000000000000..1c7e757bba449 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/ReplicaSet.java @@ -0,0 +1,138 @@ +package io.ray.serve; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.runtime.metric.Gauge; +import io.ray.runtime.metric.Metrics; +import io.ray.runtime.metric.TagKey; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.util.CollectionUtil; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.commons.lang3.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Data structure representing a set of replica actor handles. */ +public class ReplicaSet { + + private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class); + + private volatile int maxConcurrentQueries = 8; + + private final Map, Set>> inFlightQueries; + + private AtomicInteger numQueuedQueries = new AtomicInteger(); + + private Gauge numQueuedQueriesGauge; + + public ReplicaSet(String backendTag) { + this.inFlightQueries = new ConcurrentHashMap<>(); + RayServeMetrics.execute( + () -> + this.numQueuedQueriesGauge = + Metrics.gauge() + .name(RayServeMetrics.SERVE_DEPLOYMENT_QUEUED_QUERIES.getName()) + .description(RayServeMetrics.SERVE_DEPLOYMENT_QUEUED_QUERIES.getDescription()) + .unit("") + .tags(ImmutableMap.of(RayServeMetrics.TAG_DEPLOYMENT, backendTag)) + .register()); + } + + public void setMaxConcurrentQueries(Object backendConfig) { + int newValue = ((BackendConfig) backendConfig).getMaxConcurrentQueries(); + if (newValue != this.maxConcurrentQueries) { + this.maxConcurrentQueries = newValue; + LOGGER.info("ReplicaSet: changing max_concurrent_queries to {}", newValue); + } + } + + public int getMaxConcurrentQueries() { + return maxConcurrentQueries; + } + + @SuppressWarnings("unchecked") + public synchronized void updateWorkerReplicas(Object actorSet) { + List actorNames = ((ActorSet) actorSet).getNamesList(); + Set> workerReplicas = new HashSet<>(); + if (!CollectionUtil.isEmpty(actorNames)) { + actorNames.forEach( + name -> + workerReplicas.add((ActorHandle) Ray.getActor(name).get())); + } + + Set> added = + new HashSet<>(Sets.difference(workerReplicas, inFlightQueries.keySet())); + Set> removed = + new HashSet<>(Sets.difference(inFlightQueries.keySet(), workerReplicas)); + + added.forEach(actorHandle -> inFlightQueries.put(actorHandle, Sets.newConcurrentHashSet())); + removed.forEach(actorHandle -> inFlightQueries.remove(actorHandle)); + + if (added.size() > 0 || removed.size() > 0) { + LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size()); + } + } + + /** + * Given a query, submit it to a replica and return the object ref. This method will keep track of + * the in flight queries for each replicas and only send a query to available replicas (determined + * by the backend max_concurrent_quries value.) + * + * @param query the incoming query. + * @return ray.ObjectRef + */ + public ObjectRef assignReplica(Query query) { + String endpoint = query.getMetadata().getEndpoint(); + numQueuedQueries.incrementAndGet(); + RayServeMetrics.execute( + () -> + numQueuedQueriesGauge.update( + numQueuedQueries.get(), + TagKey.tagsFromMap(ImmutableMap.of(RayServeMetrics.TAG_ENDPOINT, endpoint)))); + ObjectRef assignedRef = + tryAssignReplica(query); // TODO controll concurrency using maxConcurrentQueries + numQueuedQueries.decrementAndGet(); + RayServeMetrics.execute( + () -> + numQueuedQueriesGauge.update( + numQueuedQueries.get(), + TagKey.tagsFromMap(ImmutableMap.of(RayServeMetrics.TAG_ENDPOINT, endpoint)))); + return assignedRef; + } + + /** + * Try to assign query to a replica, return the object ref if succeeded or return None if it can't + * assign this query to any replicas. + * + * @param query query the incoming query. + * @return ray.ObjectRef + */ + private ObjectRef tryAssignReplica(Query query) { + + List> handles = new ArrayList<>(inFlightQueries.keySet()); + if (CollectionUtil.isEmpty(handles)) { + throw new RayServeException("ReplicaSet found no replica."); + } + int randomIndex = RandomUtils.nextInt(0, handles.size()); + ActorHandle replica = + handles.get(randomIndex); // TODO controll concurrency using maxConcurrentQueries + LOGGER.debug("Assigned query {} to replica {}.", query.getMetadata().getRequestId(), replica); + return replica + .task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs()) + .remote(); + } + + public Map, Set>> getInFlightQueries() { + return inFlightQueries; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/Router.java b/java/serve/src/main/java/io/ray/serve/Router.java new file mode 100644 index 0000000000000..5ef339d77767c --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/Router.java @@ -0,0 +1,64 @@ +package io.ray.serve; + +import com.google.common.collect.ImmutableMap; +import io.ray.api.BaseActorHandle; +import io.ray.api.ObjectRef; +import io.ray.runtime.metric.Count; +import io.ray.runtime.metric.Metrics; +import io.ray.serve.generated.RequestMetadata; +import io.ray.serve.poll.KeyListener; +import io.ray.serve.poll.KeyType; +import io.ray.serve.poll.LongPollClient; +import io.ray.serve.poll.LongPollNamespace; +import java.util.HashMap; +import java.util.Map; + +/** Router process incoming queries: choose backend, and assign replica. */ +public class Router { + + private ReplicaSet replicaSet; + + private Count numRouterRequests; + + private LongPollClient longPollClient; + + public Router(BaseActorHandle controllerHandle, String backendTag) { + this.replicaSet = new ReplicaSet(backendTag); + + RayServeMetrics.execute( + () -> + this.numRouterRequests = + Metrics.count() + .name(RayServeMetrics.SERVE_NUM_ROUTER_REQUESTS.getName()) + .description(RayServeMetrics.SERVE_NUM_ROUTER_REQUESTS.getDescription()) + .unit("") + .tags(ImmutableMap.of(RayServeMetrics.TAG_DEPLOYMENT, backendTag)) + .register()); + + Map keyListeners = new HashMap<>(); + keyListeners.put( + new KeyType(LongPollNamespace.BACKEND_CONFIGS, backendTag), + backendConfig -> replicaSet.setMaxConcurrentQueries(backendConfig)); // cross language + keyListeners.put( + new KeyType(LongPollNamespace.REPLICA_HANDLES, backendTag), + workerReplicas -> replicaSet.updateWorkerReplicas(workerReplicas)); // cross language + this.longPollClient = new LongPollClient(controllerHandle, keyListeners); + this.longPollClient.start(); + } + + /** + * Assign a query and returns an object ref represent the result. + * + * @param requestMetadata the metadata of incoming queries. + * @param requestArgs the request body of incoming queries. + * @return ray.ObjectRef + */ + public ObjectRef assignRequest(RequestMetadata requestMetadata, Object[] requestArgs) { + RayServeMetrics.execute(() -> numRouterRequests.inc(1.0)); + return replicaSet.assignReplica(new Query(requestMetadata, requestArgs)); + } + + public ReplicaSet getReplicaSet() { + return replicaSet; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/ServeController.java b/java/serve/src/main/java/io/ray/serve/ServeController.java new file mode 100644 index 0000000000000..1589f4c73b4c2 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/ServeController.java @@ -0,0 +1,6 @@ +package io.ray.serve; + +public interface ServeController { + + byte[] getAllEndpoints(); +} diff --git a/java/serve/src/main/java/io/ray/serve/ServeProxy.java b/java/serve/src/main/java/io/ray/serve/ServeProxy.java new file mode 100644 index 0000000000000..532a2413f9ba5 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/ServeProxy.java @@ -0,0 +1,14 @@ +package io.ray.serve; + +import java.util.Map; + +public interface ServeProxy { + + void init(Map config, ProxyRouter proxyRouter); + + default String getName() { + return getClass().getName(); + } + + default void registerServiceDiscovery() {} +} diff --git a/java/serve/src/main/java/io/ray/serve/api/Client.java b/java/serve/src/main/java/io/ray/serve/api/Client.java new file mode 100644 index 0000000000000..e5c63b5c8e184 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/api/Client.java @@ -0,0 +1,72 @@ +package io.ray.serve.api; + +import io.ray.api.ActorHandle; +import io.ray.api.BaseActorHandle; +import io.ray.api.PyActorHandle; +import io.ray.api.function.PyActorMethod; +import io.ray.serve.RayServeException; +import io.ray.serve.RayServeHandle; +import io.ray.serve.ServeController; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.util.LogUtil; +import io.ray.serve.util.ServeProtoUtil; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class Client { + + private static final Logger LOGGER = LoggerFactory.getLogger(Client.class); + + private BaseActorHandle controller; + + private Map handleCache = new ConcurrentHashMap<>(); + + public Client(BaseActorHandle controller, String controllerName, boolean detached) { + this.controller = controller; + } + + /** + * Retrieve RayServeHandle for service endpoint to invoke it from Python. + * + * @param endpointName A registered service endpoint. + * @param missingOk If true, then Serve won't check the endpoint is registered. False by default. + * @return + */ + @SuppressWarnings("unchecked") + public RayServeHandle getHandle(String endpointName, boolean missingOk) { + + String cacheKey = endpointName + "_" + missingOk; + if (handleCache.containsKey(cacheKey)) { + return handleCache.get(cacheKey); + } + + Map endpoints = null; + if (controller instanceof PyActorHandle) { + endpoints = + ServeProtoUtil.parseEndpointSet( + (byte[]) + ((PyActorHandle) controller) + .task(PyActorMethod.of("get_all_endpoints")) + .remote() + .get()); + } else { + LOGGER.warn("Client only support Python controller now."); + endpoints = + ServeProtoUtil.parseEndpointSet( + ((ActorHandle) controller) + .task(ServeController::getAllEndpoints) + .remote() + .get()); + } + + if (!missingOk && (endpoints == null || !endpoints.containsKey(endpointName))) { + throw new RayServeException(LogUtil.format("Endpoint {} does not exist.", endpointName)); + } + + RayServeHandle handle = new RayServeHandle(controller, endpointName, null, null); + handleCache.put(cacheKey, handle); + return handle; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/api/Serve.java b/java/serve/src/main/java/io/ray/serve/api/Serve.java index 8133e5bd7f23e..3b2c0ed7a2833 100644 --- a/java/serve/src/main/java/io/ray/serve/api/Serve.java +++ b/java/serve/src/main/java/io/ray/serve/api/Serve.java @@ -1,12 +1,20 @@ package io.ray.serve.api; +import com.google.common.base.Preconditions; +import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; +import io.ray.serve.Constants; import io.ray.serve.RayServeException; import io.ray.serve.ReplicaContext; +import io.ray.serve.util.LogUtil; +import java.util.Optional; /** Ray Serve global API. TODO: will be riched in the Java SDK/API PR. */ public class Serve { - public static ReplicaContext INTERNAL_REPLICA_CONTEXT; + private static ReplicaContext INTERNAL_REPLICA_CONTEXT; + + private static Client GLOBAL_CLIENT; /** * Set replica information to global context. @@ -18,11 +26,14 @@ public class Serve { */ public static void setInternalReplicaContext( String backendTag, String replicaTag, String controllerName, Object servableObject) { - // TODO singleton. INTERNAL_REPLICA_CONTEXT = new ReplicaContext(backendTag, replicaTag, controllerName, servableObject); } + public static void setInternalReplicaContext(ReplicaContext replicaContext) { + INTERNAL_REPLICA_CONTEXT = replicaContext; + } + /** * Get the global replica context. * @@ -35,4 +46,43 @@ public static ReplicaContext getReplicaContext() { } return INTERNAL_REPLICA_CONTEXT; } + + public static Client getGlobalClient() { + if (GLOBAL_CLIENT != null) { + return GLOBAL_CLIENT; + } + synchronized (Client.class) { + if (GLOBAL_CLIENT != null) { + return GLOBAL_CLIENT; + } + return connect(); + } + } + + public static void setGlobalClient(Client client) { + GLOBAL_CLIENT = client; + } + + public static Client connect() { + + if (!Ray.isInitialized()) { + Ray.init(); + } + + String controllerName = + INTERNAL_REPLICA_CONTEXT != null + ? INTERNAL_REPLICA_CONTEXT.getInternalControllerName() + : Constants.SERVE_CONTROLLER_NAME; + + Optional optional = Ray.getActor(controllerName); + Preconditions.checkState( + optional.isPresent(), + LogUtil.format( + "There is no instance running on this Ray cluster. " + + "Please call `serve.start(detached=True) to start one.")); + + Client client = new Client(optional.get(), controllerName, true); + setGlobalClient(client); + return client; + } } diff --git a/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java b/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java index 91e9ceca04723..514193e28c37d 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java +++ b/java/serve/src/main/java/io/ray/serve/poll/KeyListener.java @@ -4,5 +4,5 @@ @FunctionalInterface public interface KeyListener { - void notifyChanged(Object object); + void notifyChanged(Object updatedObject); } diff --git a/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java b/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java index 4017be3af9db9..308391254e109 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java +++ b/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java @@ -1,6 +1,7 @@ package io.ray.serve.poll; import com.google.common.base.Preconditions; +import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.BaseActorHandle; import io.ray.api.ObjectRef; import io.ray.api.PyActorHandle; @@ -8,8 +9,16 @@ import io.ray.runtime.exception.RayActorException; import io.ray.runtime.exception.RayTaskException; import io.ray.serve.Constants; +import io.ray.serve.RayServeException; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.UpdatedObject; +import io.ray.serve.util.LogUtil; +import io.ray.serve.util.ServeProtoUtil; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; +import org.apache.commons.lang3.builder.ReflectionToStringBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,6 +42,26 @@ public class LongPollClient { /** An async thread to post the callback into. */ private Thread pollThread; + private static final Map> DESERIALIZERS = + new HashMap<>(); + + static { + DESERIALIZERS.put( + LongPollNamespace.BACKEND_CONFIGS, body -> ServeProtoUtil.parseBackendConfig(body)); + DESERIALIZERS.put( + LongPollNamespace.REPLICA_HANDLES, body -> ServeProtoUtil.parseEndpointSet(body)); + DESERIALIZERS.put( + LongPollNamespace.REPLICA_HANDLES, + body -> { + try { + return ActorSet.parseFrom(body); + } catch (InvalidProtocolBufferException e) { + throw new RayServeException( + LogUtil.format("Failed to parse ActorSet from protobuf bytes."), e); + } + }); + } + public LongPollClient(BaseActorHandle hostActor, Map keyListeners) { Preconditions.checkArgument(keyListeners != null && keyListeners.size() != 0); @@ -51,7 +80,7 @@ public LongPollClient(BaseActorHandle hostActor, Map keyLi try { pollNext(); } catch (RayActorException e) { - LOGGER.debug("LongPollClient failed to connect to host. Shutting down."); + LOGGER.error("LongPollClient failed to connect to host. Shutting down."); break; } catch (RayTaskException e) { LOGGER.error("LongPollHost errored", e); @@ -71,24 +100,44 @@ public void start() { pollThread.start(); } - /** Poll the update. */ - @SuppressWarnings("unchecked") - public void pollNext() { + /** + * Poll the update. + * + * @throws InvalidProtocolBufferException if the protobuf deserialization fails. + */ + public void pollNext() throws InvalidProtocolBufferException { currentRef = ((PyActorHandle) hostActor) .task(PyActorMethod.of(Constants.CONTROLLER_LISTEN_FOR_CHANGE_METHOD), snapshotIds) .remote(); - processUpdate((Map) currentRef.get()); + processUpdate(ServeProtoUtil.parseUpdatedObjects((byte[]) currentRef.get())); } public void processUpdate(Map updates) { - - LOGGER.debug("LongPollClient received updates for keys: {}", updates.keySet()); - + if (updates == null || updates.isEmpty()) { + LOGGER.info("LongPollClient received nothing."); + return; + } + LOGGER.info("LongPollClient received updates for keys: {}", updates.keySet()); for (Map.Entry entry : updates.entrySet()) { - objectSnapshots.put(entry.getKey(), entry.getValue().getObjectSnapshot()); + KeyType keyType = entry.getKey(); + UpdatedObject updatedObject = entry.getValue(); + + Object objectSnapshot = + DESERIALIZERS + .get(keyType.getLongPollNamespace()) + .apply(updatedObject.getObjectSnapshot().toByteArray()); + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "The updated object for key {} is {}", + keyType, + ReflectionToStringBuilder.toString(objectSnapshot)); + } + + keyListeners.get(entry.getKey()).notifyChanged(objectSnapshot); + objectSnapshots.put(entry.getKey(), objectSnapshot); snapshotIds.put(entry.getKey(), entry.getValue().getSnapshotId()); - keyListeners.get(entry.getKey()).notifyChanged(entry.getValue().getObjectSnapshot()); } } diff --git a/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java b/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java index 466af829167e8..71b3a2e8baa1e 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java +++ b/java/serve/src/main/java/io/ray/serve/poll/LongPollNamespace.java @@ -4,9 +4,7 @@ public enum LongPollNamespace { REPLICA_HANDLES, - TRAFFIC_POLICIES, - BACKEND_CONFIGS, - ROUTE_TABLE + ROUTE_TABLE; } diff --git a/java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java b/java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java deleted file mode 100644 index 3f3ddc63c1ae2..0000000000000 --- a/java/serve/src/main/java/io/ray/serve/poll/UpdatedObject.java +++ /dev/null @@ -1,33 +0,0 @@ -package io.ray.serve.poll; - -import java.io.Serializable; - -/** The updated object that long poll client received. */ -public class UpdatedObject implements Serializable { - - private static final long serialVersionUID = 6245682414826079438L; - - private Object objectSnapshot; - - /** - * The identifier for the object's version. There is not sequential relation among different - * object's snapshot_ids. - */ - private int snapshotId; - - public Object getObjectSnapshot() { - return objectSnapshot; - } - - public void setObjectSnapshot(Object objectSnapshot) { - this.objectSnapshot = objectSnapshot; - } - - public int getSnapshotId() { - return snapshotId; - } - - public void setSnapshotId(int snapshotId) { - this.snapshotId = snapshotId; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java b/java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java new file mode 100644 index 0000000000000..cd66932f48276 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/util/CollectionUtil.java @@ -0,0 +1,10 @@ +package io.ray.serve.util; + +import java.util.Collection; + +public class CollectionUtil { + + public static boolean isEmpty(Collection collection) { + return collection == null || collection.isEmpty(); + } +} diff --git a/java/serve/src/main/java/io/ray/serve/util/CommonUtil.java b/java/serve/src/main/java/io/ray/serve/util/CommonUtil.java new file mode 100644 index 0000000000000..a32ee212196d8 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/util/CommonUtil.java @@ -0,0 +1,13 @@ +package io.ray.serve.util; + +import org.apache.commons.lang3.StringUtils; + +public class CommonUtil { + + public static String formatActorName(String controllerName, String actorName) { + if (StringUtils.isBlank(controllerName)) { + return actorName; + } + return controllerName + ":" + actorName; + } +} diff --git a/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java b/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java index 5de1142433008..ae449dd714733 100644 --- a/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java +++ b/java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java @@ -2,6 +2,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Executable; +import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; @@ -178,4 +179,17 @@ public static List getMethodStrings(Class targetClass) { } return methodStrings; } + + @SuppressWarnings("unchecked") + public static List getInstancesByClassNames(String classNames, Class cls) + throws ClassNotFoundException, InstantiationException, IllegalAccessException, + IllegalArgumentException, InvocationTargetException, NoSuchMethodException, + SecurityException { + String[] classNameArray = StringUtils.split(classNames, ";"); + List isntances = new ArrayList<>(); + for (String className : classNameArray) { + isntances.add((T) Class.forName(className).getConstructor().newInstance()); + } + return isntances; + } } diff --git a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java index b1d02a046063e..1a1c0c082d3f8 100644 --- a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java +++ b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java @@ -2,26 +2,42 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Lists; +import com.google.gson.Gson; import com.google.protobuf.InvalidProtocolBufferException; import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.Constants; import io.ray.serve.RayServeException; import io.ray.serve.generated.BackendConfig; import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.BackendVersion; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.generated.EndpointSet; +import io.ray.serve.generated.LongPollResult; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.generated.RequestWrapper; +import io.ray.serve.generated.UpdatedObject; +import io.ray.serve.poll.KeyType; +import java.util.HashMap; +import java.util.Map; import org.apache.commons.lang3.StringUtils; public class ServeProtoUtil { - public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) - throws InvalidProtocolBufferException { + private static final Gson GSON = new Gson(); + + public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) { // Get a builder from BackendConfig(bytes) or create a new one. BackendConfig.Builder builder = null; if (backendConfigBytes == null) { builder = BackendConfig.newBuilder(); } else { - BackendConfig backendConfig = BackendConfig.parseFrom(backendConfigBytes); + BackendConfig backendConfig = null; + try { + backendConfig = BackendConfig.parseFrom(backendConfigBytes); + } catch (InvalidProtocolBufferException e) { + throw new RayServeException("Failed to parse BackendConfig from protobuf bytes.", e); + } if (backendConfig == null) { builder = BackendConfig.newBuilder(); } else { @@ -40,12 +56,12 @@ public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) builder.setMaxConcurrentQueries(100); } - if (builder.getExperimentalGracefulShutdownWaitLoopS() == 0) { - builder.setExperimentalGracefulShutdownWaitLoopS(2); + if (builder.getGracefulShutdownWaitLoopS() == 0) { + builder.setGracefulShutdownWaitLoopS(2); } - if (builder.getExperimentalGracefulShutdownTimeoutS() == 0) { - builder.setExperimentalGracefulShutdownTimeoutS(20); + if (builder.getGracefulShutdownTimeoutS() == 0) { + builder.setGracefulShutdownTimeoutS(20); } if (builder.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) { @@ -84,7 +100,7 @@ public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes) // Set default values. if (StringUtils.isBlank(builder.getCallMethod())) { - builder.setCallMethod("call"); + builder.setCallMethod(Constants.DEFAULT_CALL_METHOD); } return builder.build(); @@ -108,4 +124,47 @@ public static RequestWrapper parseRequestWrapper(byte[] httpRequestWrapperBytes) return builder.build(); } + + public static Map parseUpdatedObjects(byte[] longPollResultBytes) + throws InvalidProtocolBufferException { + if (longPollResultBytes == null) { + return null; + } + LongPollResult longPollResult = LongPollResult.parseFrom(longPollResultBytes); + Map updatedObjects = longPollResult.getUpdatedObjectsMap(); + if (updatedObjects == null || updatedObjects.isEmpty()) { + return null; + } + Map udpates = new HashMap<>(updatedObjects.size()); + updatedObjects.forEach( + (key, value) -> udpates.put(ServeProtoUtil.GSON.fromJson(key, KeyType.class), value)); + return udpates; + } + + public static Map parseEndpointSet(byte[] endpointSetBytes) { + if (endpointSetBytes == null) { + return null; + } + EndpointSet endpointSet = null; + try { + endpointSet = EndpointSet.parseFrom(endpointSetBytes); + } catch (InvalidProtocolBufferException e) { + throw new RayServeException("Failed to parse EndpointSet from protobuf bytes.", e); + } + if (endpointSet == null) { + return null; + } + return endpointSet.getEndpointsMap(); + } + + public static BackendVersion parseBackendVersion(byte[] backendVersionBytes) { + if (backendVersionBytes == null) { + return null; + } + try { + return BackendVersion.parseFrom(backendVersionBytes); + } catch (InvalidProtocolBufferException e) { + throw new RayServeException("Failed to parse BackendVersion from protobuf bytes.", e); + } + } } diff --git a/java/serve/src/main/java/io/ray/serve/util/SocketUtil.java b/java/serve/src/main/java/io/ray/serve/util/SocketUtil.java new file mode 100644 index 0000000000000..ab93a6e152210 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/util/SocketUtil.java @@ -0,0 +1,49 @@ +package io.ray.serve.util; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + +public class SocketUtil { + + public static final int PORT_RANGE_MAX = 65535; + + public static int findAvailableTcpPort(int minPort) { + int portRange = PORT_RANGE_MAX - minPort; + int candidatePort = minPort; + int searchCounter = 0; + while (!isPortAvailable(candidatePort)) { + candidatePort++; + if (++searchCounter > portRange) { + throw new IllegalStateException( + String.format( + "Could not find an available tcp port in the range [%d, %d] after %d attempts.", + minPort, PORT_RANGE_MAX, searchCounter)); + } + } + return candidatePort; + } + + public static boolean isPortAvailable(int port) { + ServerSocket socket; + try { + socket = new ServerSocket(); + } catch (IOException e) { + throw new IllegalStateException("Unable to create ServerSocket.", e); + } + + try { + InetSocketAddress sa = new InetSocketAddress(port); + socket.bind(sa); + return true; + } catch (IOException ex) { + return false; + } finally { + try { + socket.close(); + } catch (IOException ex) { + // ignore this exception for now + } + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/DummyServeController.java b/java/serve/src/test/java/io/ray/serve/DummyServeController.java new file mode 100644 index 0000000000000..6ee319a477898 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/DummyServeController.java @@ -0,0 +1,21 @@ +package io.ray.serve; + +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.generated.EndpointSet; +import java.util.Map; + +public class DummyServeController implements ServeController { + + private Map endpoints; + + @Override + public byte[] getAllEndpoints() { + EndpointSet.Builder builder = EndpointSet.newBuilder(); + builder.putAllEndpoints(endpoints); + return builder.build().toByteArray(); + } + + public void setEndpoints(Map endpoints) { + this.endpoints = endpoints; + } +} diff --git a/java/serve/src/test/java/io/ray/serve/HttpProxyTest.java b/java/serve/src/test/java/io/ray/serve/HttpProxyTest.java new file mode 100644 index 0000000000000..5166603662c82 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/HttpProxyTest.java @@ -0,0 +1,74 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.serve.api.Serve; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.util.CommonUtil; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.util.HashMap; +import java.util.Map; +import org.apache.commons.lang3.RandomStringUtils; +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.client5.http.classic.methods.HttpPost; +import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; +import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class HttpProxyTest { + + @Test + public void test() throws IOException { + + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String controllerName = + CommonUtil.formatActorName( + Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); + String endpointName = "HTTPProxyTest"; + String route = "/route"; + + // Controller + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + + Map endpointInfos = new HashMap<>(); + endpointInfos.put( + endpointName, + EndpointInfo.newBuilder().setEndpointName(endpointName).setRoute(route).build()); + controllerHandle.task(DummyServeController::setEndpoints, endpointInfos).remote(); + + Serve.setInternalReplicaContext(null, null, controllerName, null); + + // ProxyRouter updates routes. + ProxyRouter proxyRouter = new ProxyRouter(); + proxyRouter.updateRoutes(endpointInfos); + + // HTTP proxy. + HttpProxy httpProxy = new HttpProxy(); + httpProxy.init(null, proxyRouter); + + // Send request. + HttpClient httpClient = HttpClientBuilder.create().build(); + HttpPost httpPost = new HttpPost("http://localhost:" + httpProxy.getPort() + route); + try (CloseableHttpResponse httpResponse = + (CloseableHttpResponse) httpClient.execute(httpPost)) { + + // No Backend replica, so error is expected. + int status = httpResponse.getCode(); + Assert.assertEquals(status, HttpURLConnection.HTTP_INTERNAL_ERROR); + } + + } finally { + if (!inited) { + Ray.shutdown(); + } + Serve.setInternalReplicaContext(null); + Serve.setGlobalClient(null); + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java b/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java new file mode 100644 index 0000000000000..6b1daa11b1141 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java @@ -0,0 +1,110 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.api.Serve; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendVersion; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.util.CommonUtil; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.util.HashMap; +import java.util.Map; +import org.apache.commons.lang3.RandomStringUtils; +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.client5.http.classic.methods.HttpPost; +import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; +import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ProxyActorTest { + + @Test + public void test() throws IOException { + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String prefix = "ProxyActorTest"; + String controllerName = + CommonUtil.formatActorName( + Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); + String backendTag = prefix; + String replicaTag = prefix; + String endpointName = prefix; + String route = "/route"; + String version = "v1"; + + // Controller + ActorHandle controller = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + Map endpointInfos = new HashMap<>(); + endpointInfos.put( + endpointName, + EndpointInfo.newBuilder().setEndpointName(endpointName).setRoute(route).build()); + controller.task(DummyServeController::setEndpoints, endpointInfos).remote(); + + // Replica + DeploymentInfo deploymentInfo = new DeploymentInfo(); + deploymentInfo.setBackendConfig(BackendConfig.newBuilder().build().toByteArray()); + deploymentInfo.setBackendVersion( + BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); + deploymentInfo.setReplicaConfig( + new ReplicaConfig(DummyBackendReplica.class.getName(), null, new HashMap<>())); + + ActorHandle replica = + Ray.actor( + RayServeWrappedReplica::new, + backendTag, + replicaTag, + deploymentInfo, + controllerName) + .setName(replicaTag) + .remote(); + replica.task(RayServeWrappedReplica::ready).remote(); + + // ProxyActor + ProxyActor proxyActor = new ProxyActor(controllerName, null); + proxyActor.getProxyRouter().updateRoutes(endpointInfos); + proxyActor + .getProxyRouter() + .getHandles() + .get(endpointName) + .getRouter() + .getReplicaSet() + .updateWorkerReplicas(ActorSet.newBuilder().addNames(replicaTag).build()); + + // Send request. + HttpClient httpClient = HttpClientBuilder.create().build(); + HttpPost httpPost = + new HttpPost( + "http://localhost:" + + ((HttpProxy) proxyActor.getProxies().get(HttpProxy.PROXY_NAME)).getPort() + + route); + try (CloseableHttpResponse httpResponse = + (CloseableHttpResponse) httpClient.execute(httpPost)) { + + int status = httpResponse.getCode(); + Assert.assertEquals(status, HttpURLConnection.HTTP_OK); + Object result = + MessagePackSerializer.decode( + EntityUtils.toByteArray(httpResponse.getEntity()), Object.class); + + Assert.assertNotNull(result); + Assert.assertEquals("1", result.toString()); + } + + } finally { + if (!inited) { + Ray.shutdown(); + } + Serve.setInternalReplicaContext(null); + Serve.setGlobalClient(null); + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java b/java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java new file mode 100644 index 0000000000000..03535a0575a79 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/ProxyRouterTest.java @@ -0,0 +1,68 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.serve.api.Serve; +import io.ray.serve.generated.EndpointInfo; +import io.ray.serve.util.CommonUtil; +import java.util.HashMap; +import java.util.Map; +import org.apache.commons.lang3.RandomStringUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ProxyRouterTest { + + @Test + public void test() { + + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String prefix = "ProxyRouterTest"; + String controllerName = + CommonUtil.formatActorName( + Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); + String endpointName1 = prefix + "_1"; + String endpointName2 = prefix + "_2"; + String route1 = "/route1"; + + // Controller + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + Map endpointInfos = new HashMap<>(); + endpointInfos.put( + endpointName1, + EndpointInfo.newBuilder().setEndpointName(endpointName1).setRoute(route1).build()); + endpointInfos.put( + endpointName2, EndpointInfo.newBuilder().setEndpointName(endpointName2).build()); + controllerHandle.task(DummyServeController::setEndpoints, endpointInfos).remote(); + + Serve.setInternalReplicaContext(null, null, controllerName, null); + + // ProxyRouter updates routes. + ProxyRouter proxyRouter = new ProxyRouter(); + proxyRouter.updateRoutes(endpointInfos); + + // Check result. + Map routeInfo = proxyRouter.getRouteInfo(); + Assert.assertNotNull(routeInfo); + Assert.assertNotNull(routeInfo.get(route1)); + Assert.assertEquals(routeInfo.get(route1).getRoute(), route1); + Assert.assertEquals(routeInfo.get(route1).getEndpointName(), endpointName1); + Assert.assertNotNull(routeInfo.get(endpointName2)); + Assert.assertEquals(routeInfo.get(endpointName2).getEndpointName(), endpointName2); + Map handles = proxyRouter.getHandles(); + Assert.assertNotNull(handles); + Assert.assertNotNull(handles.get(endpointName1)); + Assert.assertNotNull(handles.get(endpointName2)); + } finally { + if (!inited) { + Ray.shutdown(); + } + Serve.setInternalReplicaContext(null); + Serve.setGlobalClient(null); + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java b/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java new file mode 100644 index 0000000000000..9e4ac68b612fd --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java @@ -0,0 +1,76 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.BackendVersion; +import java.util.HashMap; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RayServeHandleTest { + + @Test + public void test() { + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String backendTag = "RayServeHandleTest"; + String controllerName = backendTag + "_controller"; + String replicaTag = backendTag + "_replica"; + String actorName = replicaTag; + String version = "v1"; + + // Controller + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + + // Replica + BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); + backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); + byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + + Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; + byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); + + DeploymentInfo deploymentInfo = new DeploymentInfo(); + deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setBackendVersion( + BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); + deploymentInfo.setReplicaConfig( + new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); + + ActorHandle replicaHandle = + Ray.actor( + RayServeWrappedReplica::new, + backendTag, + replicaTag, + deploymentInfo, + controllerName) + .setName(actorName) + .remote(); + replicaHandle.task(RayServeWrappedReplica::ready).remote(); + + // RayServeHandle + RayServeHandle rayServeHandle = + new RayServeHandle(controllerHandle, backendTag, null, null) + .setMethodName("getBackendTag"); + ActorSet.Builder builder = ActorSet.newBuilder(); + builder.addNames(actorName); + rayServeHandle.getRouter().getReplicaSet().updateWorkerReplicas(builder.build()); + + // remote + ObjectRef resultRef = rayServeHandle.remote(null); + Assert.assertEquals((String) resultRef.get(), backendTag); + } finally { + if (!inited) { + Ray.shutdown(); + } + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java index 7cc7746ff165c..065b74ac1fc0e 100644 --- a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java +++ b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java @@ -6,9 +6,12 @@ import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.generated.BackendConfig; import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.BackendVersion; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.generated.RequestWrapper; import java.io.IOException; +import java.util.HashMap; +import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -17,7 +20,6 @@ public class RayServeReplicaTest { @SuppressWarnings("unused") @Test public void test() throws IOException { - boolean inited = Ray.isInitialized(); Ray.init(); @@ -25,38 +27,40 @@ public void test() throws IOException { String controllerName = "RayServeReplicaTest"; String backendTag = "b_tag"; String replicaTag = "r_tag"; + String version = "v1"; - ActorHandle controllerHandle = - Ray.actor(ReplicaContext::new, backendTag, replicaTag, controllerName, new Object()) - .setName(controllerName) - .remote(); + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); - byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); - Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; - byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); + DeploymentInfo deploymentInfo = new DeploymentInfo(); + deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setBackendVersion( + BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); + deploymentInfo.setReplicaConfig( + new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); + ActorHandle backendHandle = Ray.actor( RayServeWrappedReplica::new, backendTag, replicaTag, - "io.ray.serve.ReplicaContext", - initArgsBytes, - backendConfigBytes, + deploymentInfo, controllerName) .remote(); + // ready backendHandle.task(RayServeWrappedReplica::ready).remote(); + // handle request RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); - requestMetadata.setRequestId("RayServeReplicaTest"); + requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); requestMetadata.setCallMethod("getBackendTag"); - RequestWrapper.Builder requestWrapper = RequestWrapper.newBuilder(); ObjectRef resultRef = @@ -66,8 +70,22 @@ public void test() throws IOException { requestMetadata.build().toByteArray(), requestWrapper.build().toByteArray()) .remote(); - Assert.assertEquals((String) resultRef.get(), backendTag); + + // reconfigure + ObjectRef versionRef = + backendHandle.task(RayServeWrappedReplica::reconfigure, (Object) null).remote(); + Assert.assertEquals(BackendVersion.parseFrom(versionRef.get()).getCodeVersion(), version); + + // get version + versionRef = backendHandle.task(RayServeWrappedReplica::getVersion).remote(); + Assert.assertEquals(BackendVersion.parseFrom(versionRef.get()).getCodeVersion(), version); + + // prepare for shutdown + ObjectRef shutdownRef = + backendHandle.task(RayServeWrappedReplica::prepareForShutdown).remote(); + Assert.assertTrue(shutdownRef.get()); + } finally { if (!inited) { Ray.shutdown(); diff --git a/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java new file mode 100644 index 0000000000000..513d27e4bb6b1 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java @@ -0,0 +1,108 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.BackendVersion; +import io.ray.serve.generated.RequestMetadata; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import org.apache.commons.lang3.RandomStringUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ReplicaSetTest { + + private String backendTag = "ReplicaSetTest"; + + @Test + public void setMaxConcurrentQueriesTest() { + ReplicaSet replicaSet = new ReplicaSet(backendTag); + BackendConfig.Builder builder = BackendConfig.newBuilder(); + builder.setMaxConcurrentQueries(200); + + replicaSet.setMaxConcurrentQueries(builder.build()); + Assert.assertEquals(replicaSet.getMaxConcurrentQueries(), 200); + } + + @Test + public void updateWorkerReplicasTest() { + ReplicaSet replicaSet = new ReplicaSet(backendTag); + ActorSet.Builder builder = ActorSet.newBuilder(); + + replicaSet.updateWorkerReplicas(builder.build()); + Map, Set>> inFlightQueries = + replicaSet.getInFlightQueries(); + Assert.assertTrue(inFlightQueries.isEmpty()); + } + + @SuppressWarnings("unused") + @Test + public void assignReplicaTest() { + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String controllerName = backendTag + "_controller"; + String replicaTag = backendTag + "_replica"; + String actorName = replicaTag; + String version = "v1"; + + // Controller + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + + // Replica + BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); + backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); + byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + + Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; + byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); + + DeploymentInfo deploymentInfo = new DeploymentInfo(); + deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setBackendVersion( + BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); + deploymentInfo.setReplicaConfig( + new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); + + ActorHandle replicaHandle = + Ray.actor( + RayServeWrappedReplica::new, + backendTag, + replicaTag, + deploymentInfo, + controllerName) + .setName(actorName) + .remote(); + replicaHandle.task(RayServeWrappedReplica::ready).remote(); + + // ReplicaSet + ReplicaSet replicaSet = new ReplicaSet(backendTag); + ActorSet.Builder builder = ActorSet.newBuilder(); + builder.addNames(actorName); + replicaSet.updateWorkerReplicas(builder.build()); + + // assign + + RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); + requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); + requestMetadata.setCallMethod("getBackendTag"); + + Query query = new Query(requestMetadata.build(), null); + ObjectRef resultRef = replicaSet.assignReplica(query); + + Assert.assertEquals((String) resultRef.get(), backendTag); + } finally { + if (!inited) { + Ray.shutdown(); + } + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/RouterTest.java b/java/serve/src/test/java/io/ray/serve/RouterTest.java new file mode 100644 index 0000000000000..3312179912e38 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/RouterTest.java @@ -0,0 +1,80 @@ +package io.ray.serve; + +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.generated.ActorSet; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.BackendVersion; +import io.ray.serve.generated.RequestMetadata; +import java.util.HashMap; +import org.apache.commons.lang3.RandomStringUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RouterTest { + + @Test + public void test() { + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String backendTag = "RouterTest"; + String controllerName = backendTag + "_controller"; + String replicaTag = backendTag + "_replica"; + String actorName = replicaTag; + String version = "v1"; + + // Controller + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + + // Replica + BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); + backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); + byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + + Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; + byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); + + DeploymentInfo deploymentInfo = new DeploymentInfo(); + deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setBackendVersion( + BackendVersion.newBuilder().setCodeVersion(version).build().toByteArray()); + deploymentInfo.setReplicaConfig( + new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>())); + + ActorHandle replicaHandle = + Ray.actor( + RayServeWrappedReplica::new, + backendTag, + replicaTag, + deploymentInfo, + controllerName) + .setName(actorName) + .remote(); + replicaHandle.task(RayServeWrappedReplica::ready).remote(); + + // Router + Router router = new Router(controllerHandle, backendTag); + ActorSet.Builder builder = ActorSet.newBuilder(); + builder.addNames(actorName); + router.getReplicaSet().updateWorkerReplicas(builder.build()); + + // assign + RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); + requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10)); + requestMetadata.setCallMethod("getBackendTag"); + + ObjectRef resultRef = router.assignRequest(requestMetadata.build(), null); + Assert.assertEquals((String) resultRef.get(), backendTag); + } finally { + if (!inited) { + Ray.shutdown(); + } + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/api/ClientTest.java b/java/serve/src/test/java/io/ray/serve/api/ClientTest.java new file mode 100644 index 0000000000000..c3489bc1a1a19 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/api/ClientTest.java @@ -0,0 +1,47 @@ +package io.ray.serve.api; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.serve.DummyServeController; +import io.ray.serve.RayServeHandle; +import io.ray.serve.generated.EndpointInfo; +import java.util.HashMap; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ClientTest { + + @Test + public void getHandleTest() { + + boolean inited = Ray.isInitialized(); + Ray.init(); + + try { + String prefix = "ClientTest"; + String controllerName = prefix + "_controller"; + String endpointName = prefix + "_endpoint"; + + // Controller. + ActorHandle controllerHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + + // Mock endpoints. + Map endpoints = new HashMap<>(); + endpoints.put(endpointName, EndpointInfo.newBuilder().setEndpointName(endpointName).build()); + controllerHandle.task(DummyServeController::setEndpoints, endpoints).remote(); + + // Client. + Client client = new Client(controllerHandle, controllerName, true); + + // Get handle. + RayServeHandle rayServeHandle = client.getHandle(endpointName, false); + Assert.assertNotNull(rayServeHandle); + } finally { + if (!inited) { + Ray.shutdown(); + } + } + } +} diff --git a/java/serve/src/test/java/io/ray/serve/api/ServeTest.java b/java/serve/src/test/java/io/ray/serve/api/ServeTest.java index b63a709a167de..cf470e8ce2248 100644 --- a/java/serve/src/test/java/io/ray/serve/api/ServeTest.java +++ b/java/serve/src/test/java/io/ray/serve/api/ServeTest.java @@ -1,7 +1,12 @@ package io.ray.serve.api; -import io.ray.serve.RayServeException; +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.serve.Constants; +import io.ray.serve.DummyServeController; import io.ray.serve.ReplicaContext; +import io.ray.serve.util.CommonUtil; +import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -10,31 +15,53 @@ public class ServeTest { @Test public void replicaContextTest() { - ReplicaContext preContext = Serve.INTERNAL_REPLICA_CONTEXT; - ReplicaContext replicaContext; - - // Test null replica context. - Serve.INTERNAL_REPLICA_CONTEXT = null; try { - replicaContext = Serve.getReplicaContext(); - Assert.assertTrue(false, "expect RayServeException"); - } catch (RayServeException e) { + // Test context setting and getting. + String backendTag = "backendTag"; + String replicaTag = "replicaTag"; + String controllerName = "controllerName"; + Object servableObject = new Object(); + Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, servableObject); + ReplicaContext replicaContext = Serve.getReplicaContext(); + Assert.assertNotNull(replicaContext, "no replica context"); + Assert.assertEquals(replicaContext.getBackendTag(), backendTag); + Assert.assertEquals(replicaContext.getReplicaTag(), replicaTag); + Assert.assertEquals(replicaContext.getInternalControllerName(), controllerName); + } finally { + // Recover context. + Serve.setInternalReplicaContext(null); } + } - // Test context setting and getting. - String backendTag = "backendTag"; - String replicaTag = "replicaTag"; - String controllerName = "controllerName"; - Object servableObject = new Object(); - Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, servableObject); - - replicaContext = Serve.getReplicaContext(); - Assert.assertNotNull(replicaContext, "no replica context"); - Assert.assertEquals(replicaContext.getBackendTag(), backendTag); - Assert.assertEquals(replicaContext.getReplicaTag(), replicaTag); - Assert.assertEquals(replicaContext.getInternalControllerName(), controllerName); + @SuppressWarnings("unused") + @Test + public void getGlobalClientTest() { + boolean inited = Ray.isInitialized(); + Ray.init(); + try { + Client client = null; + try { + client = Serve.getGlobalClient(); + Assert.assertTrue(false, "Expect IllegalStateException here!"); + } catch (IllegalStateException e) { + } + Assert.assertNull(client); - Serve.INTERNAL_REPLICA_CONTEXT = preContext; + String controllerName = + CommonUtil.formatActorName( + Constants.SERVE_CONTROLLER_NAME, RandomStringUtils.randomAlphabetic(6)); + ActorHandle actorHandle = + Ray.actor(DummyServeController::new).setName(controllerName).remote(); + Serve.setInternalReplicaContext(null, null, controllerName, null); + client = Serve.getGlobalClient(); + Assert.assertNotNull(client); + } finally { + if (!inited) { + Ray.shutdown(); + } + Serve.setInternalReplicaContext(null); + Serve.setGlobalClient(null); + } } } diff --git a/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java b/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java index 628f5ff4a89c4..710ad97128ede 100644 --- a/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java +++ b/java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java @@ -1,12 +1,15 @@ package io.ray.serve.poll; +import com.google.gson.Gson; import org.testng.Assert; import org.testng.annotations.Test; public class KeyTypeTest { + private static final Gson GSON = new Gson(); + @Test - public void test() { + public void hashTest() { KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1"); KeyType k2 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1"); KeyType k3 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, null); @@ -28,4 +31,14 @@ public void test() { Assert.assertNotEquals(k1.hashCode(), k4.hashCode()); Assert.assertFalse(k1.equals(k4)); } + + @Test + public void jsonTest() { + KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1"); + String json = GSON.toJson(k1); + + KeyType k2 = GSON.fromJson(json, KeyType.class); + Assert.assertEquals(k1, k2); + Assert.assertEquals(k1.hashCode(), k2.hashCode()); + } } diff --git a/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java b/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java index 3d172d87bedc7..7ee254806fad3 100644 --- a/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java +++ b/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java @@ -1,5 +1,8 @@ package io.ray.serve.poll; +import com.google.protobuf.ByteString; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.UpdatedObject; import java.util.HashMap; import java.util.Map; import org.testng.Assert; @@ -10,25 +13,35 @@ public class LongPollClientTest { @Test public void test() throws Throwable { + String[] a = new String[] {"test"}; + + // Construct a listener map. KeyType keyType = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "backendTag"); - int[] a = new int[] {0}; Map keyListeners = new HashMap<>(); - keyListeners.put(keyType, (object) -> a[0] = (Integer) object); + keyListeners.put( + keyType, (object) -> a[0] = String.valueOf(((BackendConfig) object).getNumReplicas())); + + // Initialize LongPollClient. LongPollClient longPollClient = new LongPollClient(null, keyListeners); + // Construct updated object. + BackendConfig.Builder backendConfig = BackendConfig.newBuilder(); + backendConfig.setNumReplicas(20); int snapshotId = 10; - int objectSnapshot = 20; - UpdatedObject updatedObject = new UpdatedObject(); + UpdatedObject.Builder updatedObject = UpdatedObject.newBuilder(); updatedObject.setSnapshotId(snapshotId); - updatedObject.setObjectSnapshot(objectSnapshot); + updatedObject.setObjectSnapshot(ByteString.copyFrom(backendConfig.build().toByteArray())); + // Process update. Map updates = new HashMap<>(); - updates.put(keyType, updatedObject); + updates.put(keyType, updatedObject.build()); longPollClient.processUpdate(updates); + // Validation. Assert.assertEquals(longPollClient.getSnapshotIds().get(keyType).intValue(), snapshotId); Assert.assertEquals( - ((Integer) longPollClient.getObjectSnapshots().get(keyType)).intValue(), objectSnapshot); - Assert.assertEquals(a[0], objectSnapshot); + ((BackendConfig) longPollClient.getObjectSnapshots().get(keyType)).getNumReplicas(), + backendConfig.getNumReplicas()); + Assert.assertEquals(a[0], String.valueOf(backendConfig.getNumReplicas())); } } diff --git a/python/build-wheel-windows.sh b/python/build-wheel-windows.sh index cb36f901bd61c..c7c282acaa421 100755 --- a/python/build-wheel-windows.sh +++ b/python/build-wheel-windows.sh @@ -81,6 +81,13 @@ build_wheel_windows() { unset PYTHON2_BIN_PATH PYTHON3_BIN_PATH # make sure these aren't set by some chance install_ray cd "${WORKSPACE_DIR}"/python + # Set the commit SHA in __init__.py. + if [ -n "$TRAVIS_COMMIT" ]; then + sed -i.bak "s/{{RAY_COMMIT_SHA}}/$TRAVIS_COMMIT/g" ray/__init__.py && rm ray/__init__.py.bak + else + echo "TRAVIS_COMMIT variable not set - required to populated ray.__commit__." + exit 1 + fi # build ray wheel python setup.py --quiet bdist_wheel # build ray-cpp wheel diff --git a/python/ray/_private/client_mode_hook.py b/python/ray/_private/client_mode_hook.py index e6abf5f5a98f0..ef3df68206303 100644 --- a/python/ray/_private/client_mode_hook.py +++ b/python/ray/_private/client_mode_hook.py @@ -1,6 +1,6 @@ import os from contextlib import contextmanager -from functools import wraps +from functools import partial, wraps import threading # Attr set on func defs to mark they have been converted to client mode. @@ -15,6 +15,8 @@ is_client_mode_enabled_by_default = is_client_mode_enabled os.environ.update({"RAY_CLIENT_MODE": "0"}) +is_init_called = False + # Local setting of whether to ignore client hook conversion. This defaults # to TRUE and is disabled when the underlying 'real' Ray function is needed. _client_hook_status_on_thread = threading.local() @@ -75,13 +77,27 @@ def enable_client_mode(): _explicitly_disable_client_mode() -def client_mode_hook(func): - """Decorator for ray module methods to delegate to ray client""" +def client_mode_hook(func=None, *, auto_init: bool): + """Decorator for whether to use the 'regular' ray version of a function, + or the Ray Client version of that function. + + Args: + func (callable): This function. This is set when this function is used + as a decorator. + auto_init (bool): Whether `ray.init()` should be transparently called when + the wrapped function is called. This should be `True` for functions + that are *NOT* part of the initialization path (e.g. `init` or + `is_initialized`) or for functions that do not require Ray to be + initialized (e.g., KV operations, `shutdown`). + """ + if func is None: + return partial(client_mode_hook, auto_init=auto_init) + from ray.util.client import ray @wraps(func) def wrapper(*args, **kwargs): - if client_mode_should_convert(): + if client_mode_should_convert(auto_init=auto_init): # Legacy code # we only convert init function if RAY_CLIENT_MODE=1 if func.__name__ != "init" or is_client_mode_enabled_by_default: @@ -91,13 +107,23 @@ def wrapper(*args, **kwargs): return wrapper -def client_mode_should_convert(): - # This is for testing with RAY_CLIENT_MODE. - # When RAY_CLIENT_MODE=1, it means that for all the tests - # will run with client mode. - # is_client_mode_enabled will be set to be off when client is off +def client_mode_should_convert(*, auto_init: bool): + """Determines if functions should be converted to client mode & if + Ray should be auto-initialized. + + NOTE: `auto_init` must happen before we branch into regular ray or client + code because the initialization may result in either mode. + """ + if auto_init: + import ray + if os.environ.get("RAY_ENABLE_AUTO_CONNECT", + "") != "0" and not ray.is_initialized(): + ray.init() + + # `is_client_mode_enabled_by_default` is used for testing with + # `RAY_CLIENT_MODE=1`. This flag means all tests run with client mode. return (is_client_mode_enabled or is_client_mode_enabled_by_default) and \ - _get_client_hook_status_on_thread() + _get_client_hook_status_on_thread() def client_mode_wrap(func): @@ -115,7 +141,9 @@ def client_mode_wrap(func): @wraps(func) def wrapper(*args, **kwargs): - if client_mode_should_convert(): + # Directly pass this through since `client_mode_wrap` is for + # Placement Group APIs + if client_mode_should_convert(auto_init=True): f = ray.remote(num_cpus=0)(func) ref = f.remote(*args, **kwargs) return ray.get(ref) diff --git a/python/ray/_private/parameter.py b/python/ray/_private/parameter.py index 4303808609a48..24eace1a8e78f 100644 --- a/python/ray/_private/parameter.py +++ b/python/ray/_private/parameter.py @@ -72,8 +72,8 @@ class RayParams: be created. worker_path (str): The path of the source code that will be run by the worker. - setup_worker_path (str): The path of the Python file that will run - worker_setup_hook to set up the environment for the worker process. + setup_worker_path (str): The path of the Python file that will set up + the environment for the worker process. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. include_dashboard: Boolean flag indicating whether to start the web diff --git a/python/ray/_private/runtime_env/__init__.py b/python/ray/_private/runtime_env/__init__.py index 20401cb96f021..e69de29bb2d1d 100644 --- a/python/ray/_private/runtime_env/__init__.py +++ b/python/ray/_private/runtime_env/__init__.py @@ -1,3 +0,0 @@ -from ray._private.runtime_env.context import RuntimeEnvContext # noqa: F401 -from ray._private.runtime_env.validation import ( # noqa: F401 - override_task_or_actor_runtime_env, RuntimeEnvDict) # noqa: F401 diff --git a/python/ray/_private/runtime_env/conda.py b/python/ray/_private/runtime_env/conda.py index d9c810b89b75b..92bc3d8cb1139 100644 --- a/python/ray/_private/runtime_env/conda.py +++ b/python/ray/_private/runtime_env/conda.py @@ -12,9 +12,9 @@ from pathlib import Path import ray -from ray._private.runtime_env import RuntimeEnvContext from ray._private.runtime_env.conda_utils import (get_conda_activate_commands, get_or_create_conda_env) +from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.utils import (get_wheel_filename, get_master_wheel_url, get_release_wheel_url, try_to_create_directory) @@ -81,7 +81,7 @@ def get_conda_dict(runtime_env, runtime_env_dir) -> Optional[Dict[Any, Any]]: else: return None if runtime_env.get("pip"): - requirements_txt = runtime_env["pip"] + requirements_txt = "\n".join(runtime_env["pip"]) + "\n" pip_hash = hashlib.sha1(requirements_txt.encode("utf-8")).hexdigest() pip_hash_str = f"pip-generated-{pip_hash}" diff --git a/python/ray/_private/runtime_env/conda_utils.py b/python/ray/_private/runtime_env/conda_utils.py index 5d61c9e8c5f45..2339da036b60c 100644 --- a/python/ray/_private/runtime_env/conda_utils.py +++ b/python/ray/_private/runtime_env/conda_utils.py @@ -126,6 +126,21 @@ def get_or_create_conda_env(conda_env_path: str, return env_name +def get_conda_env_list() -> list: + """ + Get conda env list. + """ + conda_path = get_conda_bin_executable("conda") + try: + exec_cmd([conda_path, "--help"], throw_on_error=False) + except EnvironmentError: + raise ValueError(f"Could not find Conda executable at {conda_path}.") + _, stdout, _ = exec_cmd([conda_path, "env", "list", "--json"]) + envs = json.loads(stdout)["envs"] + print(f"Conda env len {len(envs)}") + return envs + + class ShellCommandException(Exception): pass diff --git a/python/ray/_private/runtime_env/context.py b/python/ray/_private/runtime_env/context.py index af3409f310ca5..c5db64437ce2d 100644 --- a/python/ray/_private/runtime_env/context.py +++ b/python/ray/_private/runtime_env/context.py @@ -4,9 +4,13 @@ import sys from typing import Dict, List, Optional +from ray.util.annotations import DeveloperAPI +from ray.core.generated.common_pb2 import Language + logger = logging.getLogger(__name__) +@DeveloperAPI class RuntimeEnvContext: """A context used to describe the created runtime env.""" @@ -31,10 +35,13 @@ def serialize(self) -> str: def deserialize(json_string): return RuntimeEnvContext(**json.loads(json_string)) - def exec_worker(self, passthrough_args: List[str]): + def exec_worker(self, passthrough_args: List[str], language: Language): os.environ.update(self.env_vars) - exec_command = " ".join([f"exec {self.py_executable}"] + - passthrough_args) + if language == Language.PYTHON: + executable = f"exec {self.py_executable}" + else: + executable = "exec" + exec_command = " ".join([executable] + passthrough_args) command_str = " && ".join(self.command_prefix + [exec_command]) logger.info(f"Exec'ing worker with command: {command_str}") os.execvp("bash", ["bash", "-c", command_str]) diff --git a/python/ray/_private/runtime_env/plugin.py b/python/ray/_private/runtime_env/plugin.py new file mode 100644 index 0000000000000..5e411c141fc08 --- /dev/null +++ b/python/ray/_private/runtime_env/plugin.py @@ -0,0 +1,70 @@ +from abc import ABC, abstractstaticmethod + +from ray.util.annotations import DeveloperAPI +from ray._private.runtime_env.context import RuntimeEnvContext + + +@DeveloperAPI +class RuntimeEnvPlugin(ABC): + @abstractstaticmethod + def validate(runtime_env_dict: dict) -> str: + """Validate user entry and returns a URI uniquely describing resource. + + This method will be called at ``f.options(runtime_env=...)`` or + ``ray.init(runtime_env=...)`` time and it should check the runtime env + dictionary for any errors. For example, it can raise "TypeError: + expected string for "conda" field". + + Args: + runtime_env_dict(dict): the entire dictionary passed in by user. + + Returns: + uri(str): a URI uniquely describing this resource (e.g., a hash of + the conda spec). + """ + raise NotImplementedError() + + def create(uri: str, runtime_env_dict: dict, + ctx: RuntimeEnvContext) -> float: + """Create and install the runtime environment. + + Gets called in the runtime env agent at install time. The URI can be + used as a caching mechanism. + + Args: + uri(str): a URI uniquely describing this resource. + runtime_env_dict(dict): the entire dictionary passed in by user. + ctx(RuntimeEnvContext): auxiliary information supplied by Ray. + + Returns: + the disk space taken up by this plugin installation for this + environment. e.g. for working_dir, this downloads the files to the + local node. + """ + return 0 + + def modify_context(uri: str, runtime_env_dict: dict, + ctx: RuntimeEnvContext) -> None: + """Modify context to change worker startup behavior. + + For example, you can use this to preprend "cd " command to worker + startup, or add new environment variables. + + Args: + uri(str): a URI uniquely describing this resource. + runtime_env_dict(dict): the entire dictionary passed in by user. + ctx(RuntimeEnvContext): auxiliary information supplied by Ray. + """ + return + + def delete(uri: str, ctx: RuntimeEnvContext) -> float: + """Delete the the runtime environment given uri. + + Args: + uri(str): a URI uniquely describing this resource. + ctx(RuntimeEnvContext): auxiliary information supplied by Ray. + + Returns: + the amount of space reclaimed by the deletion. + """ + return 0 diff --git a/python/ray/_private/runtime_env/validation.py b/python/ray/_private/runtime_env/validation.py index e113e4151424d..0bc2a609762c5 100644 --- a/python/ray/_private/runtime_env/validation.py +++ b/python/ray/_private/runtime_env/validation.py @@ -1,12 +1,15 @@ +import copy import json import logging import os from pathlib import Path import sys -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Set, Union import yaml import ray +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray._private.utils import import_attr # We need to setup this variable before # using this module @@ -20,19 +23,176 @@ GCS_STORAGE_MAX_SIZE = 100 * 1024 * 1024 # 100MiB -class RuntimeEnvDict: - """Parses and validates the runtime env dictionary from the user. +def parse_and_validate_working_dir(working_dir: str, + is_task_or_actor: bool = False) -> str: + """Parses and validates a user-provided 'working_dir' option. - Attributes: + The working_dir may not be specified per-task or per-actor. + + Otherwise, it should be a valid path to a local directory. + """ + assert working_dir is not None + + if is_task_or_actor: + raise NotImplementedError( + "Overriding working_dir for tasks and actors isn't supported. " + "Please use ray.init(runtime_env={'working_dir': ...}) " + "to configure the environment per-job instead.") + elif not isinstance(working_dir, str): + raise TypeError("`working_dir` must be a string, got " + f"{type(working_dir)}.") + elif not Path(working_dir).is_dir(): + raise ValueError( + f"working_dir {working_dir} is not a valid directory.") + + return working_dir + + +def parse_and_validate_conda(conda: Union[str, dict], + is_task_or_actor: bool = False + ) -> Union[str, dict]: + """Parses and validates a user-provided 'conda' option. + + Conda can be one of three cases: + 1) A dictionary describing the env. This is passed through directly. + 2) A string referring to a preinstalled conda env. + 3) A string pointing to a local conda YAML file. This is detected + by looking for a '.yaml' or '.yml' suffix. In this case, the file + will be read as YAML and passed through as a dictionary. + """ + assert conda is not None + + result = None + if sys.platform == "win32": + raise NotImplementedError("The 'conda' field in runtime_env " + "is not currently supported on " + "Windows.") + elif isinstance(conda, str): + yaml_file = Path(conda) + if yaml_file.suffix in (".yaml", ".yml"): + if not yaml_file.is_file(): + raise ValueError(f"Can't find conda YAML file {yaml_file}.") + try: + result = yaml.safe_load(yaml_file.read_text()) + except Exception as e: + raise ValueError( + f"Failed to read conda file {yaml_file}: {e}.") + else: + # Assume it's a pre-existing conda environment name. + result = conda + elif isinstance(conda, dict): + result = conda + else: + raise TypeError("runtime_env['conda'] must be of type str or " + f"dict, got {type(conda)}.") + + return result + + +def parse_and_validate_pip(pip: Union[str, List[str]], + is_task_or_actor: bool = False + ) -> Optional[List[str]]: + """Parses and validates a user-provided 'pip' option. + + Conda can be one of two cases: + 1) A List[str] describing the requirements. This is passed through. + 2) A string pointing to a local requirements file. In this case, the + file contents will be read split into a list. + """ + assert pip is not None + + result = None + if sys.platform == "win32": + raise NotImplementedError("The 'pip' field in runtime_env " + "is not currently supported on " + "Windows.") + elif isinstance(pip, str): + # We have been given a path to a requirements.txt file. + pip_file = Path(pip) + if not pip_file.is_file(): + raise ValueError(f"{pip_file} is not a valid file") + result = pip_file.read_text().strip().split("\n") + elif isinstance(pip, list) and all(isinstance(dep, str) for dep in pip): + if len(pip) == 0: + result = None + else: + result = pip + else: + raise TypeError("runtime_env['pip'] must be of type str or " + f"List[str], got {type(pip)}") + + return result + + +def parse_and_validate_uris(uris: List[str], + is_task_or_actor: bool = False) -> List[str]: + """Parses and validates a user-provided 'uris' option. + + These are passed through without validation (for now). + """ + assert uris is not None + return uris + + +def parse_and_validate_container(container: List[str], + is_task_or_actor: bool = False) -> List[str]: + """Parses and validates a user-provided 'container' option. + + This is passed through without validation (for now). + """ + assert container is not None + return container + + +def parse_and_validate_env_vars(env_vars: Dict[str, str], + is_task_or_actor: bool = False + ) -> Optional[Dict[str, str]]: + """Parses and validates a user-provided 'env_vars' option. + + This is validated to verify that all keys and vals are strings. + + If an empty dictionary is passed, we return `None` for consistency. + """ + assert env_vars is not None + if len(env_vars) == 0: + return None + + if not (isinstance(env_vars, dict) and all( + isinstance(k, str) and isinstance(v, str) + for (k, v) in env_vars.items())): + raise TypeError("runtime_env['env_vars'] must be of type " + "Dict[str, str]") + + return env_vars + + +# Dictionary mapping runtime_env options with the function to parse and +# validate them. +OPTION_TO_VALIDATION_FN = { + "working_dir": parse_and_validate_working_dir, + "conda": parse_and_validate_conda, + "pip": parse_and_validate_pip, + "uris": parse_and_validate_uris, + "env_vars": parse_and_validate_env_vars, + "container": parse_and_validate_container, +} + + +class ParsedRuntimeEnv(dict): + """An internal wrapper for runtime_env that is parsed and validated. + + This should be constructed from user-provided input (the API runtime_env) + and used everywhere that the runtime_env is passed around internally. + + All options in the resulting dictionary will have non-None values. + + Currently supported options: working_dir (Path): Specifies the working directory of the worker. This can either be a local directory or zip file. Examples: "." # cwd "local_project.zip" # archive is unpacked into directory - py_modules (List[Path]): Similar to working_dir, but specifies python - modules to add to the `sys.path`. - Examples: - ["/path/to/other_module", "/other_path/local_project.zip"] + uris (List[str]): A list of URIs that define the working_dir. pip (List[str] | str): Either a list of pip packages, or a string containing the path to a pip requirements.txt file. conda (dict | str): Either the conda YAML config, the name of a @@ -64,170 +224,136 @@ class RuntimeEnvDict: {"OMP_NUM_THREADS": "32", "TF_WARNINGS": "none"} """ + known_fields: Set[str] = { + "working_dir", "conda", "pip", "uris", "containers", "excludes", + "env_vars", "_ray_release", "_ray_commit", "_inject_current_ray", + "plugins" + } + def __init__(self, - runtime_env_json: dict, - working_dir: Optional[str] = None): - # Simple dictionary with all options validated. This will always - # contain all supported keys; values will be set to None if - # unspecified. However, if all values are None this is set to {}. - self._dict = dict() - - if "working_dir" in runtime_env_json: - self._dict["working_dir"] = runtime_env_json["working_dir"] - if not isinstance(self._dict["working_dir"], str): - raise TypeError("`working_dir` must be a string. Type " - f"{type(self._dict['working_dir'])} received.") - working_dir = Path(self._dict["working_dir"]).absolute() - else: - self._dict["working_dir"] = None - working_dir = Path(working_dir).absolute() if working_dir else None - - self._dict["conda"] = None - if "conda" in runtime_env_json: - if sys.platform == "win32": - raise NotImplementedError("The 'conda' field in runtime_env " - "is not currently supported on " - "Windows.") - conda = runtime_env_json["conda"] - if isinstance(conda, str): - yaml_file = Path(conda) - if yaml_file.suffix in (".yaml", ".yml"): - if working_dir and not yaml_file.is_absolute(): - yaml_file = working_dir / yaml_file - if not yaml_file.is_file(): - raise ValueError( - f"Can't find conda YAML file {yaml_file}") - try: - self._dict["conda"] = yaml.safe_load( - yaml_file.read_text()) - except Exception as e: - raise ValueError( - f"Invalid conda file {yaml_file} with error {e}") - else: - logger.info( - f"Using preinstalled conda environment: {conda}") - self._dict["conda"] = conda - elif isinstance(conda, dict): - self._dict["conda"] = conda - elif conda is not None: - raise TypeError("runtime_env['conda'] must be of type str or " - "dict") - - self._dict["pip"] = None - if "pip" in runtime_env_json: - if sys.platform == "win32": - raise NotImplementedError("The 'pip' field in runtime_env " - "is not currently supported on " - "Windows.") - if ("conda" in runtime_env_json - and runtime_env_json["conda"] is not None): - raise ValueError( - "The 'pip' field and 'conda' field of " - "runtime_env cannot both be specified.\n" - f"specified pip field: {runtime_env_json['pip']}\n" - f"specified conda field: {runtime_env_json['conda']}\n" - "To use pip with conda, please only set the 'conda' " - "field, and specify your pip dependencies " - "within the conda YAML config dict: see " - "https://conda.io/projects/conda/en/latest/" - "user-guide/tasks/manage-environments.html" - "#create-env-file-manually") - pip = runtime_env_json["pip"] - if isinstance(pip, str): - # We have been given a path to a requirements.txt file. - pip_file = Path(pip) - if working_dir and not pip_file.is_absolute(): - pip_file = working_dir / pip_file - if not pip_file.is_file(): - raise ValueError(f"{pip_file} is not a valid file") - self._dict["pip"] = pip_file.read_text() - elif isinstance(pip, list) and all( - isinstance(dep, str) for dep in pip): - # Construct valid pip requirements.txt from list of packages. - self._dict["pip"] = "\n".join(pip) + "\n" - else: - raise TypeError("runtime_env['pip'] must be of type str or " - "List[str]") - - if "uris" in runtime_env_json: - self._dict["uris"] = runtime_env_json["uris"] - - if "container" in runtime_env_json: - self._dict["container"] = runtime_env_json["container"] - - self._dict["env_vars"] = None - if "env_vars" in runtime_env_json: - env_vars = runtime_env_json["env_vars"] - self._dict["env_vars"] = env_vars - if not (isinstance(env_vars, dict) and all( - isinstance(k, str) and isinstance(v, str) - for (k, v) in env_vars.items())): - raise TypeError("runtime_env['env_vars'] must be of type" - "Dict[str, str]") - - if "_ray_release" in runtime_env_json: - self._dict["_ray_release"] = runtime_env_json["_ray_release"] - - if "_ray_commit" in runtime_env_json: - self._dict["_ray_commit"] = runtime_env_json["_ray_commit"] + runtime_env: Dict[str, Any], + is_task_or_actor: bool = False, + _validate: bool = True): + super().__init__() + + # Blindly trust that the runtime_env has already been validated. + # This is dangerous and should only be used internally (e.g., on the + # deserialization codepath. + if not _validate: + self.update(runtime_env) + return + + if runtime_env.get("conda") and runtime_env.get("pip"): + raise ValueError( + "The 'pip' field and 'conda' field of " + "runtime_env cannot both be specified.\n" + f"specified pip field: {runtime_env['pip']}\n" + f"specified conda field: {runtime_env['conda']}\n" + "To use pip with conda, please only set the 'conda' " + "field, and specify your pip dependencies " + "within the conda YAML config dict: see " + "https://conda.io/projects/conda/en/latest/" + "user-guide/tasks/manage-environments.html" + "#create-env-file-manually") + + for option, validate_fn in OPTION_TO_VALIDATION_FN.items(): + option_val = runtime_env.get(option) + if option_val is not None: + validated_option_val = validate_fn( + option_val, is_task_or_actor=is_task_or_actor) + if validated_option_val is not None: + self[option] = validated_option_val + + if "_ray_release" in runtime_env: + self["_ray_release"] = runtime_env["_ray_release"] + + if "_ray_commit" in runtime_env: + self["_ray_commit"] = runtime_env["_ray_commit"] else: - if self._dict.get("pip") or self._dict.get("conda"): - self._dict["_ray_commit"] = ray.__commit__ + if self.get("pip") or self.get("conda"): + self["_ray_commit"] = ray.__commit__ # Used for testing wheels that have not yet been merged into master. # If this is set to True, then we do not inject Ray into the conda # or pip dependencies. - if os.environ.get("RAY_RUNTIME_ENV_LOCAL_DEV_MODE"): - runtime_env_json["_inject_current_ray"] = True - if "_inject_current_ray" in runtime_env_json: - self._dict["_inject_current_ray"] = runtime_env_json[ - "_inject_current_ray"] + if "_inject_current_ray" in runtime_env: + self["_inject_current_ray"] = runtime_env["_inject_current_ray"] + elif "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ: + self["_inject_current_ray"] = True - # TODO(ekl) we should have better schema validation here. - # TODO(ekl) support py_modules - # TODO(architkulkarni) support docker + if "plugins" in runtime_env: + self["plugins"] = dict() + for class_path, plugin_field in runtime_env["plugins"].items(): + plugin_class: RuntimeEnvPlugin = import_attr(class_path) + if not issubclass(plugin_class, RuntimeEnvPlugin): + # TODO(simon): move the inferface to public once ready. + raise TypeError( + f"{class_path} must be inherit from " + "ray._private.runtime_env.plugin.RuntimeEnvPlugin.") + # TODO(simon): implement uri support. + _ = plugin_class.validate(runtime_env) + # Validation passed, add the entry to parsed runtime env. + self["plugins"][class_path] = plugin_field + + unknown_fields = ( + set(runtime_env.keys()) - ParsedRuntimeEnv.known_fields) + if len(unknown_fields): + logger.warning( + "The following unknown entries in the runtime_env dictionary " + f"will be ignored: {unknown_fields}. If you intended to use " + "them as plugins, they must be nested in the `plugins` field.") # TODO(architkulkarni) This is to make it easy for the worker caching # code in C++ to check if the env is empty without deserializing and # parsing it. We should use a less confusing approach here. - if all(val is None for val in self._dict.values()): + if all(val is None for val in self.values()): self._dict = {} - def get_parsed_dict(self) -> dict: - return self._dict + @classmethod + def deserialize(cls, serialized: str) -> "ParsedRuntimeEnv": + return cls(json.loads(serialized), _validate=False) def serialize(self) -> str: - # Use sort_keys=True because we will use the output as a key to cache - # workers by, so we need the serialization to be independent of the - # dict order. - return json.dumps(self._dict, sort_keys=True) - - def set_uris(self, uris): - self._dict["uris"] = uris + # Sort the keys we can compare the serialized string for equality. + return json.dumps(self, sort_keys=True) def override_task_or_actor_runtime_env( - runtime_env: Optional[Dict[str, Any]], - parent_runtime_env: Dict[str, Any]) -> Dict[str, Any]: - if runtime_env: - if runtime_env.get("working_dir"): - raise NotImplementedError( - "Overriding working_dir for actors is not supported. " - "Please use ray.init(runtime_env={'working_dir': ...}) " - "to configure per-job environment instead.") - # NOTE(edoakes): this is sort of hacky, but we pass in the parent - # working_dir here so the relative path to a requirements.txt file - # works. The right solution would be to merge the runtime_env with the - # parent runtime env before validation. - runtime_env_dict = RuntimeEnvDict( - runtime_env, working_dir=parent_runtime_env.get( - "working_dir")).get_parsed_dict() - else: - runtime_env_dict = {} + child_runtime_env: ParsedRuntimeEnv, + parent_runtime_env: ParsedRuntimeEnv) -> ParsedRuntimeEnv: + """Merge the given child runtime env with the parent runtime env. + + If running in a driver, the current runtime env comes from the + JobConfig. Otherwise, we are running in a worker for an actor or + task, and the current runtime env comes from the current TaskSpec. + + By default, the child runtime env inherits non-specified options from the + parent. There are two exceptions to this: + - working_dir is not inherited (only URIs). + - The env_vars dictionaries are merged, so environment variables + not specified by the child are still inherited from the parent. + + Returns: + The resulting merged ParsedRuntimeEnv. + """ + assert child_runtime_env is not None + assert parent_runtime_env is not None + + # Override environment variables. + result_env_vars = copy.deepcopy(parent_runtime_env.get("env_vars") or {}) + child_env_vars = child_runtime_env.get("env_vars") or {} + result_env_vars.update(child_env_vars) + + # Inherit all other non-specified options from the parent. + result = copy.deepcopy(parent_runtime_env) + result.update(child_runtime_env) + if len(result_env_vars) > 0: + result["env_vars"] = result_env_vars + if "working_dir" in result: + del result["working_dir"] # working_dir should not be in child env. - # If per-actor URIs aren't specified, override them with those in the - # job config. - if "uris" not in runtime_env_dict and "uris" in parent_runtime_env: - runtime_env_dict["uris"] = parent_runtime_env.get("uris") + # NOTE(architkulkarni): This allows worker caching code in C++ to + # check if a runtime env is empty without deserializing it. + assert all(val is not None for val in result.values()) - return runtime_env_dict + return result diff --git a/python/ray/_private/runtime_env/working_dir.py b/python/ray/_private/runtime_env/working_dir.py index 964cf4aafcf5d..e5034caf74a27 100644 --- a/python/ray/_private/runtime_env/working_dir.py +++ b/python/ray/_private/runtime_env/working_dir.py @@ -15,7 +15,7 @@ _internal_kv_initialized) from ray.job_config import JobConfig from ray._private.thirdparty.pathspec import PathSpec -from ray._private.runtime_env import RuntimeEnvContext +from ray._private.runtime_env.context import RuntimeEnvContext default_logger = logging.getLogger(__name__) diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 8d135129b78fa..04b63ec920f5a 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -21,6 +21,7 @@ import ray import ray.ray_constants as ray_constants import redis +from ray.core.generated.common_pb2 import Language # Import psutil and colorama after ray so the packaged version is used. import colorama @@ -398,6 +399,11 @@ def node_ip_address_from_perspective(address): def get_node_ip_address(address="8.8.8.8:53"): if ray.worker._global_node is not None: return ray.worker._global_node.node_ip_address + if sys.platform == "darwin": + # Due to the mac osx firewall, + # we use loopback ip as the ip address + # to prevent security popups. + return "127.0.0.1" return node_ip_address_from_perspective(address) @@ -866,7 +872,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, fate_share=fate_share, - port_denylist=port_denylist) + port_denylist=port_denylist, + listen_to_localhost_only=(node_ip_address == "127.0.0.1")) processes.append(p) redis_address = address(node_ip_address, port) primary_redis_client = redis.StrictRedis( @@ -922,7 +929,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, fate_share=fate_share, - port_denylist=port_denylist) + port_denylist=port_denylist, + listen_to_localhost_only=(node_ip_address == "127.0.0.1")) processes.append(p) shard_address = address(node_ip_address, redis_shard_port) @@ -944,7 +952,8 @@ def _start_redis_instance(executable, password=None, redis_max_memory=None, fate_share=None, - port_denylist=None): + port_denylist=None, + listen_to_localhost_only=False): """Start a single Redis server. Notes: @@ -970,6 +979,9 @@ def _start_redis_instance(executable, will start LRU eviction of entries. port_denylist (set): A set of denylist ports that shouldn't be used when allocating a new port. + listen_to_localhost_only (bool): Redis server only listens to + localhost (127.0.0.1) if it's true, + otherwise it listens to all network interfaces. Returns: A tuple of the port used by Redis and ProcessInfo for the process that @@ -990,6 +1002,8 @@ def _start_redis_instance(executable, raise ValueError("Spaces not permitted in redis password.") command += ["--requirepass", password] command += (["--port", str(port), "--loglevel", "warning"]) + if listen_to_localhost_only: + command += ["--bind", "127.0.0.1"] process_info = start_ray_process( command, ray_constants.PROCESS_TYPE_REDIS_SERVER, @@ -1360,8 +1374,8 @@ def start_raylet(redis_address, to. worker_path (str): The path of the Python file that new worker processes will execute. - setup_worker_path (str): The path of the Python file that will run - worker_setup_hook to set up the environment for the worker process. + setup_worker_path (str): The path of the Python file that will set up + the environment for the worker process. temp_dir (str): The path of the temporary directory Ray will use. session_dir (str): The path of this session. resource_dir(str): The path of resource of this session . @@ -1437,6 +1451,7 @@ def start_raylet(redis_address, redis_password, session_dir, node_ip_address, + setup_worker_path, ) else: java_worker_command = [] @@ -1591,6 +1606,7 @@ def build_java_worker_command( redis_password, session_dir, node_ip_address, + setup_worker_path, ): """This method assembles the command used to start a Java worker. @@ -1602,6 +1618,8 @@ def build_java_worker_command( redis_password (str): The password of connect to redis. session_dir (str): The path of this session. node_ip_address (str): The ip address for this node. + setup_worker_path (str): The path of the Python file that will set up + the environment for the worker process. Returns: The command string for starting Java worker. """ @@ -1626,7 +1644,9 @@ def build_java_worker_command( pairs.append(("ray.home", RAY_HOME)) pairs.append(("ray.logging.dir", os.path.join(session_dir, "logs"))) pairs.append(("ray.session-dir", session_dir)) - command = ["java"] + ["-D{}={}".format(*pair) for pair in pairs] + command = [sys.executable] + [setup_worker_path] + ["java"] + [ + "-D{}={}".format(*pair) for pair in pairs + ] # Add ray jars path to java classpath ray_jars = os.path.join(get_ray_jars_dir(), "*") @@ -1908,9 +1928,14 @@ def start_ray_client_server( ray_constants.SETUP_WORKER_FILENAME) command = [ - sys.executable, setup_worker_path, "-m", "ray.util.client.server", - f"--redis-address={redis_address}", f"--port={ray_client_server_port}", - f"--mode={server_type}" + sys.executable, + setup_worker_path, + "-m", + "ray.util.client.server", + f"--redis-address={redis_address}", + f"--port={ray_client_server_port}", + f"--mode={server_type}", + f"--language={Language.Name(Language.PYTHON)}", ] if redis_password: command.append(f"--redis-password={redis_password}") diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 5c79c3b796459..4326d6cf943a9 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -114,7 +114,7 @@ cdef class CoreWorker: object async_event_loop object plasma_event_handler object job_config - object current_runtime_env_dict + object current_runtime_env c_bool is_local_mode cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata, diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index fdb9a7f51fef0..bbc064ec92938 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -7,18 +7,18 @@ from cpython.exc cimport PyErr_CheckSignals import asyncio -import copy import gc import inspect -import threading -import traceback -import time import logging +import msgpack import os import pickle +import setproctitle import sys +import threading +import time +import traceback import _thread -import setproctitle from libc.stdint cimport ( int32_t, @@ -100,13 +100,6 @@ from ray.includes.ray_config cimport RayConfig from ray.includes.global_state_accessor cimport CGlobalStateAccessor import ray -import ray._private.gcs_utils as gcs_utils -from ray import external_storage -from ray._private.async_compat import ( - sync_to_async, get_new_event_loop) -import ray._private.memory_monitor as memory_monitor -import ray.ray_constants as ray_constants -import ray._private.profiling as profiling from ray.exceptions import ( RayActorError, RayError, @@ -117,11 +110,15 @@ from ray.exceptions import ( TaskCancelledError, AsyncioActorExit, ) +from ray import external_storage +import ray.ray_constants as ray_constants +from ray._private.async_compat import sync_to_async, get_new_event_loop +from ray._private.client_mode_hook import disable_client_hook +import ray._private.gcs_utils as gcs_utils +from ray._private.runtime_env.validation import ParsedRuntimeEnv +import ray._private.memory_monitor as memory_monitor +import ray._private.profiling as profiling from ray._private.utils import decode -from ray._private.client_mode_hook import ( - disable_client_hook, -) -import msgpack cimport cpython @@ -1353,8 +1350,8 @@ cdef class CoreWorker: int64_t placement_group_bundle_index, c_bool placement_group_capture_child_tasks, c_string debugger_breakpoint, - runtime_env_dict, - override_environment_variables + c_string serialized_runtime_env, + runtime_env_uris, ): cdef: unordered_map[c_string, double] c_resources @@ -1362,15 +1359,10 @@ cdef class CoreWorker: c_vector[unique_ptr[CTaskArg]] args_vector CPlacementGroupID c_placement_group_id = \ placement_group_id.native() - c_string c_serialized_runtime_env - unordered_map[c_string, c_string] \ - c_override_environment_variables = \ - override_environment_variables + c_vector[c_string] c_runtime_env_uris = runtime_env_uris c_vector[CObjectReference] return_refs with self.profile_event(b"submit_task"): - c_serialized_runtime_env = \ - self.prepare_runtime_env(runtime_env_dict) prepare_resources(resources, &c_resources) ray_function = CRayFunction( language.lang, function_descriptor.descriptor) @@ -1383,8 +1375,8 @@ cdef class CoreWorker: ray_function, args_vector, CTaskOptions( name, num_returns, c_resources, b"", - c_serialized_runtime_env, - c_override_environment_variables), + serialized_runtime_env, + c_runtime_env_uris), max_retries, retry_exceptions, c_pair[CPlacementGroupID, int64_t]( c_placement_group_id, placement_group_bundle_index), @@ -1410,8 +1402,8 @@ cdef class CoreWorker: int64_t placement_group_bundle_index, c_bool placement_group_capture_child_tasks, c_string extension_data, - runtime_env_dict, - override_environment_variables + c_string serialized_runtime_env, + runtime_env_uris, ): cdef: CRayFunction ray_function @@ -1422,14 +1414,9 @@ cdef class CoreWorker: CActorID c_actor_id CPlacementGroupID c_placement_group_id = \ placement_group_id.native() - c_string c_serialized_runtime_env - unordered_map[c_string, c_string] \ - c_override_environment_variables = \ - override_environment_variables + c_vector[c_string] c_runtime_env_uris = runtime_env_uris with self.profile_event(b"submit_task"): - c_serialized_runtime_env = \ - self.prepare_runtime_env(runtime_env_dict) prepare_resources(resources, &c_resources) prepare_resources(placement_resources, &c_placement_resources) ray_function = CRayFunction( @@ -1449,8 +1436,8 @@ cdef class CoreWorker: c_placement_group_id, placement_group_bundle_index), placement_group_capture_child_tasks, - c_serialized_runtime_env, - c_override_environment_variables), + serialized_runtime_env, + c_runtime_env_uris), extension_data, &c_actor_id)) @@ -1725,12 +1712,11 @@ cdef class CoreWorker: return CCoreWorkerProcess.GetCoreWorker().GetOwnerAddress( c_object_id).SerializeAsString() - def serialize_and_promote_object_ref(self, ObjectRef object_ref): + def serialize_object_ref(self, ObjectRef object_ref): cdef: CObjectID c_object_id = object_ref.native() CAddress c_owner_address = CAddress() c_string serialized_object_status - CCoreWorkerProcess.GetCoreWorker().PromoteObjectToPlasma(c_object_id) CCoreWorkerProcess.GetCoreWorker().GetOwnershipInfo( c_object_id, &c_owner_address, &serialized_object_status) return (object_ref, @@ -1861,19 +1847,20 @@ cdef class CoreWorker: return (CCoreWorkerProcess.GetCoreWorker().GetWorkerContext() .CurrentActorIsAsync()) - def get_current_runtime_env_dict(self): + def get_current_runtime_env(self) -> ParsedRuntimeEnv: # This should never change, so we can safely cache it to avoid ser/de - if self.current_runtime_env_dict is None: + if self.current_runtime_env is None: if self.is_driver: - self.current_runtime_env_dict = \ - json.loads(self.get_job_config().serialized_runtime_env) + job_config = self.get_job_config() + serialized_env = job_config.runtime_env.serialized_runtime_env else: - self.current_runtime_env_dict = json.loads( - CCoreWorkerProcess.GetCoreWorker() - .GetWorkerContext() - .GetCurrentSerializedRuntimeEnv() - ) - return self.current_runtime_env_dict + serialized_env = CCoreWorkerProcess.GetCoreWorker() \ + .GetWorkerContext().GetCurrentSerializedRuntimeEnv() + + self.current_runtime_env = ParsedRuntimeEnv.deserialize( + serialized_env) + + return self.current_runtime_env def is_exiting(self): return CCoreWorkerProcess.GetCoreWorker().IsExiting() @@ -1901,6 +1888,26 @@ cdef class CoreWorker: return ref_counts + def get_actor_call_stats(self): + cdef: + unordered_map[c_string, c_vector[uint64_t]] c_tasks_count + + c_tasks_count = ( + CCoreWorkerProcess.GetCoreWorker().GetActorCallStats()) + it = c_tasks_count.begin() + + tasks_count = dict() + while it != c_tasks_count.end(): + func_name = dereference(it).first + counters = dereference(it).second + tasks_count[func_name] = { + "pending": counters[0], + "running": counters[1], + "finished": counters[2], + } + postincrement(it) + return tasks_count + def set_get_async_callback(self, ObjectRef object_ref, callback): cpython.Py_INCREF(callback) CCoreWorkerProcess.GetCoreWorker().GetAsync( @@ -1925,45 +1932,6 @@ cdef class CoreWorker: self.job_config.ParseFromString(c_job_config.SerializeAsString()) return self.job_config - def prepare_runtime_env(self, runtime_env_dict: dict) -> str: - """Merge the given new runtime env with the current runtime env. - - If running in a driver, the current runtime env comes from the - JobConfig. Otherwise, we are running in a worker for an actor or - task, and the current runtime env comes from the current TaskSpec. - - The child's runtime env dict is merged with the parents via a simple - dict update, except for runtime_env["env_vars"], which is merged - with runtime_env["env_vars"] of the parent rather than overwriting it. - This is so that env vars set in the parent propagate to child actors - and tasks even if a new env var is set in the child. - - Args: - runtime_env_dict (dict): A runtime env for a child actor or task. - Returns: - The resulting merged JSON-serialized runtime env. - """ - - result_dict = copy.deepcopy(self.get_current_runtime_env_dict()) - - result_env_vars = copy.deepcopy(result_dict.get("env_vars") or {}) - child_env_vars = runtime_env_dict.get("env_vars") or {} - result_env_vars.update(child_env_vars) - - result_dict.update(runtime_env_dict) - result_dict["env_vars"] = result_env_vars - - # NOTE(architkulkarni): This allows worker caching code in C++ to - # check if a runtime env is empty without deserializing it. - if result_dict["env_vars"] == {}: - result_dict["env_vars"] = None - if all(val is None for val in result_dict.values()): - result_dict = {} - - # TODO(architkulkarni): We should just use RuntimeEnvDict here - # so all the serialization and validation is done in one place - return json.dumps(result_dict, sort_keys=True) - def get_task_submission_stats(self): cdef: int64_t num_tasks_submitted diff --git a/python/ray/actor.py b/python/ray/actor.py index faec5fccc7dd7..f228389da72e0 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -5,7 +5,8 @@ import ray.ray_constants as ray_constants import ray._raylet import ray._private.signature as signature -import ray._private.runtime_env as runtime_support +from ray._private.runtime_env.validation import ( + override_task_or_actor_runtime_env, ParsedRuntimeEnv) import ray.worker from ray.util.annotations import PublicAPI from ray.util.placement_group import ( @@ -31,7 +32,7 @@ @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def method(*args, **kwargs): """Annotate an actor method. @@ -388,11 +389,17 @@ class DerivedActorClass(cls, modified_class): PythonFunctionDescriptor.from_class( modified_class.__ray_actor_class__) + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + new_runtime_env = ParsedRuntimeEnv( + runtime_env or {}, is_task_or_actor=True) + self.__ray_metadata__ = ActorClassMetadata( Language.PYTHON, modified_class, actor_creation_function_descriptor, class_id, max_restarts, max_task_retries, num_cpus, num_gpus, memory, object_store_memory, - resources, accelerator_type, runtime_env) + resources, accelerator_type, new_runtime_env) return self @@ -403,10 +410,15 @@ def _ray_from_function_descriptor( resources, accelerator_type, runtime_env): self = ActorClass.__new__(ActorClass) + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + new_runtime_env = ParsedRuntimeEnv( + runtime_env or {}, is_task_or_actor=True) self.__ray_metadata__ = ActorClassMetadata( language, None, actor_creation_function_descriptor, None, max_restarts, max_task_retries, num_cpus, num_gpus, memory, - object_store_memory, resources, accelerator_type, runtime_env) + object_store_memory, resources, accelerator_type, new_runtime_env) return self @@ -442,8 +454,7 @@ def options(self, placement_group="default", placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, - runtime_env=None, - override_environment_variables=None): + runtime_env=None): """Configures and overrides the actor instantiation parameters. The arguments are the same as those that can be passed @@ -464,6 +475,12 @@ def method(self): actor_cls = self + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + new_runtime_env = ParsedRuntimeEnv( + runtime_env or {}, is_task_or_actor=True) + class ActorOptionWrapper: def remote(self, *args, **kwargs): return actor_cls._remote( @@ -485,9 +502,7 @@ def remote(self, *args, **kwargs): placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), - runtime_env=runtime_env, - override_environment_variables=( - override_environment_variables)) + runtime_env=new_runtime_env) return ActorOptionWrapper() @@ -510,8 +525,7 @@ def _remote(self, placement_group="default", placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, - runtime_env=None, - override_environment_variables=None): + runtime_env=None): """Create an actor. This method allows more flexibility than the remote method because @@ -557,9 +571,6 @@ def _remote(self, this actor or task and its children (see :ref:`runtime-environments` for details). This API is in beta and may change before becoming stable. - override_environment_variables: Environment variables to override - and/or introduce for this actor. This is a dictionary mapping - variable names to their values. Returns: A handle to the newly created actor. @@ -584,7 +595,7 @@ def _remote(self, if max_concurrency < 1: raise ValueError("max_concurrency must be >= 1") - if client_mode_should_convert(): + if client_mode_should_convert(auto_init=True): return client_mode_convert_actor( self, args, @@ -605,9 +616,7 @@ def _remote(self, placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), - runtime_env=runtime_env, - override_environment_variables=( - override_environment_variables)) + runtime_env=runtime_env) worker = ray.worker.global_worker worker.check_connected() @@ -723,18 +732,16 @@ def _remote(self, creation_args = signature.flatten_args(function_signature, args, kwargs) - if runtime_env is None: + if runtime_env and not isinstance(runtime_env, ParsedRuntimeEnv): + runtime_env = ParsedRuntimeEnv(runtime_env) + elif isinstance(runtime_env, ParsedRuntimeEnv): + pass + else: runtime_env = meta.runtime_env - job_runtime_env = worker.core_worker.get_current_runtime_env_dict() - runtime_env_dict = runtime_support.override_task_or_actor_runtime_env( - runtime_env, job_runtime_env) - - if override_environment_variables: - logger.warning("override_environment_variables is deprecated and " - "will be removed in Ray 1.6. Please use " - ".options(runtime_env={'env_vars': {...}}).remote()" - "instead.") + parent_runtime_env = worker.core_worker.get_current_runtime_env() + parsed_runtime_env = override_task_or_actor_runtime_env( + runtime_env, parent_runtime_env) actor_id = worker.core_worker.create_actor( meta.language, @@ -754,9 +761,8 @@ def _remote(self, placement_group_capture_child_tasks, # Store actor_method_cpu in actor handle's extension data. extension_data=str(actor_method_cpu), - runtime_env_dict=runtime_env_dict, - override_environment_variables=override_environment_variables - or dict()) + serialized_runtime_env=parsed_runtime_env.serialize(), + runtime_env_uris=parsed_runtime_env.get("uris") or []) actor_handle = ActorHandle( meta.language, diff --git a/python/ray/autoscaler/_private/autoscaler.py b/python/ray/autoscaler/_private/autoscaler.py index 3b26b845e8070..f626de070a0b5 100644 --- a/python/ray/autoscaler/_private/autoscaler.py +++ b/python/ray/autoscaler/_private/autoscaler.py @@ -485,7 +485,8 @@ def _report_pending_infeasible(self, unfulfilled: List[ResourceDict]): pending = [] infeasible = [] for bundle in unfulfilled: - placement_group = any("_group_" in k for k in bundle) + placement_group = any( + "_group_" in k or k == "bundle" for k in bundle) if placement_group: continue if self.resource_demand_scheduler.is_feasible(bundle): diff --git a/python/ray/autoscaler/_private/docker.py b/python/ray/autoscaler/_private/docker.py index 8d94759549217..92dd16ad5001f 100644 --- a/python/ray/autoscaler/_private/docker.py +++ b/python/ray/autoscaler/_private/docker.py @@ -18,7 +18,7 @@ def _check_docker_file_mounts(file_mounts: Dict[str, str]) -> None: if Path(local).is_file(): cli_logger.warning( f"File Mount: ({remote}:{local}) refers to a file.\n To ensure" - "this mount updates properly, please use a directory.") + " this mount updates properly, please use a directory.") def validate_docker_config(config: Dict[str, Any]) -> None: diff --git a/python/ray/autoscaler/_private/gcp/node.py b/python/ray/autoscaler/_private/gcp/node.py index 93a9933ddc186..69a456ac56c0e 100644 --- a/python/ray/autoscaler/_private/gcp/node.py +++ b/python/ray/autoscaler/_private/gcp/node.py @@ -437,8 +437,26 @@ def create_instance(self, "name": name }) + # Allow Google Compute Engine instance templates. + # + # Config example: + # + # ... + # node_config: + # sourceInstanceTemplate: global/instanceTemplates/worker-16 + # machineType: e2-standard-16 + # ... + # + # node_config parameters override matching template parameters, if any. + # + # https://cloud.google.com/compute/docs/instance-templates + # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert + source_instance_template = config.pop("sourceInstanceTemplate", None) + operation = self.resource.instances().insert( - project=self.project_id, zone=self.availability_zone, + project=self.project_id, + zone=self.availability_zone, + sourceInstanceTemplate=source_instance_template, body=config).execute() if wait_for_operation: diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index 172bd5b74b57d..abe0754a31502 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -223,7 +223,7 @@ def update_load_metrics(self): request = gcs_service_pb2.GetAllResourceUsageRequest() response = self.gcs_node_resources_stub.GetAllResourceUsage( - request, timeout=4) + request, timeout=60) resources_batch_data = response.resource_usage_data # Tell the readonly node provider what nodes to report. @@ -244,8 +244,7 @@ def update_load_metrics(self): resource_message.node_id.hex()) resources = {} for k, v in resource_message.resources_total.items(): - if not k.startswith("node:"): - resources[k] = v + resources[k] = v mirror_node_types[node_type] = { "resources": resources, "node_config": {}, diff --git a/python/ray/autoscaler/_private/resource_demand_scheduler.py b/python/ray/autoscaler/_private/resource_demand_scheduler.py index 517f49f63281c..bc3fe5925140f 100644 --- a/python/ray/autoscaler/_private/resource_demand_scheduler.py +++ b/python/ray/autoscaler/_private/resource_demand_scheduler.py @@ -764,7 +764,11 @@ def _utilization_score(node_resources: ResourceDict, return None fittable = [] + resource_types = set() for r in resources: + for k, v in r.items(): + if v > 0: + resource_types.add(k) if _fits(remaining, r): fittable.append(r) _inplace_subtract(remaining, r) @@ -772,12 +776,15 @@ def _utilization_score(node_resources: ResourceDict, return None util_by_resources = [] + num_matching_resource_types = 0 for k, v in node_resources.items(): # Don't divide by zero. if v < 1: # Could test v == 0 on the nose, but v < 1 feels safer. # (Note that node resources are integers.) continue + if k in resource_types: + num_matching_resource_types += 1 util = (v - remaining[k]) / v util_by_resources.append(v * (util**3)) @@ -785,9 +792,11 @@ def _utilization_score(node_resources: ResourceDict, if not util_by_resources: return None - # Prioritize using all resources first, then prioritize overall balance + # Prioritize matching multiple resource types first, then prioritize + # using all resources, then prioritize overall balance # of multiple resources. - return (min(util_by_resources), np.mean(util_by_resources)) + return (num_matching_resource_types, min(util_by_resources), + np.mean(util_by_resources)) def get_bin_pack_residual(node_resources: List[ResourceDict], @@ -818,7 +827,16 @@ def get_bin_pack_residual(node_resources: List[ResourceDict], nodes = copy.deepcopy(node_resources) # List of nodes that cannot be used again due to strict spread. used = [] - for demand in resource_demands: + # We order the resource demands in the following way: + # More complex demands first. + # Break ties: heavier demands first. + # Break ties: lexicographically (to ensure stable ordering). + for demand in sorted( + resource_demands, + key=lambda demand: (len(demand.values()), + sum(demand.values()), + sorted(demand.items())), + reverse=True): found = False node = None for i in range(len(nodes)): diff --git a/python/ray/autoscaler/gcp/tpu.yaml b/python/ray/autoscaler/gcp/tpu.yaml index 34726cb2205b4..a963e62c1898d 100644 --- a/python/ray/autoscaler/gcp/tpu.yaml +++ b/python/ray/autoscaler/gcp/tpu.yaml @@ -32,9 +32,9 @@ available_node_types: # Support for TPU pods will be added in the future. acceleratorType: v2-8 runtimeVersion: v2-alpha - # Uncomment to use preemptible TPUs - # schedulingConfig: - # preemptible: true + schedulingConfig: + # Set to false to use non-preemptible TPUs + preemptible: true provider: type: gcp @@ -51,15 +51,21 @@ head_node_type: ray_head_default # Compute instances have python 3.7, but TPUs have 3.8 - need to update # Install Jax and other dependencies on the Compute head node head_setup_commands: - - conda create -y -n "ray" python=3.8.5 && sudo update-alternatives --install /opt/conda/bin/python python /opt/conda/envs/ray/bin/python 10 && sudo update-alternatives --install /opt/conda/bin/pip pip /opt/conda/envs/ray/bin/pip 10 - - export PATH="$PATH:/opt/conda/envs/ray/bin" && echo 'export PATH="$PATH:/opt/conda/envs/ray/bin"' >> ~/.bashrc - - python -m pip install --upgrade "jax[cpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + # Two first lines are a workaround for ssh timing out + - sleep 2 + - sleep 2 + - sudo chown -R $(whoami) /opt/conda/* + - conda create -y -n "ray" python=3.8.5 + - conda activate ray && echo 'conda activate ray' >> ~/.bashrc + - python -m pip install --upgrade pip + - python -m pip install --upgrade "jax[cpu]==0.2.14" - python -m pip install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku google-api-python-client cryptography tensorboardX ray[default] - python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl - git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && python -m pip install . # Install Jax and other dependencies on TPU worker_setup_commands: + - pip3 install --upgrade pip - pip3 install --upgrade "jax[tpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip3 install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku tensorboardX ray[default] - python3 -c "import jax; jax.device_count(); jax.numpy.add(1, 1)" # test if Jax has been installed correctly diff --git a/python/ray/cross_language.py b/python/ray/cross_language.py index eae3939ffa11e..f3ad1de0a3903 100644 --- a/python/ray/cross_language.py +++ b/python/ray/cross_language.py @@ -79,7 +79,8 @@ def java_function(class_name, function_name): None, # max_calls, None, # max_retries, None, # retry_exceptions, - None) # runtime_env + None, # runtime_env + None) # placement_group @PublicAPI(stability="beta") diff --git a/python/ray/data/__init__.py b/python/ray/data/__init__.py index 521add717220c..c6e411fadf86b 100644 --- a/python/ray/data/__init__.py +++ b/python/ray/data/__init__.py @@ -1,7 +1,8 @@ from ray.data.read_api import from_items, range, range_arrow, \ range_tensor, read_parquet, read_json, read_csv, read_binary_files, \ - from_dask, from_modin, from_mars, from_pandas, from_numpy, from_arrow, \ - from_spark, read_datasource, read_numpy, read_text + from_dask, from_modin, from_mars, from_pandas, from_pandas_refs, \ + from_numpy, from_arrow, from_arrow_refs, from_spark, read_datasource, \ + read_numpy, read_text from ray.data.datasource import Datasource, ReadTask from ray.data.dataset import Dataset from ray.data.impl.progress_bar import set_progress_bars @@ -18,10 +19,12 @@ "from_dask", "from_items", "from_arrow", + "from_arrow_refs", "from_mars", "from_modin", "from_numpy", "from_pandas", + "from_pandas_refs", "from_spark", "range", "range_arrow", diff --git a/python/ray/data/block.py b/python/ray/data/block.py index 35b99780c5e0d..e7edab74863ad 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -16,8 +16,8 @@ # Represents a batch of records to be stored in the Ray object store. # # Block data can be accessed in a uniform way via ``BlockAccessors`` such as -# ``SimpleBlockAccessor``, ``ArrowBlockAccessor``, and ``TensorBlockAccessor``. -Block = Union[List[T], np.ndarray, "pyarrow.Table", bytes] +# ``SimpleBlockAccessor`` and ``ArrowBlockAccessor``. +Block = Union[List[T], "pyarrow.Table", bytes] @DeveloperAPI @@ -52,8 +52,8 @@ class BlockAccessor(Generic[T]): as a top-level Ray object, without a wrapping class (issue #17186). There are three types of block accessors: ``SimpleBlockAccessor``, which - operates over a plain Python list, ``ArrowBlockAccessor``, for - ``pyarrow.Table`` type blocks, and ``TensorBlockAccessor``, for tensors. + operates over a plain Python list, and ``ArrowBlockAccessor`` for + ``pyarrow.Table`` type blocks. """ def num_rows(self) -> int: @@ -85,12 +85,16 @@ def to_pandas(self) -> "pandas.DataFrame": """Convert this block into a Pandas dataframe.""" raise NotImplementedError - def to_numpy(self) -> np.ndarray: - """Convert this block into a NumPy ndarray.""" + def to_numpy(self, column: str = None) -> np.ndarray: + """Convert this block (or column of block) into a NumPy ndarray. + + Args: + column: Name of column to convert, or None. + """ raise NotImplementedError - def to_arrow(self) -> Union["pyarrow.Table", "pyarrow.Tensor"]: - """Convert this block into an Arrow table or tensor.""" + def to_arrow(self) -> "pyarrow.Table": + """Convert this block into an Arrow table.""" raise NotImplementedError def size_bytes(self) -> int: @@ -136,10 +140,6 @@ def for_block(block: Block) -> "BlockAccessor[T]": from ray.data.impl.simple_block import \ SimpleBlockAccessor return SimpleBlockAccessor(block) - elif isinstance(block, np.ndarray): - from ray.data.impl.tensor_block import \ - TensorBlockAccessor - return TensorBlockAccessor(block) else: raise TypeError("Not a block type: {}".format(block)) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 11d0a13c9cbae..6ecabe4c5ce37 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -51,8 +51,7 @@ class Dataset(Generic[T]): Datasets are implemented as a list of ``ObjectRef[Block]``. The block also determines the unit of parallelism. The default block type is the - ``pyarrow.Table``. Tensor objects are held in ``np.ndarray`` blocks, - and other Arrow-incompatible objects are held in ``list`` blocks. + ``pyarrow.Table``. Arrow-incompatible objects are held in ``list`` blocks. Since Datasets are just lists of Ray object refs, they can be passed between Ray tasks and actors just like any other object. Datasets support @@ -169,7 +168,7 @@ def map_batches(self, tasks, or "actors" to use an autoscaling Ray actor pool. batch_format: Specify "native" to use the native block format, "pandas" to select ``pandas.DataFrame`` as the batch format, - or "pyarrow" to select ``pyarrow.Table/Tensor``. + or "pyarrow" to select ``pyarrow.Table``. ray_remote_args: Additional resource requirements to request from ray (e.g., num_gpus=1 to request GPUs for the map tasks). """ @@ -205,19 +204,15 @@ def transform(block: Block) -> Block: "or 'pyarrow', got: {}".format(batch_format)) applied = fn(view) - if (isinstance(applied, list) or isinstance(applied, pa.Table) - or isinstance(applied, np.ndarray)): + if isinstance(applied, list) or isinstance(applied, pa.Table): applied = applied elif isinstance(applied, pd.core.frame.DataFrame): applied = pa.Table.from_pandas(applied) - elif isinstance(applied, pa.Tensor): - applied = applied.to_numpy() else: raise ValueError("The map batches UDF returned a type " f"{type(applied)}, which is not allowed. " "The return type must be either list, " - "pandas.DataFrame, np.ndarray, " - "pyarrow.Tensor, or pyarrow.Table") + "pandas.DataFrame, or pyarrow.Table") builder.add_block(applied) return builder.build() @@ -352,8 +347,13 @@ def random_shuffle( Returns: The shuffled dataset. """ + curr_num_blocks = self.num_blocks() + # Handle empty dataset. + if curr_num_blocks == 0: + return self + if num_blocks is None: - num_blocks = self.num_blocks() + num_blocks = curr_num_blocks new_blocks = simple_shuffle( self._move_blocks() if _move else self._blocks, num_blocks, @@ -402,24 +402,150 @@ def split(self, if n <= 0: raise ValueError(f"The number of splits {n} is not positive.") - if n > self.num_blocks() and equal: - raise NotImplementedError( - f"The number of splits {n} > the number of dataset blocks " - f"{self.num_blocks()}, yet an equal split was requested.") - if locality_hints and len(locality_hints) != n: raise ValueError( f"The length of locality_hints {len(locality_hints)} " "doesn't equal the number of splits {n}.") - # TODO(ekl) we could do better than truncation here. This could be a - # problem if block sizes are very skewed. - def equalize(splits: List[Dataset[T]]) -> List[Dataset[T]]: + def _partition_splits(splits: List[Dataset[T]], part_size: int, + counts_cache: Dict[str, int]): + """Partition splits into two sets: splits that are smaller than the + target size and splits that are larger than the target size. + """ + splits = sorted(splits, key=lambda s: counts_cache[s._get_uuid()]) + idx = next(i for i, split in enumerate(splits) + if counts_cache[split._get_uuid()] >= part_size) + return splits[:idx], splits[idx:] + + def _equalize_larger_splits(splits: List[Dataset[T]], target_size: int, + counts_cache: Dict[str, int], + num_splits_required: int): + """Split each split into one or more subsplits that are each the + target size, with at most one leftover split that's smaller + than the target size. + + This assume that the given splits are sorted in ascending order. + """ + new_splits = [] + leftovers = [] + for split in splits: + size = counts_cache[split._get_uuid()] + if size == target_size: + new_splits.append(split) + continue + split_indices = list(range(target_size, size, target_size)) + split_splits = split.split_at_indices(split_indices) + last_split_size = split_splits[-1].count() + if last_split_size < target_size: + # Last split is smaller than the target size, save it for + # our unioning of small splits. + leftover = split_splits.pop() + leftovers.append(leftover) + counts_cache[leftover._get_uuid()] = leftover.count() + if len(new_splits) + len(split_splits) >= num_splits_required: + # Short-circuit if the new splits will make us reach the + # desired number of splits. + new_splits.extend( + split_splits[:num_splits_required - len(new_splits)]) + break + new_splits.extend(split_splits) + return new_splits, leftovers + + def _equalize_smaller_splits( + splits: List[Dataset[T]], target_size: int, + counts_cache: Dict[str, int], num_splits_required: int): + """Union small splits up to the target split size. + + This assume that the given splits are sorted in ascending order. + """ + new_splits = [] + union_buffer = [] + union_buffer_size = 0 + low = 0 + high = len(splits) - 1 + while low <= high: + # Union small splits up to the target split size. + low_split = splits[low] + low_count = counts_cache[low_split._get_uuid()] + high_split = splits[high] + high_count = counts_cache[high_split._get_uuid()] + if union_buffer_size + high_count <= target_size: + # Try to add the larger split to the union buffer first. + union_buffer.append(high_split) + union_buffer_size += high_count + high -= 1 + elif union_buffer_size + low_count <= target_size: + union_buffer.append(low_split) + union_buffer_size += low_count + low += 1 + else: + # Neither the larger nor smaller split fit in the union + # buffer, so we split the smaller split into a subsplit + # that will fit into the union buffer and a leftover + # subsplit that we add back into the candidate split list. + diff = target_size - union_buffer_size + diff_split, new_low_split = low_split.split_at_indices( + [diff]) + union_buffer.append(diff_split) + union_buffer_size += diff + # We overwrite the old low split and don't advance the low + # pointer since (1) the old low split can be discarded, + # (2) the leftover subsplit is guaranteed to be smaller + # than the old low split, and (3) the low split should be + # the smallest split in the candidate split list, which is + # this subsplit. + splits[low] = new_low_split + counts_cache[new_low_split._get_uuid()] = low_count - diff + if union_buffer_size == target_size: + # Once the union buffer is full, we union together the + # splits. + assert len(union_buffer) > 1, union_buffer + first_ds = union_buffer[0] + new_split = first_ds.union(*union_buffer[1:]) + new_splits.append(new_split) + # Clear the union buffer. + union_buffer = [] + union_buffer_size = 0 + if len(new_splits) == num_splits_required: + # Short-circuit if we've reached the desired number of + # splits. + break + return new_splits + + def equalize(splits: List[Dataset[T]], + num_splits: int) -> List[Dataset[T]]: if not equal: return splits - lower_bound = min([s.count() for s in splits]) - assert lower_bound > 0, splits - return [s.limit(lower_bound) for s in splits] + counts = {s._get_uuid(): s.count() for s in splits} + total_rows = sum(counts.values()) + # Number of rows for each split. + target_size = total_rows // num_splits + + # Partition splits. + smaller_splits, larger_splits = _partition_splits( + splits, target_size, counts) + if len(smaller_splits) == 0 and num_splits < len(splits): + # All splits are already equal. + return splits + + # Split larger splits. + new_splits, leftovers = _equalize_larger_splits( + larger_splits, target_size, counts, num_splits) + # Short-circuit if we've already reached the desired number of + # splits. + if len(new_splits) == num_splits: + return new_splits + # Add leftovers to small splits and re-sort. + smaller_splits += leftovers + smaller_splits = sorted( + smaller_splits, key=lambda s: counts[s._get_uuid()]) + + # Union smaller splits. + new_splits_small = _equalize_smaller_splits( + smaller_splits, target_size, counts, + num_splits - len(new_splits)) + new_splits.extend(new_splits_small) + return new_splits block_refs = list(self._blocks) metadata_mapping = { @@ -433,7 +559,8 @@ def equalize(splits: List[Dataset[T]]) -> List[Dataset[T]]: BlockList( list(blocks), [metadata_mapping[b] for b in blocks])) for blocks in np.array_split(block_refs, n) - ]) + if not equal or len(blocks) > 0 + ], n) # If the locality_hints is set, we use a two-round greedy algorithm # to co-locate the blocks with the actors based on block @@ -532,7 +659,7 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]: [metadata_mapping[b] for b in allocation_per_actor[actor]])) for actor in locality_hints - ]) + ], n) def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]: """Split the dataset at the given indices (like np.split). @@ -580,6 +707,9 @@ def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]: def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]": """Combine this dataset with others of the same type. + The order of the blocks in the datasets is preserved, as is the + relative ordering between the datasets passed in the argument list. + Args: other: List of datasets to combine with this one. The datasets must have the same schema as this dataset, otherwise the @@ -589,35 +719,21 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]": A new dataset holding the union of their data. """ - blocks: List[ObjectRef[Block]] = [] + calls: List[Callable[[], ObjectRef[Block]]] = [] metadata: List[BlockMetadata] = [] - pending_blocks: List[Callable[[], ObjectRef[Block]]] = [] - pending_metadata: List[BlockMetadata] = [] + blocks: List[ObjectRef[Block]] = [] datasets = [self] + list(other) for ds in datasets: bl = ds._blocks if isinstance(bl, LazyBlockList): - for block, meta in zip(bl._blocks, bl._metadata): - blocks.append(block) - metadata.append(meta) - lim = len(bl._blocks) - for call, meta in zip(bl._calls[lim:], bl._metadata[lim:]): - pending_blocks.append(call) - pending_metadata.append(meta) + calls.extend(bl._calls) else: - assert isinstance(bl, BlockList), bl - blocks.extend(list(bl._blocks)) - metadata.extend(bl.get_metadata()) - - result = LazyBlockList([], []) - result._calls = ([None] * len(blocks)) + pending_blocks - result._blocks = blocks - result._metadata = metadata + pending_metadata + calls.extend([None] * len(bl)) + metadata.extend(bl._metadata) + blocks.extend(bl._blocks) - assert len(result._calls) == len(result._metadata), result - assert len(result._blocks) <= len(result._calls), result - return Dataset(result) + return Dataset(LazyBlockList(calls, metadata, blocks)) def sort(self, key: Union[None, str, List[str], Callable[[T], Any]] = None, @@ -653,6 +769,9 @@ def sort(self, Returns: A new, sorted dataset. """ + # Handle empty dataset. + if self.num_blocks() == 0: + return self return Dataset(sort_impl(self._blocks, key, descending)) def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]": @@ -678,8 +797,8 @@ def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]": comes from the first dataset and v comes from the second. """ - blocks1 = self.get_blocks() - blocks2 = other.get_blocks() + blocks1 = self.get_internal_block_refs() + blocks2 = other.get_internal_block_refs() if len(blocks1) != len(blocks2): # TODO(ekl) consider supporting if num_rows are equal. @@ -761,6 +880,9 @@ def count(self) -> int: Returns: The number of records in the dataset. """ + # Handle empty dataset. + if self.num_blocks() == 0: + return 0 # For parquet, we can return the count directly from metadata. meta_count = self._meta_count() @@ -849,6 +971,7 @@ def write_parquet(self, path: str, *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, **arrow_parquet_args) -> None: """Write the dataset to parquet. @@ -867,6 +990,8 @@ def write_parquet(self, path: The path to the destination root directory, where Parquet files will be written to. filesystem: The filesystem implementation to write to. + try_create_dir: Try to create all directories in destination path + if True. Does nothing if all directories already exist. arrow_parquet_args: Options to pass to pyarrow.parquet.write_table(), which is used to write out each block to a file. @@ -876,12 +1001,14 @@ def write_parquet(self, path=path, dataset_uuid=self._uuid, filesystem=filesystem, + try_create_dir=try_create_dir, **arrow_parquet_args) def write_json(self, path: str, *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, **pandas_json_args) -> None: """Write the dataset to json. @@ -900,6 +1027,8 @@ def write_json(self, path: The path to the destination root directory, where json files will be written to. filesystem: The filesystem implementation to write to. + try_create_dir: Try to create all directories in destination path + if True. Does nothing if all directories already exist. pandas_json_args: These args will be passed to pandas.DataFrame.to_json(), which we use under the hood to write out each Datasets block. These @@ -910,12 +1039,14 @@ def write_json(self, path=path, dataset_uuid=self._uuid, filesystem=filesystem, + try_create_dir=try_create_dir, **pandas_json_args) def write_csv(self, path: str, *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, **arrow_csv_args) -> None: """Write the dataset to csv. @@ -934,6 +1065,8 @@ def write_csv(self, path: The path to the destination root directory, where csv files will be written to. filesystem: The filesystem implementation to write to. + try_create_dir: Try to create all directories in destination path + if True. Does nothing if all directories already exist. arrow_csv_args: Other CSV write options to pass to pyarrow. """ self.write_datasource( @@ -941,17 +1074,20 @@ def write_csv(self, path=path, dataset_uuid=self._uuid, filesystem=filesystem, + try_create_dir=try_create_dir, **arrow_csv_args) - def write_numpy( - self, - path: str, - *, - filesystem: Optional["pyarrow.fs.FileSystem"] = None) -> None: - """Write the dataset to npy files. + def write_numpy(self, + path: str, + *, + column: str = "value", + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True) -> None: + """Write a tensor column of the dataset to npy files. - This is only supported for datasets of Tensor records. - To control the number of files, use ``.repartition()``. + This is only supported for datasets convertible to Arrow records that + contain a TensorArray column. To control the number of files, use + ``.repartition()``. The format of the output files will be {self._uuid}_{block_idx}.npy, where ``uuid`` is an unique id for the dataset. @@ -964,13 +1100,19 @@ def write_numpy( Args: path: The path to the destination root directory, where npy files will be written to. + column: The name of the table column that contains the tensor to + be written. This defaults to "value". filesystem: The filesystem implementation to write to. + try_create_dir: Try to create all directories in destination path + if True. Does nothing if all directories already exist. """ self.write_datasource( NumpyDatasource(), path=path, dataset_uuid=self._uuid, - filesystem=filesystem) + column=column, + filesystem=filesystem, + try_create_dir=try_create_dir) def write_datasource(self, datasource: Datasource[T], **write_args) -> None: @@ -1042,7 +1184,7 @@ def iter_batches(self, batch_format: The format in which to return each batch. Specify "native" to use the current block format, "pandas" to select ``pandas.DataFrame`` or "pyarrow" to select - ``pyarrow.Table/Tensor``. Default is "native". + ``pyarrow.Table``. Default is "native". drop_last: Whether to drop the last batch if it's incomplete. Returns: @@ -1310,14 +1452,15 @@ def to_modin(self) -> "modin.DataFrame": """Convert this dataset into a Modin dataframe. This works by first converting this dataset into a distributed set of - Pandas dataframes (using ``.to_pandas()``). Please see caveats there. - Then the individual dataframes are used to create the modin DataFrame - using + Pandas dataframes (using ``.to_pandas_refs()``). Please see caveats + there. Then the individual dataframes are used to create the modin + DataFrame using ``modin.distributed.dataframe.pandas.partitions.from_partitions()``. This is only supported for datasets convertible to Arrow records. This function induces a copy of the data. For zero-copy access to the - underlying data, consider using ``.to_arrow()`` or ``.get_blocks()``. + underlying data, consider using ``.to_arrow()`` or + ``.get_internal_block_refs()``. Time complexity: O(dataset size / parallelism) @@ -1327,7 +1470,7 @@ def to_modin(self) -> "modin.DataFrame": from modin.distributed.dataframe.pandas.partitions import ( from_partitions) - pd_objs = self.to_pandas() + pd_objs = self.to_pandas_refs() return from_partitions(pd_objs, axis=0) def to_spark(self, @@ -1343,17 +1486,45 @@ def to_spark(self, core_worker = ray.worker.global_worker.core_worker locations = [ core_worker.get_owner_address(block) - for block in self.get_blocks() + for block in self.get_internal_block_refs() ] return raydp.spark.ray_dataset_to_spark_dataframe( - spark, self.schema(), self.get_blocks(), locations) + spark, self.schema(), self.get_internal_block_refs(), locations) + + def to_pandas(self, limit: int = 1000) -> "pandas.DataFrame": + """Convert this dataset into a single Pandas DataFrame. - def to_pandas(self) -> List[ObjectRef["pandas.DataFrame"]]: + This is only supported for datasets convertible to Arrow records. This + limits the number of records returned to the provided limit. + + Time complexity: O(limit) + + Args: + limit: The maximum number of records to return. + + Returns: + A Pandas DataFrame created from this dataset, containing a limited + number of records. + """ + + if self.count() > limit: + logger.warning(f"Only returning the first {limit} records from " + "to_pandas()") + limited_ds = self.limit(limit) + blocks = limited_ds.get_internal_block_refs() + output = DelegatingArrowBlockBuilder() + for block in ray.get(blocks): + output.add_block(block) + return output.build().to_pandas() + + @DeveloperAPI + def to_pandas_refs(self) -> List[ObjectRef["pandas.DataFrame"]]: """Convert this dataset into a distributed set of Pandas dataframes. This is only supported for datasets convertible to Arrow records. This function induces a copy of the data. For zero-copy access to the - underlying data, consider using ``.to_arrow()`` or ``.get_blocks()``. + underlying data, consider using ``.to_arrow()`` or + ``.get_internal_block_refs()``. Time complexity: O(dataset size / parallelism) @@ -1364,23 +1535,48 @@ def to_pandas(self) -> List[ObjectRef["pandas.DataFrame"]]: block_to_df = cached_remote_fn(_block_to_df) return [block_to_df.remote(block) for block in self._blocks] - def to_numpy(self) -> List[ObjectRef[np.ndarray]]: + def to_numpy(self, *, + column: Optional[str] = None) -> List[ObjectRef[np.ndarray]]: """Convert this dataset into a distributed set of NumPy ndarrays. This is only supported for datasets convertible to NumPy ndarrays. This function induces a copy of the data. For zero-copy access to the - underlying data, consider using ``.to_arrow()`` or ``.get_blocks()``. + underlying data, consider using ``.to_arrow()`` or + ``.get_internal_block_refs()``. Time complexity: O(dataset size / parallelism) + Args: + column: The name of the column to convert to numpy, or None to + specify the entire row. Required for Arrow tables. + Returns: A list of remote NumPy ndarrays created from this dataset. """ block_to_ndarray = cached_remote_fn(_block_to_ndarray) - return [block_to_ndarray.remote(block) for block in self._blocks] + return [ + block_to_ndarray.remote(block, column=column) + for block in self._blocks + ] - def to_arrow(self) -> List[ObjectRef["pyarrow.Table"]]: + def to_arrow(self) -> List["pyarrow.Table"]: + """Convert this dataset into a list of Arrow tables. + + This is only supported for datasets convertible to Arrow records. + This function is zero-copy if the existing data is already in Arrow + format. Otherwise, the data will be converted to Arrow format. + + Time complexity: O(1) unless conversion is required. + + Returns: + A list of Arrow tables created from this dataset. + """ + + return ray.get(self.to_arrow_refs()) + + @DeveloperAPI + def to_arrow_refs(self) -> List[ObjectRef["pyarrow.Table"]]: """Convert this dataset into a distributed set of Arrow tables. This is only supported for datasets convertible to Arrow records. @@ -1450,28 +1646,32 @@ def __init__(self, ds: "Dataset[T]"): def __iter__(self): return Iterator(self._ds) - return DatasetPipeline(Iterable(self), length=times) + return DatasetPipeline(Iterable(self), length=times or float("inf")) def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]": - """Pipeline the dataset execution by splitting its blocks into groups. + raise DeprecationWarning("Use .window(blocks_per_window=n) instead of " + ".pipeline(parallelism=n)") - Transformations prior to the call to ``pipeline()`` are evaluated in + def window(self, *, blocks_per_window: int = 10) -> "DatasetPipeline[T]": + """Convert this into a DatasetPipeline by windowing over data blocks. + + Transformations prior to the call to ``window()`` are evaluated in bulk on the entire dataset. Transformations done on the returned - pipeline are evaluated incrementally per group of blocks as data is + pipeline are evaluated incrementally per window of blocks as data is read from the output of the pipeline. - Pipelining execution allows for output to be read sooner without + Windowing execution allows for output to be read sooner without waiting for all transformations to fully execute, and can also improve efficiency if transforms use different resources (e.g., GPUs). - Without pipelining:: + Without windowing:: [preprocessing......] [inference.......] [write........] Time -----------------------------------------------------------> - With pipelining:: + With windowing:: [prep1] [prep2] [prep3] [infer1] [infer2] [infer3] @@ -1481,20 +1681,20 @@ def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]": Examples: >>> # Create an inference pipeline. >>> ds = ray.data.read_binary_files(dir) - >>> pipe = ds.pipeline(parallelism=10).map(infer) - DatasetPipeline(num_stages=2, length=40) + >>> pipe = ds.window(blocks_per_window=10).map(infer) + DatasetPipeline(num_windows=40, num_stages=2) >>> # The higher the stage parallelism, the shorter the pipeline. - >>> pipe = ds.pipeline(parallelism=20).map(infer) - DatasetPipeline(num_stages=2, length=20) + >>> pipe = ds.window(blocks_per_window=20).map(infer) + DatasetPipeline(num_windows=20, num_stages=2) >>> # Outputs can be incrementally read from the pipeline. >>> for item in pipe.iter_rows(): ... print(item) Args: - parallelism: The parallelism (number of blocks) per stage. - Increasing parallelism increases pipeline throughput, but also + blocks_per_window: The window size (parallelism) in blocks. + Increasing window size increases pipeline throughput, but also increases the latency to initial output, since it decreases the length of the pipeline. Setting this to infinity effectively disables pipelining. @@ -1518,7 +1718,7 @@ def gen(): class Iterable: def __init__(self, blocks): - self._splits = blocks.split(split_size=parallelism) + self._splits = blocks.split(split_size=blocks_per_window) def __iter__(self): return Iterator(self._splits) @@ -1527,7 +1727,7 @@ def __iter__(self): return DatasetPipeline(it, length=len(it._splits)) @DeveloperAPI - def get_blocks(self) -> List[ObjectRef[Block]]: + def get_internal_block_refs(self) -> List[ObjectRef[Block]]: """Get a list of references to the underlying blocks of this dataset. This function can be used for zero-copy access to the data. @@ -1581,13 +1781,14 @@ def _split(self, index: int, right = None return left, right + def _divide(self, block_idx: int) -> ("Dataset[T]", "Dataset[T]"): + left, right = self._blocks.divide(block_idx) + return Dataset(left), Dataset(right) + def __repr__(self) -> str: schema = self.schema() if schema is None: schema_str = "Unknown schema" - elif isinstance(schema, dict): - schema_str = "".format( - schema["shape"], schema["dtype"]) elif isinstance(schema, type): schema_str = str(schema) else: @@ -1599,8 +1800,6 @@ def __repr__(self) -> str: schema_str = ", ".join(schema_str) schema_str = "{" + schema_str + "}" count = self._meta_count() - if count is None: - count = "?" return "Dataset(num_blocks={}, num_rows={}, schema={})".format( len(self._blocks), count, schema_str) @@ -1640,9 +1839,9 @@ def _block_to_df(block: Block): return block.to_pandas() -def _block_to_ndarray(block: Block): +def _block_to_ndarray(block: Block, column: Optional[str]): block = BlockAccessor.for_block(block) - return block.to_numpy() + return block.to_numpy(column) def _block_to_arrow(block: Block): diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 158905e70e9f9..962961105f895 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1,7 +1,7 @@ import functools import time from typing import Any, Callable, List, Iterator, Iterable, Generic, Union, \ - TYPE_CHECKING + Optional, TYPE_CHECKING import ray from ray.data.dataset import Dataset, T, U, BatchType @@ -13,13 +13,15 @@ if TYPE_CHECKING: import pyarrow -# Operations that can be naively applied per dataset in the pipeline. +# Operations that can be naively applied per dataset row in the pipeline. PER_DATASET_OPS = [ - "map", "map_batches", "flat_map", "filter", "repartition", - "random_shuffle", "sort", "write_json", "write_csv", "write_parquet", - "write_datasource" + "map", "map_batches", "flat_map", "filter", "write_json", "write_csv", + "write_parquet", "write_datasource" ] +# Operations that apply to each dataset holistically in the pipeline. +HOLISTIC_PER_DATASET_OPS = ["repartition", "random_shuffle", "sort"] + # Similar to above but we should force evaluation immediately. PER_DATASET_OUTPUT_OPS = [ "write_json", "write_csv", "write_parquet", "write_datasource" @@ -40,7 +42,7 @@ class DatasetPipeline(Generic[T]): A DatasetPipeline can be created by either repeating a Dataset (``ds.repeat(times=None)``), by turning a single Dataset into a pipeline - (``ds.pipeline(parallelism=10)``), or defined explicitly using + (``ds.window(blocks_per_window=10)``), or defined explicitly using ``DatasetPipeline.from_iterable()``. DatasetPipeline supports the all the per-record transforms of Datasets @@ -57,7 +59,7 @@ def __init__(self, """Construct a DatasetPipeline (internal API). The constructor is not part of the DatasetPipeline API. Use the - ``Dataset.repeat()``, ``Dataset.pipeline()``, or + ``Dataset.repeat()``, ``Dataset.window()``, or ``DatasetPipeline.from_iterable()`` methods to construct a pipeline. """ self._base_iterable = base_iterable @@ -240,6 +242,124 @@ def __next__(self): for idx in range(n) ] + def rewindow(self, *, blocks_per_window: int) -> "DatasetPipeline[T]": + """Change the windowing (blocks per dataset) of this pipeline. + + Changes the windowing of this pipeline to the specified size. For + example, if the current pipeline has two blocks per dataset, and + `.rewindow(blocks_per_window=4)` is requested, adjacent datasets will + be merged until each dataset is 4 blocks. If + `.rewindow(blocks_per_window)` was requested the datasets will be + split into smaller windows. + + Args: + blocks_per_window: The new target blocks per window. + """ + + class WindowIterator: + def __init__(self, original_iter): + self._original_iter = original_iter + self._buffer: Optional[Dataset[T]] = None + + def __next__(self) -> Dataset[T]: + try: + # Merge windows until we meet the requested window size. + if self._buffer is None: + self._buffer = next(self._original_iter) + while self._buffer.num_blocks() < blocks_per_window: + self._buffer = self._buffer.union( + next(self._original_iter)) + # Slice off the left-most chunk and return it. + res, self._buffer = self._buffer._divide(blocks_per_window) + assert res.num_blocks() <= blocks_per_window, res + return lambda: res + except StopIteration: + # Return the left-over data as a single window. + if self._buffer and self._buffer.num_blocks() > 0: + res = self._buffer + assert res.num_blocks() <= blocks_per_window, res + self._buffer = None + return lambda: res + else: + raise + + class WindowIterable: + def __init__(self, original_iter): + self._original_iter = original_iter + + def __iter__(self): + return WindowIterator(self._original_iter) + + return DatasetPipeline( + WindowIterable(self.iter_datasets()), length=None) + + def repeat(self, times: int = None) -> "DatasetPipeline[T]": + """Repeat this pipeline a given number or times, or indefinitely. + + This operation is only allowed for pipelines of a finite length. An + error will be raised for pipelines of infinite length. + + Transformations prior to the call to ``repeat()`` are evaluated once. + Transformations done on the repeated pipeline are evaluated on each + loop of the pipeline over the base pipeline. + + Args: + times: The number of times to loop over this pipeline, or None + to repeat indefinitely. + """ + + if self._length == float("inf"): + raise ValueError("Cannot repeat a pipeline of infinite length.") + + class RepeatIterator: + def __init__(self, original_iter): + self._original_iter = original_iter + # Holds results to repeat. + self._results = [] + # Incrementing cursor over results. + self._i = 0 + # This is calculated later. + self._max_i = None + + def __next__(self) -> Dataset[T]: + # Still going through the original pipeline. + if self._original_iter: + try: + res = next(self._original_iter) + self._results.append(res) + return lambda: res + except StopIteration: + self._original_iter = None + # Calculate the cursor limit. + if times: + self._max_i = len(self._results) * (times - 1) + else: + self._max_i = float("inf") + # Going through a repeat of the pipeline. + if self._i < self._max_i: + res = self._results[self._i % len(self._results)] + self._i += 1 + return lambda: res + else: + raise StopIteration + + class RepeatIterable: + def __init__(self, original_iter): + self._original_iter = original_iter + + def __iter__(self): + return RepeatIterator(self._original_iter) + + if not times: + length = float("inf") + elif times and self._length: + length = times * self._length + else: + length = None + + return DatasetPipeline( + RepeatIterable(self.iter_datasets()), length=length) + def schema(self) -> Union[type, "pyarrow.lib.Schema"]: """Return the schema of the dataset pipeline. @@ -287,6 +407,19 @@ def sum(self) -> int: total += elem return total + def show_windows(self, limit_per_dataset: int = 10) -> None: + """Print up to the given number of records from each window/dataset. + + This is helpful as a debugging tool for understanding the structure of + dataset pipelines. + + Args: + limit_per_dataset: Rows to print per window/dataset. + """ + for i, ds in enumerate(self.iter_datasets()): + print("=== Window {} ===".format(i)) + ds.show(limit_per_dataset) + @DeveloperAPI def iter_datasets(self) -> Iterator[Dataset[T]]: """Iterate over the output datasets of this pipeline. @@ -300,9 +433,9 @@ def iter_datasets(self) -> Iterator[Dataset[T]]: return PipelineExecutor(self) @DeveloperAPI - def foreach_dataset(self, fn: Callable[[Dataset[T]], Dataset[U]] - ) -> "DatasetPipeline[U]": - """Apply a transform to each dataset in this pipeline. + def foreach_window(self, fn: Callable[[Dataset[T]], Dataset[U]] + ) -> "DatasetPipeline[U]": + """Apply a transform to each dataset/window in this pipeline. Args: fn: The function to transform each dataset with. @@ -319,6 +452,10 @@ def foreach_dataset(self, fn: Callable[[Dataset[T]], Dataset[U]] self._progress_bars, _executed=self._executed) + def foreach_dataset(self, *a, **kw) -> None: + raise DeprecationWarning( + "`foreach_dataset` has been renamed to `foreach_window`.") + @staticmethod def from_iterable(iterable: Iterable[Callable[[], Dataset[T]]], ) -> "DatasetPipeline[T]": @@ -335,7 +472,7 @@ def from_iterable(iterable: Iterable[Callable[[], Dataset[T]]], return DatasetPipeline(iterable, length=length) def __repr__(self) -> str: - return "DatasetPipeline(length={}, num_stages={})".format( + return "DatasetPipeline(num_windows={}, num_stages={})".format( self._length, 1 + len(self._stages)) def __str__(self) -> str: @@ -355,7 +492,7 @@ def make_impl(method): @functools.wraps(delegate) def impl(self, *args, **kwargs): - return self.foreach_dataset( + return self.foreach_window( lambda ds: getattr(ds, method)(*args, **kwargs)) if impl.__annotations__.get("return"): @@ -366,6 +503,33 @@ def impl(self, *args, **kwargs): setattr(DatasetPipeline, method, make_impl(method)) +for method in HOLISTIC_PER_DATASET_OPS: + + def make_impl(method): + delegate = getattr(Dataset, method) + + @functools.wraps(delegate) + def impl(self, *args, **kwargs): + return self.foreach_window( + lambda ds: getattr(ds, method)(*args, **kwargs)) + + if impl.__annotations__.get("return"): + impl.__annotations__["return"] = impl.__annotations__[ + "return"].replace("Dataset", "DatasetPipeline") + + return impl + + def deprecation_warning(method: str): + def impl(*a, **kw): + raise DeprecationWarning( + "`{}` has been renamed to `{}_each_window`.".format( + method, method)) + + return impl + + setattr(DatasetPipeline, method, deprecation_warning(method)) + setattr(DatasetPipeline, method + "_each_window", make_impl(method)) + for method in PER_DATASET_OUTPUT_OPS: def make_impl(method): diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 46b313ab3bfd0..b45b3ab3930b2 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -130,10 +130,11 @@ def make_block(start: int, count: int) -> Block: return pyarrow.Table.from_arrays( [np.arange(start, start + count)], names=["value"]) elif block_format == "tensor": - return np.ones( - tensor_shape, dtype=np.int64) * np.expand_dims( + tensor = TensorArray( + np.ones(tensor_shape, dtype=np.int64) * np.expand_dims( np.arange(start, start + count), - tuple(range(1, 1 + len(tensor_shape)))) + tuple(range(1, 1 + len(tensor_shape))))) + return pyarrow.Table.from_pydict({"value": tensor}) else: return list(builtins.range(start, start + count)) @@ -145,7 +146,14 @@ def make_block(start: int, count: int) -> Block: import pyarrow schema = pyarrow.Table.from_pydict({"value": [0]}).schema elif block_format == "tensor": - schema = {"dtype": "int64", "shape": (None, ) + tensor_shape} + _check_pyarrow_version() + from ray.data.extensions import TensorArray + import pyarrow + tensor = TensorArray( + np.ones(tensor_shape, dtype=np.int64) * np.expand_dims( + np.arange(0, 10), tuple( + range(1, 1 + len(tensor_shape))))) + schema = pyarrow.Table.from_pydict({"value": tensor}).schema elif block_format == "list": schema = int else: diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 9a326ebdcf62d..af678b1511888 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -115,12 +115,14 @@ def do_write(self, path: str, dataset_uuid: str, filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, _block_udf: Optional[Callable[[Block], Block]] = None, **write_args) -> List[ObjectRef[WriteResult]]: """Creates and returns write tasks for a file-based datasource.""" path, filesystem = _resolve_paths_and_filesystem(path, filesystem) path = path[0] - filesystem.create_dir(path, recursive=True) + if try_create_dir: + filesystem.create_dir(path, recursive=True) filesystem = _wrap_s3_serialization_workaround(filesystem) _write_block_to_file = self._write_block @@ -133,7 +135,8 @@ def write_block(write_path: str, block: Block): if _block_udf is not None: block = _block_udf(block) with fs.open_output_stream(write_path) as f: - _write_block_to_file(f, BlockAccessor.for_block(block)) + _write_block_to_file(f, BlockAccessor.for_block(block), + **write_args) write_block = cached_remote_fn(write_block) @@ -188,9 +191,8 @@ def _resolve_paths_and_filesystem( compatibility. """ import pyarrow as pa - from pyarrow.fs import (FileSystem, PyFileSystem, FSSpecHandler, - _resolve_filesystem_and_path) - import fsspec + from pyarrow.fs import FileSystem, PyFileSystem, FSSpecHandler, \ + _resolve_filesystem_and_path if isinstance(paths, str): paths = [paths] @@ -202,11 +204,20 @@ def _resolve_paths_and_filesystem( raise ValueError("Must provide at least one path.") if filesystem and not isinstance(filesystem, FileSystem): + err_msg = f"The filesystem passed must either conform to " \ + f"pyarrow.fs.FileSystem, or " \ + f"fsspec.spec.AbstractFileSystem. The provided " \ + f"filesystem was: {filesystem}" + try: + import fsspec + except ModuleNotFoundError: + # If filesystem is not a pyarrow filesystem and fsspec isn't + # installed, then filesystem is neither a pyarrow filesystem nor + # an fsspec filesystem, so we raise a TypeError. + raise TypeError(err_msg) if not isinstance(filesystem, fsspec.spec.AbstractFileSystem): - raise TypeError(f"The filesystem passed must either conform to " - f"pyarrow.fs.FileSystem, or " - f"fsspec.spec.AbstractFileSystem. The provided " - f"filesystem was: {filesystem}") + raise TypeError(err_msg) + filesystem = PyFileSystem(FSSpecHandler(filesystem)) resolved_paths = [] @@ -266,9 +277,10 @@ def _expand_paths(paths: Union[str, List[str]], return expanded_paths, file_infos -def _expand_directory(path: str, - filesystem: "pyarrow.fs.FileSystem", - exclude_prefixes: List[str] = [".", "_"]) -> List[str]: +def _expand_directory( + path: str, + filesystem: "pyarrow.fs.FileSystem", + exclude_prefixes: Optional[List[str]] = None) -> List[str]: """ Expand the provided directory path to a list of file paths. @@ -283,6 +295,9 @@ def _expand_directory(path: str, Returns: A list of file paths contained in the provided directory. """ + if exclude_prefixes is None: + exclude_prefixes = [".", "_"] + from pyarrow.fs import FileSelector selector = FileSelector(path, recursive=True) files = filesystem.get_file_info(selector) @@ -295,7 +310,7 @@ def _expand_directory(path: str, if not file_path.startswith(base_path): continue relative = file_path[len(base_path):] - if any(relative.startswith(prefix) for prefix in [".", "_"]): + if any(relative.startswith(prefix) for prefix in exclude_prefixes): continue filtered_paths.append((file_path, file_)) # We sort the paths to guarantee a stable order. diff --git a/python/ray/data/datasource/numpy_datasource.py b/python/ray/data/datasource/numpy_datasource.py index 08bc7f2c0916e..8ba02e9d40cc5 100644 --- a/python/ray/data/datasource/numpy_datasource.py +++ b/python/ray/data/datasource/numpy_datasource.py @@ -7,7 +7,7 @@ import pyarrow from ray.data.block import BlockAccessor -from ray.data.datasource.file_based_datasource import (FileBasedDatasource) +from ray.data.datasource.file_based_datasource import FileBasedDatasource class NumpyDatasource(FileBasedDatasource): @@ -21,17 +21,22 @@ class NumpyDatasource(FileBasedDatasource): """ def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args): + from ray.data.extensions import TensorArray + import pyarrow as pa # TODO(ekl) Ideally numpy can read directly from the file, but it # seems like it requires the file to be seekable. buf = BytesIO() data = f.readall() buf.write(data) buf.seek(0) - return np.load(buf) + return pa.Table.from_pydict({ + "value": TensorArray(np.load(buf, allow_pickle=True)) + }) def _write_block(self, f: "pyarrow.NativeFile", block: BlockAccessor, - **writer_args): - np.save(f, block.to_arrow()) + column: str, **writer_args): + value = block.to_numpy(column) + np.save(f, value) def _file_format(self): return "npy" diff --git a/python/ray/data/examples/demo_infer.py b/python/ray/data/examples/demo_infer.py index 18237f7898541..352d8ddf31ec6 100644 --- a/python/ray/data/examples/demo_infer.py +++ b/python/ray/data/examples/demo_infer.py @@ -18,7 +18,7 @@ def __call__(self, x): return x -ds = ds.pipeline(parallelism=10) \ +ds = ds.window(blocks_per_window=10) \ .map(preprocess) \ .map(Model, compute="actors", num_gpus=1) diff --git a/python/ray/data/extensions/tensor_extension.py b/python/ray/data/extensions/tensor_extension.py index 3c80fed64242f..9872cf7e225ef 100644 --- a/python/ray/data/extensions/tensor_extension.py +++ b/python/ray/data/extensions/tensor_extension.py @@ -140,7 +140,7 @@ class TensorDtype(pd.api.extensions.ExtensionDtype): one: int64 two: extension> - >>> read_df = ray.get(read_ds.to_pandas())[0] + >>> read_df = ray.get(read_ds.to_pandas_refs())[0] >>> read_df.dtypes one int64 two TensorDtype @@ -422,7 +422,7 @@ class TensorArray(pd.api.extensions.ExtensionArray, TensorOpsMixin): one: int64 two: extension> - >>> read_df = ray.get(read_ds.to_pandas())[0] + >>> read_df = ray.get(read_ds.to_pandas_refs())[0] >>> read_df.dtypes one int64 two TensorDtype @@ -1155,6 +1155,10 @@ def __arrow_ext_class__(self): """ return ArrowTensorArray + def __str__(self): + return "".format( + self.shape, self.storage_type.value_type) + @PublicAPI(stability="beta") class ArrowTensorArray(pa.ExtensionArray): diff --git a/python/ray/data/impl/arrow_block.py b/python/ray/data/impl/arrow_block.py index 41c5875bb6c16..a9d0634930a49 100644 --- a/python/ray/data/impl/arrow_block.py +++ b/python/ray/data/impl/arrow_block.py @@ -13,7 +13,6 @@ from ray.data.block import Block, BlockAccessor, BlockMetadata from ray.data.impl.block_builder import BlockBuilder from ray.data.impl.simple_block import SimpleBlockBuilder -from ray.data.impl.tensor_block import TensorBlockBuilder if TYPE_CHECKING: import pandas @@ -78,8 +77,6 @@ def add(self, item: Any) -> None: self._builder = ArrowBlockBuilder() except (TypeError, pyarrow.lib.ArrowInvalid): self._builder = SimpleBlockBuilder() - elif isinstance(item, np.ndarray): - self._builder = TensorBlockBuilder() else: self._builder = SimpleBlockBuilder() self._builder.add(item) @@ -188,8 +185,21 @@ def schema(self) -> "pyarrow.lib.Schema": def to_pandas(self) -> "pandas.DataFrame": return self._table.to_pandas() - def to_numpy(self) -> np.ndarray: - return np.array(self._table) + def to_numpy(self, column: str = None) -> np.ndarray: + if not column: + raise ValueError( + "`column` must be specified when calling .to_numpy() " + "on Arrow blocks.") + if column not in self._table.column_names: + raise ValueError( + "Cannot find column {}, available columns: {}".format( + column, self._table.column_names)) + array = self._table[column] + if array.num_chunks > 1: + # TODO(ekl) combine fails since we can't concat ArrowTensorType? + array = array.combine_chunks() + assert array.num_chunks == 1, array + return self._table[column].chunk(0).to_numpy() def to_arrow(self) -> "pyarrow.Table": return self._table diff --git a/python/ray/data/impl/block_list.py b/python/ray/data/impl/block_list.py index b6e88c8fe4fc0..691a710b5faa6 100644 --- a/python/ray/data/impl/block_list.py +++ b/python/ray/data/impl/block_list.py @@ -42,6 +42,13 @@ def split(self, split_size: int) -> List["BlockList"]: output.append(BlockList(b.tolist(), m.tolist())) return output + def divide(self, block_idx: int) -> ("BlockList", "BlockList"): + self._check_if_cleared() + return (BlockList(self._blocks[:block_idx], + self._metadata[:block_idx]), + BlockList(self._blocks[block_idx:], + self._metadata[block_idx:])) + def __len__(self): self._check_if_cleared() return len(self._blocks) diff --git a/python/ray/data/impl/compute.py b/python/ray/data/impl/compute.py index e52aa3bce13d1..8f0a7fb8e41f0 100644 --- a/python/ray/data/impl/compute.py +++ b/python/ray/data/impl/compute.py @@ -35,6 +35,10 @@ def _map_block(block: Block, meta: BlockMetadata, class TaskPool(ComputeStrategy): def apply(self, fn: Any, remote_args: dict, blocks: BlockList[Any]) -> BlockList[Any]: + # Handle empty datasets. + if len(blocks) == 0: + return blocks + map_bar = ProgressBar("Map Progress", total=len(blocks)) kwargs = remote_args.copy() @@ -47,8 +51,23 @@ def apply(self, fn: Any, remote_args: dict, ] new_blocks, new_metadata = zip(*refs) - map_bar.block_until_complete(list(new_blocks)) - new_metadata = ray.get(list(new_metadata)) + new_metadata = list(new_metadata) + try: + new_metadata = map_bar.fetch_until_complete(new_metadata) + except (ray.exceptions.RayTaskError, KeyboardInterrupt) as e: + # One or more mapper tasks failed, or we received a SIGINT signal + # while waiting; either way, we cancel all map tasks. + for ref in new_metadata: + ray.cancel(ref) + # Wait until all tasks have failed or been cancelled. + for ref in new_metadata: + try: + ray.get(ref) + except (ray.exceptions.RayTaskError, + ray.exceptions.TaskCancelledError): + pass + # Reraise the original task failure exception. + raise e from None return BlockList(list(new_blocks), list(new_metadata)) diff --git a/python/ray/data/impl/lazy_block_list.py b/python/ray/data/impl/lazy_block_list.py index 7ccf8e58295ae..0bfd1e0ac1093 100644 --- a/python/ray/data/impl/lazy_block_list.py +++ b/python/ray/data/impl/lazy_block_list.py @@ -9,19 +9,25 @@ class LazyBlockList(BlockList[T]): - def __init__(self, calls: Callable[[], ObjectRef[Block]], - metadata: List[BlockMetadata]): - assert len(calls) == len(metadata), (calls, metadata) + def __init__(self, + calls: Callable[[], ObjectRef[Block]], + metadata: List[BlockMetadata], + blocks: List[ObjectRef[Block]] = None): self._calls = calls - self._blocks = [calls[0]()] if calls else [] self._metadata = metadata + if blocks: + self._blocks = blocks + else: + self._blocks = [None] * len(calls) + # Immediately compute the first block at least. + if calls: + self._blocks[0] = calls[0]() + assert len(calls) == len(metadata), (calls, metadata) + assert len(calls) == len(self._blocks), (calls, self._blocks) def copy(self) -> "LazyBlockList": - new_list = LazyBlockList.__new__(LazyBlockList) - new_list._calls = self._calls - new_list._blocks = self._blocks - new_list._metadata = self._metadata - return new_list + return LazyBlockList(self._calls.copy(), self._metadata.copy(), + self._blocks.copy()) def clear(self): super().clear() @@ -32,11 +38,22 @@ def split(self, split_size: int) -> List["LazyBlockList"]: num_splits = math.ceil(len(self._calls) / split_size) calls = np.array_split(self._calls, num_splits) meta = np.array_split(self._metadata, num_splits) + blocks = np.array_split(self._blocks, num_splits) output = [] - for c, m in zip(calls, meta): - output.append(LazyBlockList(c.tolist(), m.tolist())) + for c, m, b in zip(calls, meta, blocks): + output.append(LazyBlockList(c.tolist(), m.tolist(), b.tolist())) return output + def divide(self, block_idx: int) -> ("BlockList", "BlockList"): + self._check_if_cleared() + left = LazyBlockList(self._calls[:block_idx], + self._metadata[:block_idx], + self._blocks[:block_idx]) + right = LazyBlockList(self._calls[block_idx:], + self._metadata[block_idx:], + self._blocks[block_idx:]) + return left, right + def __len__(self): self._check_if_cleared() return len(self._calls) @@ -64,9 +81,19 @@ def _get_or_compute(self, i: int) -> ObjectRef[Block]: self._check_if_cleared() assert i < len(self._calls), i # Check if we need to compute more blocks. - if i >= len(self._blocks): - start = len(self._blocks) + if not self._blocks[i]: # Exponentially increase the number of blocks computed per batch. - for c in self._calls[start:max(i + 1, start * 2)]: - self._blocks.append(c()) + for j in range(max(i + 1, i * 2)): + if j >= len(self._blocks): + break + if not self._blocks[j]: + self._blocks[j] = self._calls[j]() + assert self._blocks[i], self._blocks return self._blocks[i] + + def _num_computed(self): + i = 0 + for b in self._blocks: + if b is not None: + i += 1 + return i diff --git a/python/ray/data/impl/pipeline_executor.py b/python/ray/data/impl/pipeline_executor.py index c02b04ffdabb4..7eeacc0a8cac1 100644 --- a/python/ray/data/impl/pipeline_executor.py +++ b/python/ray/data/impl/pipeline_executor.py @@ -10,7 +10,7 @@ from ray.data.dataset_pipeline import DatasetPipeline -@ray.remote +@ray.remote(num_cpus=0, placement_group=None) def pipeline_stage(fn: Callable[[], Dataset[T]]) -> Dataset[T]: try: prev = set_progress_bars(False) @@ -27,12 +27,15 @@ def __init__(self, pipeline: "DatasetPipeline[T]"): self._iter = iter(self._pipeline._base_iterable) self._stages[0] = pipeline_stage.remote(next(self._iter)) + if self._pipeline._length and self._pipeline._length != float("inf"): + length = self._pipeline._length + else: + length = 1 + if self._pipeline._progress_bars: self._bars = [ - ProgressBar( - "Stage {}".format(i), - self._pipeline._length or 1, - position=i) for i in range(len(self._stages)) + ProgressBar("Stage {}".format(i), length, position=i) + for i in range(len(self._stages)) ] else: self._bars = None @@ -84,7 +87,7 @@ def __next__(self): return output -@ray.remote +@ray.remote(num_cpus=0, placement_group=None) class PipelineSplitExecutorCoordinator: def __init__(self, pipeline: "DatasetPipeline[T]", n: int, splitter: Callable[[Dataset], "DatasetPipeline[T]"]): diff --git a/python/ray/data/impl/progress_bar.py b/python/ray/data/impl/progress_bar.py index c9c1caa43cb5b..fc28da681f3ee 100644 --- a/python/ray/data/impl/progress_bar.py +++ b/python/ray/data/impl/progress_bar.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Any import ray from ray.types import ObjectRef @@ -50,6 +50,16 @@ def block_until_complete(self, remaining: List[ObjectRef]) -> None: done, remaining = ray.wait(remaining, fetch_local=False) self.update(len(done)) + def fetch_until_complete(self, refs: List[ObjectRef]) -> List[Any]: + ref_to_result = {} + remaining = refs + while remaining: + done, remaining = ray.wait(remaining, fetch_local=True) + for ref, result in zip(done, ray.get(done)): + ref_to_result[ref] = result + self.update(len(done)) + return [ref_to_result[ref] for ref in refs] + def set_description(self, name: str) -> None: if self._bar: self._bar.set_description(name) diff --git a/python/ray/data/impl/remote_fn.py b/python/ray/data/impl/remote_fn.py index 968380e187c50..a6b4eb06d0f46 100644 --- a/python/ray/data/impl/remote_fn.py +++ b/python/ray/data/impl/remote_fn.py @@ -13,7 +13,10 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: which means ray.remote cannot be used top-level in ray.data). """ if fn not in CACHED_FUNCTIONS: - default_ray_remote_args = {"retry_exceptions": True} + default_ray_remote_args = { + "retry_exceptions": True, + "placement_group": None, + } CACHED_FUNCTIONS[fn] = ray.remote(**{ **default_ray_remote_args, **ray_remote_args diff --git a/python/ray/data/impl/simple_block.py b/python/ray/data/impl/simple_block.py index ba20d1334b06b..f609c65bd28b8 100644 --- a/python/ray/data/impl/simple_block.py +++ b/python/ray/data/impl/simple_block.py @@ -58,7 +58,9 @@ def to_pandas(self) -> "pandas.DataFrame": import pandas return pandas.DataFrame(self._items) - def to_numpy(self) -> np.ndarray: + def to_numpy(self, column: str = None) -> np.ndarray: + if column: + raise ValueError("`column` arg not supported for list block") return np.array(self._items) def to_arrow(self) -> "pyarrow.Table": diff --git a/python/ray/data/impl/tensor_block.py b/python/ray/data/impl/tensor_block.py deleted file mode 100644 index 3ad8d8afad71b..0000000000000 --- a/python/ray/data/impl/tensor_block.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Iterator, List, TypeVar, Dict, TYPE_CHECKING - -import numpy as np - -if TYPE_CHECKING: - import pandas - import pyarrow - -from ray.data.block import Block, BlockAccessor -from ray.data.impl.block_builder import BlockBuilder - -T = TypeVar("T") - - -# TODO(ekl) switch to pyarrow.Tensor as the block type; currently there is a -# serialization issue with pyarrow tensors. -class TensorBlockBuilder(BlockBuilder[T]): - def __init__(self): - self._rows = [] - self._tensors: List[np.ndarray] = [] - self._num_rows = 0 - - def add(self, row: np.ndarray) -> None: - self._rows.append(row) - self._num_rows += 1 - - def add_block(self, block: np.ndarray) -> None: - assert isinstance(block, np.ndarray), block - self._tensors.append(block) - self._num_rows += len(block) - - def build(self) -> Block: - tensors = self._tensors.copy() - if self._rows: - tensors.append(np.stack(self._rows, axis=0)) - return np.concatenate(tensors, axis=0) - - def num_rows(self) -> int: - return self._num_rows - - -class TensorBlockAccessor(BlockAccessor): - def __init__(self, tensor: np.ndarray): - self._tensor = tensor - - def iter_rows(self) -> Iterator[np.ndarray]: - return iter(self._tensor) - - def slice(self, start: int, end: int, - copy: bool) -> "TensorBlockAccessor[T]": - view = self._tensor[start:end] - if copy: - view = view.copy() - return view - - def to_pandas(self) -> "pandas.DataFrame": - import pandas - return pandas.DataFrame(self._tensor) - - def to_numpy(self) -> np.ndarray: - return self._tensor - - def to_arrow(self) -> "pyarrow.Tensor": - import pyarrow - return pyarrow.Tensor.from_numpy(self._tensor) - - def schema(self) -> Dict: - shape = self._tensor.shape - shape = (None, ) + shape[1:] - return {"shape": shape, "dtype": self._tensor.dtype.name} - - def num_rows(self) -> int: - return len(self._tensor) - - def size_bytes(self) -> int: - return self._tensor.nbytes - - @staticmethod - def builder() -> TensorBlockBuilder[T]: - return TensorBlockBuilder() diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 887e08baa1495..8d1b66d04c044 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -14,7 +14,7 @@ import ray from ray.types import ObjectRef -from ray.util.annotations import PublicAPI +from ray.util.annotations import PublicAPI, DeveloperAPI from ray.data.block import Block, BlockAccessor, BlockMetadata from ray.data.dataset import Dataset from ray.data.datasource import Datasource, RangeDatasource, \ @@ -392,7 +392,7 @@ def read_numpy(paths: Union[str, List[str]], *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, - **numpy_load_args) -> Dataset[np.ndarray]: + **numpy_load_args) -> Dataset[ArrowRow]: """Create an Arrow dataset from csv files. Examples: @@ -509,12 +509,27 @@ def from_modin(df: "modin.DataFrame") -> Dataset[ArrowRow]: from modin.distributed.dataframe.pandas.partitions import unwrap_partitions parts = unwrap_partitions(df, axis=0) - return from_pandas(parts) + return from_pandas_refs(parts) @PublicAPI(stability="beta") -def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]: - """Create a dataset from a set of Pandas dataframes. +def from_pandas(dfs: List["pandas.DataFrame"]) -> Dataset[ArrowRow]: + """Create a dataset from a list of Pandas dataframes. + + Args: + dfs: A list of Pandas dataframes. + + Returns: + Dataset holding Arrow records read from the dataframes. + """ + return from_pandas_refs([ray.put(df) for df in dfs]) + + +@DeveloperAPI +def from_pandas_refs( + dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]: + """Create a dataset from a list of Ray object references to Pandas + dataframes. Args: dfs: A list of Ray object references to pandas dataframes. @@ -529,7 +544,7 @@ def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]: return Dataset(BlockList(blocks, ray.get(list(metadata)))) -def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]: +def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[ArrowRow]: """Create a dataset from a set of NumPy ndarrays. Args: @@ -546,8 +561,23 @@ def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]: @PublicAPI(stability="beta") -def from_arrow(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]] - ) -> Dataset[ArrowRow]: +def from_arrow( + tables: List[Union["pyarrow.Table", bytes]]) -> Dataset[ArrowRow]: + """Create a dataset from a list of Arrow tables. + + Args: + tables: A list of Ray object references to Arrow tables, + or its streaming format in bytes. + + Returns: + Dataset holding Arrow records from the tables. + """ + return from_arrow_refs([ray.put(t) for t in tables]) + + +@DeveloperAPI +def from_arrow_refs(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]] + ) -> Dataset[ArrowRow]: """Create a dataset from a set of Arrow tables. Args: @@ -590,8 +620,11 @@ def _df_to_block(df: "pandas.DataFrame") -> Block[ArrowRow]: def _ndarray_to_block(ndarray: np.ndarray) -> Block[np.ndarray]: - return (ndarray, - BlockAccessor.for_block(ndarray).get_metadata(input_files=None)) + import pyarrow as pa + from ray.data.extensions import TensorArray + table = pa.Table.from_pydict({"value": TensorArray(ndarray)}) + return (table, + BlockAccessor.for_block(table).get_metadata(input_files=None)) def _get_schema(block: Block) -> Any: diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 7562e2c5a7105..91aa91e5c2eb2 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -17,9 +17,11 @@ import ray from ray.tests.conftest import * # noqa +from ray.data.dataset import Dataset from ray.data.datasource import DummyOutputDatasource from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.block import BlockAccessor +from ray.data.impl.block_list import BlockList from ray.data.datasource.file_based_datasource import _unwrap_protocol from ray.data.extensions.tensor_extension import ( TensorArray, TensorDtype, ArrowTensorType, ArrowTensorArray) @@ -29,7 +31,7 @@ def maybe_pipeline(ds, enabled): if enabled: - return ds.pipeline(parallelism=1) + return ds.window(blocks_per_window=1) else: return ds @@ -58,7 +60,10 @@ def run(): assert sorted(ds.iter_rows()) == [0, 1, 2, 3, 4] pg = ray.util.placement_group([{"CPU": 1}]) - ray.get(run.options(placement_group=pg).remote()) + ray.get( + run.options( + placement_group=pg, + placement_group_capture_child_tasks=True).remote()) @pytest.mark.parametrize("pipelined", [False, True]) @@ -142,6 +147,102 @@ def __call__(self, x): assert len(actor_reuse) == 10, actor_reuse +def test_transform_failure(shutdown_only): + ray.init(num_cpus=2) + ds = ray.data.from_items([0, 10], parallelism=2) + + def mapper(x): + time.sleep(x) + raise ValueError("oops") + return x + + with pytest.raises(ray.exceptions.RayTaskError): + ds.map(mapper) + + +@pytest.mark.parametrize( + "block_sizes,num_splits", + [ + ( # Test baseline. + [3, 6, 3], 3), + ( # Already balanced. + [3, 3, 3], 3), + ( # Row truncation. + [3, 6, 4], 3), + ( # Row truncation, smaller number of blocks. + [3, 6, 2, 3], 3), + ( # Row truncation, larger number of blocks. + [5, 6, 2, 5], 5), + ( # All smaller but one. + [1, 1, 1, 1, 6], 5), + ( # All larger but one. + [4, 4, 4, 4, 1], 5), + ( # Single block. + [2], 2), + ( # Single split. + [2, 5], 1), + ]) +def test_equal_split_balanced(ray_start_regular_shared, block_sizes, + num_splits): + _test_equal_split_balanced(block_sizes, num_splits) + + +def _test_equal_split_balanced(block_sizes, num_splits): + blocks = [] + metadata = [] + total_rows = 0 + for block_size in block_sizes: + block = list(range(total_rows, total_rows + block_size)) + blocks.append(ray.put(block)) + metadata.append(BlockAccessor.for_block(block).get_metadata(None)) + total_rows += block_size + block_list = BlockList(blocks, metadata) + ds = Dataset(block_list) + + splits = ds.split(num_splits, equal=True) + split_counts = [split.count() for split in splits] + assert len(split_counts) == num_splits + expected_block_size = total_rows // num_splits + # Check that all splits are the expected size. + assert all([count == expected_block_size for count in split_counts]) + expected_total_rows = sum(split_counts) + # Check that the expected number of rows were dropped. + assert total_rows - expected_total_rows == total_rows % num_splits + # Check that all rows are unique (content check). + split_rows = [row for split in splits for row in split.take(total_rows)] + assert len(set(split_rows)) == len(split_rows) + + +def test_equal_split_balanced_grid(ray_start_regular_shared): + + # Tests balanced equal splitting over a grid of configurations. + # Grid: num_blocks x num_splits x num_rows_block_1 x ... x num_rows_block_n + seed = int(time.time()) + print(f"Seeding RNG for test_equal_split_balanced_grid with: {seed}") + random.seed(seed) + max_num_splits = 20 + num_splits_samples = 5 + max_num_blocks = 50 + max_num_rows_per_block = 100 + num_blocks_samples = 5 + block_sizes_samples = 5 + for num_splits in np.random.randint( + 2, max_num_splits + 1, size=num_splits_samples): + for num_blocks in np.random.randint( + 1, max_num_blocks + 1, size=num_blocks_samples): + block_sizes_list = [ + np.random.randint( + 1, max_num_rows_per_block + 1, size=num_blocks) + for _ in range(block_sizes_samples) + ] + for block_sizes in block_sizes_list: + if sum(block_sizes) < num_splits: + min_ = math.ceil(num_splits / num_blocks) + block_sizes = np.random.randint( + min_, max_num_rows_per_block + 1, size=num_blocks) + _test_equal_split_balanced(block_sizes, num_splits) + + @pytest.mark.parametrize("pipelined", [False, True]) def test_basic(ray_start_regular_shared, pipelined): ds0 = ray.data.range(5) @@ -195,30 +296,15 @@ def test_batch_tensors(ray_start_regular_shared): def test_tensors(ray_start_regular_shared): # Create directly. ds = ray.data.range_tensor(5, shape=(3, 5)) - assert str(ds) == ("Dataset(num_blocks=5, num_rows=5, " - "schema=)") - - # Transform. - ds = ds.map_batches(lambda t: np.expand_dims(t, 3)) - assert str(ds) == ("Dataset(num_blocks=5, num_rows=5, " - "schema=)") + assert str(ds) == ( + "Dataset(num_blocks=5, num_rows=5, " + "schema={value: })") # Pandas conversion. res = ray.data.range_tensor(10).map_batches( lambda t: t + 2, batch_format="pandas").take(2) - assert str(res) == "[ArrowRow({'0': 2}), ArrowRow({'0': 3})]", res - - # From other formats. - ds = ray.data.range(10).map_batches(lambda x: np.array(x)) - assert str(ds) == ("Dataset(num_blocks=10, num_rows=10, " - "schema=)") - ds = ray.data.range(10).map(lambda x: np.array(x)) - assert str(ds) == ("Dataset(num_blocks=10, num_rows=10, " - "schema=)") - ds = ray.data.from_items([np.zeros(shape=(2, 2, 2)) for _ in range(4)]) - assert str(ds) == ( - "Dataset(num_blocks=4, num_rows=4, " - "schema=)"), ds + assert str(res) == \ + "[ArrowRow({'value': array([2])}), ArrowRow({'value': array([3])})]" def test_tensor_array_ops(ray_start_regular_shared): @@ -308,7 +394,7 @@ def test_tensors_in_tables_from_pandas(ray_start_regular_shared): df = pd.DataFrame({"one": list(range(outer_dim)), "two": list(arr)}) # Cast column to tensor extension dtype. df["two"] = df["two"].astype(TensorDtype()) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) values = [[s["one"], s["two"]] for s in ds.take()] expected = list(zip(list(range(outer_dim)), arr)) for v, e in zip(sorted(values), expected): @@ -322,8 +408,8 @@ def test_tensors_in_tables_pandas_roundtrip(ray_start_regular_shared): num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape) df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)}) - ds = ray.data.from_pandas([ray.put(df)]) - ds_df = ray.get(ds.to_pandas())[0] + ds = ray.data.from_pandas([df]) + ds_df = ds.to_pandas() assert ds_df.equals(df) @@ -335,7 +421,7 @@ def test_tensors_in_tables_parquet_roundtrip(ray_start_regular_shared, num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape) df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)}) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) ds = ray.data.read_parquet(str(tmp_path)) values = [[s["one"], s["two"]] for s in ds.take()] @@ -352,7 +438,7 @@ def test_tensors_in_tables_parquet_with_schema(ray_start_regular_shared, num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape) df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)}) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) schema = pa.schema([ ("one", pa.int32()), @@ -378,7 +464,7 @@ def test_tensors_in_tables_parquet_pickle_manual_serde( "one": list(range(outer_dim)), "two": [pickle.dumps(a) for a in arr] }) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) ds = ray.data.read_parquet(str(tmp_path)) @@ -421,7 +507,7 @@ def test_tensors_in_tables_parquet_bytes_manual_serde(ray_start_regular_shared, "one": list(range(outer_dim)), "two": [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) ds = ray.data.read_parquet(str(tmp_path)) @@ -460,7 +546,7 @@ def test_tensors_in_tables_parquet_bytes_manual_serde_udf( "one": list(range(outer_dim)), tensor_col_name: [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) # Manually deserialize the tensor bytes and cast to a TensorArray. @@ -499,7 +585,7 @@ def test_tensors_in_tables_parquet_bytes_manual_serde_col_schema( "one": list(range(outer_dim)), tensor_col_name: [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) def _block_udf(block: pa.Table): @@ -536,7 +622,7 @@ def test_tensors_in_tables_parquet_bytes_with_schema(ray_start_regular_shared, "one": list(range(outer_dim)), "two": [a.tobytes() for a in arr] }) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds.write_parquet(str(tmp_path)) schema = pa.schema([ ("one", pa.int32()), @@ -574,7 +660,7 @@ def test_tensors_in_tables_to_torch(ray_start_regular_shared, pipelined): "label": [4.0, 5.0, 6.0] }) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ds = maybe_pipeline(ds, pipelined) torchd = ds.to_torch(label_column="label", batch_size=2) @@ -614,7 +700,7 @@ def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined): "label": TensorArray(arr2), }) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ds = maybe_pipeline(ds, pipelined) tfd = ds.to_tf( label_column="label", @@ -639,13 +725,11 @@ def test_numpy_roundtrip(ray_start_regular_shared, fs, data_path): ds = ray.data.range_tensor(10, parallelism=2) ds.write_numpy(data_path, filesystem=fs) ds = ray.data.read_numpy(data_path, filesystem=fs) - assert str(ds) == ("Dataset(num_blocks=2, num_rows=?, " - "schema=)") - - assert str( - ds.take()) == ("[array([0]), array([1]), array([2]), " - "array([3]), array([4]), array([5]), array([6]), " - "array([7]), array([8]), array([9])]"), ds.take() + assert str(ds) == ( + "Dataset(num_blocks=2, num_rows=None, " + "schema={value: })") + assert str(ds.take(2)) == \ + "[ArrowRow({'value': array([0])}), ArrowRow({'value': array([1])})]" def test_numpy_read(ray_start_regular_shared, tmp_path): @@ -654,13 +738,11 @@ def test_numpy_read(ray_start_regular_shared, tmp_path): np.save( os.path.join(path, "test.npy"), np.expand_dims(np.arange(0, 10), 1)) ds = ray.data.read_numpy(path) - assert str(ds) == ("Dataset(num_blocks=1, num_rows=?, " - "schema=)") - - assert str( - ds.take()) == ("[array([0]), array([1]), array([2]), " - "array([3]), array([4]), array([5]), array([6]), " - "array([7]), array([8]), array([9])]"), ds.take() + assert str(ds) == ( + "Dataset(num_blocks=1, num_rows=None, " + "schema={value: })") + assert str(ds.take(2)) == \ + "[ArrowRow({'value': array([0])}), ArrowRow({'value': array([1])})]" @pytest.mark.parametrize("fs,data_path,endpoint_url", [ @@ -682,7 +764,12 @@ def test_numpy_write(ray_start_regular_shared, fs, data_path, endpoint_url): s3 = S3FileSystem(client_kwargs={"endpoint_url": endpoint_url}) arr1 = np.load(s3.open(file_path1)) arr2 = np.load(s3.open(file_path2)) - np.testing.assert_equal(np.concatenate((arr1, arr2)), ds.take()) + assert ds.count() == 10 + assert len(arr1) == 5 + assert len(arr2) == 5 + assert arr1.sum() == 10 + assert arr2.sum() == 35 + assert str(ds.take(1)) == "[ArrowRow({'value': array([0])})]" def test_read_text(ray_start_regular_shared, tmp_path): @@ -733,6 +820,16 @@ def test_empty_dataset(ray_start_regular_shared): assert str(ds) == \ "Dataset(num_blocks=1, num_rows=0, schema=Unknown schema)" + # Test map on empty dataset. + ds = ray.data.from_items([]) + ds = ds.map(lambda x: x) + assert ds.count() == 0 + + # Test filter on empty dataset. + ds = ray.data.from_items([]) + ds = ds.filter(lambda: True) + assert ds.count() == 0 + def test_schema(ray_start_regular_shared): ds = ray.data.range(10) @@ -751,17 +848,17 @@ def test_schema(ray_start_regular_shared): def test_lazy_loading_exponential_rampup(ray_start_regular_shared): ds = ray.data.range(100, parallelism=20) - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 assert ds.take(10) == list(range(10)) - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 assert ds.take(20) == list(range(20)) - assert len(ds._blocks._blocks) == 4 + assert ds._blocks._num_computed() == 4 assert ds.take(30) == list(range(30)) - assert len(ds._blocks._blocks) == 8 + assert ds._blocks._num_computed() == 8 assert ds.take(50) == list(range(50)) - assert len(ds._blocks._blocks) == 16 + assert ds._blocks._num_computed() == 16 assert ds.take(100) == list(range(100)) - assert len(ds._blocks._blocks) == 20 + assert ds._blocks._num_computed() == 20 def test_limit(ray_start_regular_shared): @@ -834,7 +931,16 @@ def test_repartition_arrow(ray_start_regular_shared): def test_from_pandas(ray_start_regular_shared): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) + values = [(r["one"], r["two"]) for r in ds.take(6)] + rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] + assert values == rows + + +def test_from_pandas_refs(ray_start_regular_shared): + df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) + df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) + ds = ray.data.from_pandas_refs([ray.put(df1), ray.put(df2)]) values = [(r["one"], r["two"]) for r in ds.take(6)] rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] assert values == rows @@ -845,13 +951,27 @@ def test_from_numpy(ray_start_regular_shared): arr2 = np.expand_dims(np.arange(4, 8), 1) ds = ray.data.from_numpy([ray.put(arr1), ray.put(arr2)]) values = np.array(ds.take(8)) - np.testing.assert_equal(np.concatenate((arr1, arr2)), values) + for i in range(4): + assert values[i]["value"] == arr1[i] + for i in range(4, 8): + assert values[i]["value"] == arr2[i - 4] def test_from_arrow(ray_start_regular_shared): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_arrow([ + ds = ray.data.from_arrow( + [pa.Table.from_pandas(df1), + pa.Table.from_pandas(df2)]) + values = [(r["one"], r["two"]) for r in ds.take(6)] + rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] + assert values == rows + + +def test_from_arrow_refs(ray_start_regular_shared): + df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) + df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) + ds = ray.data.from_arrow_refs([ ray.put(pa.Table.from_pandas(df1)), ray.put(pa.Table.from_pandas(df2)) ]) @@ -864,20 +984,36 @@ def test_to_pandas(ray_start_regular_shared): n = 5 df = pd.DataFrame({"value": list(range(n))}) ds = ray.data.range_arrow(n) - dfds = pd.concat(ray.get(ds.to_pandas()), ignore_index=True) + dfds = ds.to_pandas() + assert df.equals(dfds) + + # Test limit. + dfds = ds.to_pandas(limit=3) + assert df[:3].equals(dfds) + + # Test limit greater than number of rows. + dfds = ds.to_pandas(limit=6) + assert df.equals(dfds) + + +def test_to_pandas_refs(ray_start_regular_shared): + n = 5 + df = pd.DataFrame({"value": list(range(n))}) + ds = ray.data.range_arrow(n) + dfds = pd.concat(ray.get(ds.to_pandas_refs()), ignore_index=True) assert df.equals(dfds) def test_to_numpy(ray_start_regular_shared): # Tensor Dataset ds = ray.data.range_tensor(10, parallelism=2) - arr = np.concatenate(ray.get(ds.to_numpy())) + arr = np.concatenate(ray.get(ds.to_numpy(column="value"))) np.testing.assert_equal(arr, np.expand_dims(np.arange(0, 10), 1)) # Table Dataset ds = ray.data.range_arrow(10) - arr = np.concatenate(ray.get(ds.to_numpy())) - np.testing.assert_equal(arr, np.expand_dims(np.arange(0, 10), 1)) + arr = np.concatenate(ray.get(ds.to_numpy(column="value"))) + np.testing.assert_equal(arr, np.arange(0, 10)) # Simple Dataset ds = ray.data.range(10) @@ -888,23 +1024,41 @@ def test_to_numpy(ray_start_regular_shared): def test_to_arrow(ray_start_regular_shared): n = 5 + # Zero-copy. + df = pd.DataFrame({"value": list(range(n))}) + ds = ray.data.range_arrow(n) + dfds = pd.concat([t.to_pandas() for t in ds.to_arrow()], ignore_index=True) + assert df.equals(dfds) + + # Conversion. + df = pd.DataFrame({0: list(range(n))}) + ds = ray.data.range(n) + dfds = pd.concat([t.to_pandas() for t in ds.to_arrow()], ignore_index=True) + assert df.equals(dfds) + + +def test_to_arrow_refs(ray_start_regular_shared): + n = 5 + # Zero-copy. df = pd.DataFrame({"value": list(range(n))}) ds = ray.data.range_arrow(n) dfds = pd.concat( - [t.to_pandas() for t in ray.get(ds.to_arrow())], ignore_index=True) + [t.to_pandas() for t in ray.get(ds.to_arrow_refs())], + ignore_index=True) assert df.equals(dfds) # Conversion. df = pd.DataFrame({0: list(range(n))}) ds = ray.data.range(n) dfds = pd.concat( - [t.to_pandas() for t in ray.get(ds.to_arrow())], ignore_index=True) + [t.to_pandas() for t in ray.get(ds.to_arrow_refs())], + ignore_index=True) assert df.equals(dfds) -def test_get_blocks(ray_start_regular_shared): - blocks = ray.data.range(10).get_blocks() +def test_get_internal_block_refs(ray_start_regular_shared): + blocks = ray.data.range(10).get_internal_block_refs() assert len(blocks) == 10 out = [] for b in ray.get(blocks): @@ -916,9 +1070,9 @@ def test_get_blocks(ray_start_regular_shared): def test_pandas_roundtrip(ray_start_regular_shared, tmp_path): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) - dfds = pd.concat(ray.get(ds.to_pandas())) - assert pd.concat([df1, df2]).equals(dfds) + ds = ray.data.from_pandas([df1, df2]) + dfds = ds.to_pandas() + assert pd.concat([df1, df2], ignore_index=True).equals(dfds) def test_fsspec_filesystem(ray_start_regular_shared, tmp_path): @@ -942,7 +1096,7 @@ def test_fsspec_filesystem(ray_start_regular_shared, tmp_path): ds = ray.data.read_parquet([path1, path2], filesystem=fs) # Test metadata-only parquet ops. - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 assert ds.count() == 6 out_path = os.path.join(tmp_path, "out") @@ -981,7 +1135,7 @@ def test_parquet_read(ray_start_regular_shared, fs, data_path): ds = ray.data.read_parquet(data_path, filesystem=fs) # Test metadata-only parquet ops. - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 assert ds.count() == 6 assert ds.size_bytes() > 0 assert ds.schema() is not None @@ -995,11 +1149,11 @@ def test_parquet_read(ray_start_regular_shared, fs, data_path): assert repr(ds) == \ "Dataset(num_blocks=2, num_rows=6, " \ "schema={one: int64, two: string})", ds - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 # Forces a data read. values = [[s["one"], s["two"]] for s in ds.take()] - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 assert sorted(values) == [[1, "a"], [2, "b"], [3, "c"], [4, "e"], [5, "f"], [6, "g"]] @@ -1030,7 +1184,7 @@ def test_parquet_read_partitioned(ray_start_regular_shared, fs, data_path): ds = ray.data.read_parquet(data_path, filesystem=fs) # Test metadata-only parquet ops. - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 assert ds.count() == 6 assert ds.size_bytes() > 0 assert ds.schema() is not None @@ -1044,11 +1198,11 @@ def test_parquet_read_partitioned(ray_start_regular_shared, fs, data_path): "Dataset(num_blocks=2, num_rows=6, " \ "schema={two: string, " \ "one: dictionary})", ds - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 # Forces a data read. values = [[s["one"], s["two"]] for s in ds.take()] - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 assert sorted(values) == [[1, "a"], [1, "b"], [1, "c"], [3, "e"], [3, "f"], [3, "g"]] @@ -1077,7 +1231,7 @@ def test_parquet_read_partitioned_with_filter(ray_start_regular_shared, str(tmp_path), parallelism=1, filter=(pa.dataset.field("two") == "a")) values = [[s["one"], s["two"]] for s in ds.take()] - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 assert sorted(values) == [[1, "a"], [1, "a"]] # 2 partitions, 1 empty partition, 2 block/read tasks, 1 empty block @@ -1086,7 +1240,7 @@ def test_parquet_read_partitioned_with_filter(ray_start_regular_shared, str(tmp_path), parallelism=2, filter=(pa.dataset.field("two") == "a")) values = [[s["one"], s["two"]] for s in ds.take()] - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 assert sorted(values) == [[1, "a"], [1, "a"]] @@ -1114,7 +1268,7 @@ def _block_udf(block: pa.Table): str(tmp_path), parallelism=1, _block_udf=_block_udf) ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()]) - assert len(ds._blocks._blocks) == 1 + assert ds._blocks._num_computed() == 1 np.testing.assert_array_equal(sorted(ones), np.array(one_data) + 1) # 2 blocks/read tasks @@ -1123,7 +1277,7 @@ def _block_udf(block: pa.Table): str(tmp_path), parallelism=2, _block_udf=_block_udf) ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()]) - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 np.testing.assert_array_equal(sorted(ones), np.array(one_data) + 1) # 2 blocks/read tasks, 1 empty block @@ -1135,7 +1289,7 @@ def _block_udf(block: pa.Table): _block_udf=_block_udf) ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()]) - assert len(ds._blocks._blocks) == 2 + assert ds._blocks._num_computed() == 2 np.testing.assert_array_equal(sorted(ones), np.array(one_data[:2]) + 1) @@ -1152,7 +1306,7 @@ def test_parquet_write(ray_start_regular_shared, fs, data_path, endpoint_url): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) path = os.path.join(data_path, "test_parquet_dir") if fs is None: os.mkdir(path) @@ -1187,7 +1341,7 @@ def test_parquet_write_create_dir(ray_start_regular_shared, fs, data_path, df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) path = os.path.join(data_path, "test_parquet_dir") ds._set_uuid("data") ds.write_parquet(path, filesystem=fs) @@ -1241,7 +1395,7 @@ def test_parquet_write_with_udf(ray_start_regular_shared, tmp_path): df1 = pd.DataFrame({"one": one_data[:3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": one_data[3:], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) def _block_udf(block: pa.Table): df = block.to_pandas() @@ -1266,7 +1420,7 @@ def _block_udf(block: pa.Table): def test_parquet_roundtrip(ray_start_regular_shared, fs, data_path): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ds._set_uuid("data") path = os.path.join(data_path, "test_parquet_dir") if fs is None: @@ -1275,8 +1429,8 @@ def test_parquet_roundtrip(ray_start_regular_shared, fs, data_path): fs.create_dir(_unwrap_protocol(path)) ds.write_parquet(path, filesystem=fs) ds2 = ray.data.read_parquet(path, parallelism=2, filesystem=fs) - ds2df = pd.concat(ray.get(ds2.to_pandas())) - assert pd.concat([df1, df2]).equals(ds2df) + ds2df = ds2.to_pandas() + assert pd.concat([df1, df2], ignore_index=True).equals(ds2df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() == meta.size_bytes @@ -1357,9 +1511,7 @@ def test_iter_batches_basic(ray_start_regular_shared): df3 = pd.DataFrame({"one": [7, 8, 9], "two": [8, 9, 10]}) df4 = pd.DataFrame({"one": [10, 11, 12], "two": [11, 12, 13]}) dfs = [df1, df2, df3, df4] - ds = ray.data.from_pandas( - [ray.put(df1), ray.put(df2), - ray.put(df3), ray.put(df4)]) + ds = ray.data.from_pandas(dfs) # Default. for batch, df in zip(ds.iter_batches(batch_format="pandas"), dfs): @@ -1469,7 +1621,7 @@ def test_iter_batches_grid(ray_start_regular_shared): })) running_size += block_size num_rows = running_size - ds = ray.data.from_pandas([ray.put(df) for df in dfs]) + ds = ray.data.from_pandas(dfs) for batch_size in np.random.randint( 1, num_rows + 1, size=batch_size_samples): for drop_last in (False, True): @@ -1485,10 +1637,7 @@ def test_iter_batches_grid(ray_start_regular_shared): # Concatenated batches should equal the DataFrame # representation of the entire dataset. assert pd.concat( - batches, ignore_index=True).equals( - pd.concat( - ray.get(ds.to_pandas()), - ignore_index=True)) + batches, ignore_index=True).equals(ds.to_pandas()) else: # Number of batches should be equal to # num_rows / batch_size, rounded down. @@ -1498,9 +1647,8 @@ def test_iter_batches_grid(ray_start_regular_shared): # remainder sliced off. assert pd.concat( batches, ignore_index=True).equals( - pd.concat( - ray.get(ds.to_pandas()), ignore_index=True) - [:batch_size * (num_rows // batch_size)]) + ds.to_pandas()[:batch_size * + (num_rows // batch_size)]) if num_rows % batch_size == 0 or drop_last: assert all( len(batch) == batch_size for batch in batches) @@ -1515,7 +1663,7 @@ def test_lazy_loading_iter_batches_exponential_rampup( ds = ray.data.range(32, parallelism=8) expected_num_blocks = [1, 2, 4, 4, 8, 8, 8, 8] for _, expected in zip(ds.iter_batches(), expected_num_blocks): - assert len(ds._blocks._blocks) == expected + assert ds._blocks._num_computed() == expected def test_map_batch(ray_start_regular_shared, tmp_path): @@ -1769,7 +1917,7 @@ def test_from_dask(ray_start_regular_shared): df = pd.DataFrame({"one": list(range(100)), "two": list(range(100))}) ddf = dd.from_pandas(df, npartitions=10) ds = ray.data.from_dask(ddf) - dfds = pd.concat(ray.get(ds.to_pandas())) + dfds = ds.to_pandas() assert df.equals(dfds) @@ -1778,7 +1926,7 @@ def test_to_dask(ray_start_regular_shared): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ddf = ds.to_dask() # Explicit Dask-on-Ray assert df.equals(ddf.compute(scheduler=ray_dask_get)) @@ -1791,7 +1939,7 @@ def test_from_modin(ray_start_regular_shared): df = pd.DataFrame({"one": list(range(100)), "two": list(range(100))}, ) modf = mopd.DataFrame(df) ds = ray.data.from_modin(modf) - dfds = pd.concat(ray.get(ds.to_pandas())) + dfds = ds.to_pandas() assert df.equals(dfds) @@ -1823,7 +1971,7 @@ def test_to_tf(ray_start_regular_shared, pipelined): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) + ds = ray.data.from_pandas([df1, df2, df3]) ds = maybe_pipeline(ds, pipelined) tfd = ds.to_tf( label_column="label", @@ -1851,7 +1999,7 @@ def test_to_tf_feature_columns(ray_start_regular_shared): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]).drop("two", axis=1) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) + ds = ray.data.from_pandas([df1, df2, df3]) tfd = ds.to_tf( label_column="label", feature_columns=["one"], @@ -1880,7 +2028,7 @@ def test_to_torch(ray_start_regular_shared, pipelined): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) + ds = ray.data.from_pandas([df1, df2, df3]) ds = maybe_pipeline(ds, pipelined) torchd = ds.to_torch(label_column="label", batch_size=3) @@ -1907,7 +2055,7 @@ def test_to_torch_feature_columns(ray_start_regular_shared): }) df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]}) df = pd.concat([df1, df2, df3]).drop("two", axis=1) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2), ray.put(df3)]) + ds = ray.data.from_pandas([df1, df2, df3]) torchd = ds.to_torch( label_column="label", feature_columns=["one"], batch_size=3) iterations = [] @@ -1934,7 +2082,7 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df1.to_json( path1, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json(path1, filesystem=fs) - dsdf = ray.get(ds.to_pandas())[0] + dsdf = ds.to_pandas() assert df1.equals(dsdf) # Test metadata ops. assert ds.count() == 3 @@ -1947,8 +2095,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df2.to_json( path2, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json([path1, path2], parallelism=2, filesystem=fs) - dsdf = pd.concat(ray.get(ds.to_pandas())) - df = pd.concat([df1, df2]) + dsdf = ds.to_pandas() + df = pd.concat([df1, df2], ignore_index=True) assert df.equals(dsdf) # Test metadata ops. for block, meta in zip(ds._blocks, ds._blocks.get_metadata()): @@ -1962,7 +2110,7 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): ds = ray.data.read_json( [path1, path2, path3], parallelism=2, filesystem=fs) df = pd.concat([df1, df2, df3], ignore_index=True) - dsdf = pd.concat(ray.get(ds.to_pandas()), ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) # Directory, two files. @@ -1980,8 +2128,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df2.to_json( path2, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json(path, filesystem=fs) - df = pd.concat([df1, df2]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(path) @@ -2019,8 +2167,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): lines=True, storage_options=storage_options) ds = ray.data.read_json([path1, path2], filesystem=fs) - df = pd.concat([df1, df2, df3]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2, df3], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(path1) @@ -2044,8 +2192,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url): df2.to_json( path2, orient="records", lines=True, storage_options=storage_options) ds = ray.data.read_json([dir_path, path2], filesystem=fs) - df = pd.concat([df1, df2]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(dir_path) @@ -2059,7 +2207,7 @@ def test_zipped_json_read(ray_start_regular_shared, tmp_path): path1 = os.path.join(tmp_path, "test1.json.gz") df1.to_json(path1, compression="gzip", orient="records", lines=True) ds = ray.data.read_json(path1) - assert df1.equals(ray.get(ds.to_pandas())[0]) + assert df1.equals(ds.to_pandas()) # Test metadata ops. assert ds.count() == 3 assert ds.input_files() == [path1] @@ -2069,8 +2217,8 @@ def test_zipped_json_read(ray_start_regular_shared, tmp_path): path2 = os.path.join(tmp_path, "test2.json.gz") df2.to_json(path2, compression="gzip", orient="records", lines=True) ds = ray.data.read_json([path1, path2], parallelism=2) - dsdf = pd.concat(ray.get(ds.to_pandas())) - assert pd.concat([df1, df2]).equals(dsdf) + dsdf = ds.to_pandas() + assert pd.concat([df1, df2], ignore_index=True).equals(dsdf) # Test metadata ops. for block, meta in zip(ds._blocks, ds._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() @@ -2085,8 +2233,8 @@ def test_zipped_json_read(ray_start_regular_shared, tmp_path): path2 = os.path.join(tmp_path, "data1.json.gz") df2.to_json(path2, compression="gzip", orient="records", lines=True) ds = ray.data.read_json([dir_path, path2]) - df = pd.concat([df1, df2]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) shutil.rmtree(dir_path) @@ -2103,7 +2251,7 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url): storage_options = dict(client_kwargs=dict(endpoint_url=endpoint_url)) # Single block. df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([ray.put(df1)]) + ds = ray.data.from_pandas([df1]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.json") @@ -2116,7 +2264,7 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) file_path2 = os.path.join(data_path, "data_000001.json") @@ -2143,12 +2291,12 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url): def test_json_roundtrip(ray_start_regular_shared, fs, data_path): # Single block. df = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.json") ds2 = ray.data.read_json([file_path], filesystem=fs) - ds2df = pd.concat(ray.get(ds2.to_pandas())) + ds2df = ds2.to_pandas() assert ds2df.equals(df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): @@ -2161,12 +2309,12 @@ def test_json_roundtrip(ray_start_regular_shared, fs, data_path): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df), ray.put(df2)]) + ds = ray.data.from_pandas([df, df2]) ds._set_uuid("data") ds.write_json(data_path, filesystem=fs) ds2 = ray.data.read_json(data_path, parallelism=2, filesystem=fs) - ds2df = pd.concat(ray.get(ds2.to_pandas())) - assert pd.concat([df, df2]).equals(ds2df) + ds2df = ds2.to_pandas() + assert pd.concat([df, df2], ignore_index=True).equals(ds2df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() == meta.size_bytes @@ -2190,7 +2338,7 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path1 = os.path.join(data_path, "test1.csv") df1.to_csv(path1, index=False, storage_options=storage_options) ds = ray.data.read_csv(path1, filesystem=fs) - dsdf = ray.get(ds.to_pandas())[0] + dsdf = ds.to_pandas() assert df1.equals(dsdf) # Test metadata ops. assert ds.count() == 3 @@ -2202,8 +2350,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path2 = os.path.join(data_path, "test2.csv") df2.to_csv(path2, index=False, storage_options=storage_options) ds = ray.data.read_csv([path1, path2], parallelism=2, filesystem=fs) - dsdf = pd.concat(ray.get(ds.to_pandas())) - df = pd.concat([df1, df2]) + dsdf = ds.to_pandas() + df = pd.concat([df1, df2], ignore_index=True) assert df.equals(dsdf) # Test metadata ops. for block, meta in zip(ds._blocks, ds._blocks.get_metadata()): @@ -2215,7 +2363,7 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): df3.to_csv(path3, index=False, storage_options=storage_options) ds = ray.data.read_csv([path1, path2, path3], parallelism=2, filesystem=fs) df = pd.concat([df1, df2, df3], ignore_index=True) - dsdf = pd.concat(ray.get(ds.to_pandas()), ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) # Directory, two files. @@ -2231,8 +2379,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path2 = os.path.join(path, "data1.csv") df2.to_csv(path2, index=False, storage_options=storage_options) ds = ray.data.read_csv(path, filesystem=fs) - df = pd.concat([df1, df2]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(path) @@ -2258,8 +2406,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): file_path3 = os.path.join(path2, "data2.csv") df3.to_csv(file_path3, index=False, storage_options=storage_options) ds = ray.data.read_csv([path1, path2], filesystem=fs) - df = pd.concat([df1, df2, df3]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2, df3], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(path1) @@ -2281,8 +2429,8 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url): path2 = os.path.join(data_path, "data1.csv") df2.to_csv(path2, index=False, storage_options=storage_options) ds = ray.data.read_csv([dir_path, path2], filesystem=fs) - df = pd.concat([df1, df2]) - dsdf = pd.concat(ray.get(ds.to_pandas())) + df = pd.concat([df1, df2], ignore_index=True) + dsdf = ds.to_pandas() assert df.equals(dsdf) if fs is None: shutil.rmtree(dir_path) @@ -2302,7 +2450,7 @@ def test_csv_write(ray_start_regular_shared, fs, data_path, endpoint_url): storage_options = dict(client_kwargs=dict(endpoint_url=endpoint_url)) # Single block. df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([ray.put(df1)]) + ds = ray.data.from_pandas([df1]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.csv") @@ -2310,7 +2458,7 @@ def test_csv_write(ray_start_regular_shared, fs, data_path, endpoint_url): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) + ds = ray.data.from_pandas([df1, df2]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) file_path2 = os.path.join(data_path, "data_000001.csv") @@ -2329,12 +2477,12 @@ def test_csv_write(ray_start_regular_shared, fs, data_path, endpoint_url): def test_csv_roundtrip(ray_start_regular_shared, fs, data_path): # Single block. df = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - ds = ray.data.from_pandas([ray.put(df)]) + ds = ray.data.from_pandas([df]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) file_path = os.path.join(data_path, "data_000000.csv") ds2 = ray.data.read_csv([file_path], filesystem=fs) - ds2df = pd.concat(ray.get(ds2.to_pandas())) + ds2df = ds2.to_pandas() assert ds2df.equals(df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): @@ -2342,12 +2490,12 @@ def test_csv_roundtrip(ray_start_regular_shared, fs, data_path): # Two blocks. df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - ds = ray.data.from_pandas([ray.put(df), ray.put(df2)]) + ds = ray.data.from_pandas([df, df2]) ds._set_uuid("data") ds.write_csv(data_path, filesystem=fs) ds2 = ray.data.read_csv(data_path, parallelism=2, filesystem=fs) - ds2df = pd.concat(ray.get(ds2.to_pandas())) - assert pd.concat([df, df2]).equals(ds2df) + ds2df = ds2.to_pandas() + assert pd.concat([df, df2], ignore_index=True).equals(ds2df) # Test metadata ops. for block, meta in zip(ds2._blocks, ds2._blocks.get_metadata()): BlockAccessor.for_block(ray.get(block)).size_bytes() == meta.size_bytes @@ -2365,13 +2513,21 @@ def test_sort_simple(ray_start_regular_shared): assert ds.sort(key=lambda x: -x).take(num_items) == list( reversed(range(num_items))) + # Test empty dataset. + ds = ray.data.from_items([]) + s1 = ds.sort() + assert s1.count() == 0 + assert s1 == ds + @pytest.mark.parametrize("pipelined", [False, True]) def test_random_shuffle(shutdown_only, pipelined): def range(n, parallelism=200): ds = ray.data.range(n, parallelism=parallelism) if pipelined: - return ds.repeat(2) + pipe = ds.repeat(2) + pipe.random_shuffle = pipe.random_shuffle_each_window + return pipe else: return ds @@ -2416,6 +2572,12 @@ def range(n, parallelism=200): r2 = range(100).random_shuffle(_move=True).take(999) assert r1 != r2, (r1, r2) + # Test empty dataset. + ds = ray.data.from_items([]) + r1 = ds.random_shuffle() + assert r1.count() == 0 + assert r1 == ds + def test_random_shuffle_spread(ray_start_cluster): cluster = ray_start_cluster @@ -2437,7 +2599,7 @@ def get_node_id(): ds = ray.data.range( 100, parallelism=2).random_shuffle(_spread_resource_prefix="bar:") - blocks = ds.get_blocks() + blocks = ds.get_internal_block_refs() ray.wait(blocks, num_returns=len(blocks), fetch_local=False) location_data = ray.experimental.get_object_locations(blocks) locations = [] @@ -2478,7 +2640,7 @@ def get_node_id(): ds = ray.data.read_parquet(data_path, _spread_resource_prefix="bar:") # Force reads. - blocks = ds.get_blocks() + blocks = ds.get_internal_block_refs() assert len(blocks) == 2 ray.wait(blocks, num_returns=len(blocks), fetch_local=False) @@ -2505,7 +2667,7 @@ def test_sort_arrow(ray_start_regular, num_items, parallelism): offset += shard if offset < num_items: dfs.append(pd.DataFrame({"a": a[offset:], "b": b[offset:]})) - ds = ray.data.from_pandas([ray.put(df) for df in dfs]) + ds = ray.data.from_pandas(dfs) def assert_sorted(sorted_ds, expected_rows): assert [tuple(row.values()) @@ -2535,7 +2697,7 @@ def __init__(self): def _read_file(self, f: "pa.NativeFile", path: str, **reader_args): count = self.counter.increment.remote() if ray.get(count) == 1: - raise ValueError() + raise ValueError("oops") else: return CSVDatasource._read_file(self, f, path, **reader_args) @@ -2543,7 +2705,7 @@ def _write_block(self, f: "pa.NativeFile", block: BlockAccessor, **writer_args): count = self.counter.increment.remote() if ray.get(count) == 1: - raise ValueError() + raise ValueError("oops") else: CSVDatasource._write_block(self, f, block, **writer_args) @@ -2563,7 +2725,7 @@ def _write_block(self, f: "pa.NativeFile", block: BlockAccessor, def flaky_mapper(x): count = counter.increment.remote() if ray.get(count) == 1: - raise ValueError() + raise ValueError("oops") else: return ray.get(count) diff --git a/python/ray/data/tests/test_dataset_pipeline.py b/python/ray/data/tests/test_dataset_pipeline.py index b199374f80437..cffb378f36861 100644 --- a/python/ray/data/tests/test_dataset_pipeline.py +++ b/python/ray/data/tests/test_dataset_pipeline.py @@ -30,14 +30,14 @@ def block_on_ones(x: int) -> int: time.sleep(999999) return x - pipe = ray.data.range(2).pipeline(parallelism=1) + pipe = ray.data.range(2).window(blocks_per_window=1) pipe = pipe.map(block_on_ones) assert pipe.take(1) == [0] def test_cannot_read_twice(ray_start_regular_shared): ds = ray.data.range(10) - pipe = ds.pipeline(parallelism=1) + pipe = ds.window(blocks_per_window=1) assert pipe.count() == 10 with pytest.raises(RuntimeError): pipe.count() @@ -52,25 +52,70 @@ def test_cannot_read_twice(ray_start_regular_shared): def test_basic_pipeline(ray_start_regular_shared): ds = ray.data.range(10) - pipe = ds.pipeline(parallelism=1) - assert str(pipe) == "DatasetPipeline(length=10, num_stages=1)" + pipe = ds.window(blocks_per_window=1) + assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" assert pipe.count() == 10 - pipe = ds.pipeline(parallelism=1).map(lambda x: x).map(lambda x: x) - assert str(pipe) == "DatasetPipeline(length=10, num_stages=3)" + pipe = ds.window(blocks_per_window=1).map(lambda x: x).map(lambda x: x) + assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=3)" assert pipe.take() == list(range(10)) - pipe = ds.pipeline(parallelism=999) - assert str(pipe) == "DatasetPipeline(length=1, num_stages=1)" + pipe = ds.window(blocks_per_window=999) + assert str(pipe) == "DatasetPipeline(num_windows=1, num_stages=1)" assert pipe.count() == 10 pipe = ds.repeat(10) - assert str(pipe) == "DatasetPipeline(length=10, num_stages=1)" + assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" assert pipe.count() == 100 pipe = ds.repeat(10) assert pipe.sum() == 450 +def test_window(ray_start_regular_shared): + ds = ray.data.range(10) + pipe = ds.window(blocks_per_window=1) + assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" + pipe = pipe.rewindow(blocks_per_window=3) + assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)" + datasets = list(pipe.iter_datasets()) + assert len(datasets) == 4 + assert datasets[0].take() == [0, 1, 2] + assert datasets[1].take() == [3, 4, 5] + assert datasets[2].take() == [6, 7, 8] + assert datasets[3].take() == [9] + + ds = ray.data.range(10) + pipe = ds.window(blocks_per_window=5) + assert str(pipe) == "DatasetPipeline(num_windows=2, num_stages=1)" + pipe = pipe.rewindow(blocks_per_window=3) + assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)" + datasets = list(pipe.iter_datasets()) + assert len(datasets) == 4 + assert datasets[0].take() == [0, 1, 2] + assert datasets[1].take() == [3, 4, 5] + assert datasets[2].take() == [6, 7, 8] + assert datasets[3].take() == [9] + + +def test_repeat(ray_start_regular_shared): + ds = ray.data.range(5) + pipe = ds.window(blocks_per_window=1) + assert str(pipe) == "DatasetPipeline(num_windows=5, num_stages=1)" + pipe = pipe.repeat(2) + assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" + assert pipe.take() == (list(range(5)) + list(range(5))) + + ds = ray.data.range(5) + pipe = ds.window(blocks_per_window=1) + pipe = pipe.repeat() + assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)" + assert len(pipe.take(99)) == 99 + + pipe = ray.data.range(5).repeat() + with pytest.raises(ValueError): + pipe.repeat() + + def test_from_iterable(ray_start_regular_shared): pipe = DatasetPipeline.from_iterable( [lambda: ray.data.range(3), lambda: ray.data.range(2)]) @@ -80,7 +125,7 @@ def test_from_iterable(ray_start_regular_shared): def test_repeat_forever(ray_start_regular_shared): ds = ray.data.range(10) pipe = ds.repeat() - assert str(pipe) == "DatasetPipeline(length=None, num_stages=1)" + assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)" for i, v in enumerate(pipe.iter_rows()): assert v == i % 10, (v, i, i % 10) if i > 1000: @@ -89,38 +134,38 @@ def test_repeat_forever(ray_start_regular_shared): def test_repartition(ray_start_regular_shared): pipe = ray.data.range(10).repeat(10) - assert pipe.repartition(1).sum() == 450 + assert pipe.repartition_each_window(1).sum() == 450 pipe = ray.data.range(10).repeat(10) - assert pipe.repartition(10).sum() == 450 + assert pipe.repartition_each_window(10).sum() == 450 pipe = ray.data.range(10).repeat(10) - assert pipe.repartition(100).sum() == 450 + assert pipe.repartition_each_window(100).sum() == 450 def test_iter_batches(ray_start_regular_shared): - pipe = ray.data.range(10).pipeline(parallelism=2) + pipe = ray.data.range(10).window(blocks_per_window=2) batches = list(pipe.iter_batches()) assert len(batches) == 10 assert all(len(e) == 1 for e in batches) def test_iter_datasets(ray_start_regular_shared): - pipe = ray.data.range(10).pipeline(parallelism=2) + pipe = ray.data.range(10).window(blocks_per_window=2) ds = list(pipe.iter_datasets()) assert len(ds) == 5 - pipe = ray.data.range(10).pipeline(parallelism=5) + pipe = ray.data.range(10).window(blocks_per_window=5) ds = list(pipe.iter_datasets()) assert len(ds) == 2 -def test_foreach_dataset(ray_start_regular_shared): - pipe = ray.data.range(5).pipeline(parallelism=2) - pipe = pipe.foreach_dataset(lambda ds: ds.map(lambda x: x * 2)) +def test_foreach_window(ray_start_regular_shared): + pipe = ray.data.range(5).window(blocks_per_window=2) + pipe = pipe.foreach_window(lambda ds: ds.map(lambda x: x * 2)) assert pipe.take() == [0, 2, 4, 6, 8] def test_schema(ray_start_regular_shared): - pipe = ray.data.range(5).pipeline(parallelism=2) + pipe = ray.data.range(5).window(blocks_per_window=2) assert pipe.schema() == int @@ -178,8 +223,8 @@ def test_parquet_write(ray_start_regular_shared, tmp_path): df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) df = pd.concat([df1, df2]) - ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)]) - ds = ds.pipeline(parallelism=1) + ds = ray.data.from_pandas([df1, df2]) + ds = ds.window(blocks_per_window=1) path = os.path.join(tmp_path, "test_parquet_dir") os.mkdir(path) ds._set_uuid("data") diff --git a/python/ray/data/tests/test_raydp_dataset.py b/python/ray/data/tests/test_raydp_dataset.py index c86c6a0803c13..c23b672f97e38 100644 --- a/python/ray/data/tests/test_raydp_dataset.py +++ b/python/ray/data/tests/test_raydp_dataset.py @@ -16,6 +16,10 @@ def stop_all(): return spark +@pytest.mark.skip( + reason=( + "raydp.spark.spark_dataframe_to_ray_dataset needs to be updated to " + "use ray.data.from_arrow_refs.")) def test_raydp_roundtrip(spark_on_ray_small): spark = spark_on_ray_small spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 58d1706549d2a..f46be7c0a1a15 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -28,7 +28,11 @@ def from_bytes(b): ray_exception = RayException() ray_exception.ParseFromString(b) if ray_exception.language == PYTHON: - return pickle.loads(ray_exception.serialized_exception) + try: + return pickle.loads(ray_exception.serialized_exception) + except Exception as e: + msg = "Failed to unpickle serialized exception" + raise RuntimeError(msg) from e else: return CrossLanguageError(ray_exception) diff --git a/python/ray/experimental/array/remote/core.py b/python/ray/experimental/array/remote/core.py index f4572da82babe..7b6d24f75b283 100644 --- a/python/ray/experimental/array/remote/core.py +++ b/python/ray/experimental/array/remote/core.py @@ -68,8 +68,8 @@ def diag(v, k=0): @ray.remote -def transpose(a, axes=[]): - axes = None if axes == [] else axes +def transpose(a, axes=None): + axes = None if (axes == [] or axes is None) else axes return np.transpose(a, axes=axes) diff --git a/python/ray/experimental/internal_kv.py b/python/ray/experimental/internal_kv.py index e434c3cf5f979..456adabcb66ca 100644 --- a/python/ray/experimental/internal_kv.py +++ b/python/ray/experimental/internal_kv.py @@ -35,7 +35,7 @@ def _initialize_internal_kv(gcs_client: "ray._raylet.GcsClient" = None): return global_gcs_client -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_initialized(): gcs_client = _initialize_internal_kv() @@ -46,7 +46,7 @@ def _internal_kv_initialized(): return hasattr(worker, "mode") and worker.mode is not None -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_get(key: Union[str, bytes]) -> bytes: """Fetch the value of a binary key.""" gcs_client = _initialize_internal_kv() @@ -57,7 +57,7 @@ def _internal_kv_get(key: Union[str, bytes]) -> bytes: return ray.worker.global_worker.redis_client.hget(key, "value") -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_exists(key: Union[str, bytes]) -> bool: """Check key exists or not.""" gcs_client = _initialize_internal_kv() @@ -67,7 +67,7 @@ def _internal_kv_exists(key: Union[str, bytes]) -> bool: return ray.worker.global_worker.redis_client.hexists(key, "value") -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_put(key: Union[str, bytes], value: Union[str, bytes], overwrite: bool = True) -> bool: @@ -91,7 +91,7 @@ def _internal_kv_put(key: Union[str, bytes], return updated == 0 # already exists -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_del(key: Union[str, bytes]): gcs_client = _initialize_internal_kv() if gcs_client is not None: @@ -100,7 +100,7 @@ def _internal_kv_del(key: Union[str, bytes]): return ray.worker.global_worker.redis_client.delete(key) -@client_mode_hook +@client_mode_hook(auto_init=False) def _internal_kv_list(prefix: Union[str, bytes]) -> List[bytes]: """List all keys in the internal KV store that start with the prefix. """ diff --git a/python/ray/experimental/raysort/constants.py b/python/ray/experimental/raysort/constants.py index 5ab3b2df29831..9c32b5f07330e 100644 --- a/python/ray/experimental/raysort/constants.py +++ b/python/ray/experimental/raysort/constants.py @@ -1,12 +1,15 @@ import os -from ray.experimental.raysort.types import ByteCount, RecordCount +from ray.experimental.raysort.types import ByteCount, PartId, RecordCount __DIR__ = os.path.dirname(os.path.abspath(__file__)) # Basics RECORD_SIZE = 100 # bytes +# Progress Tracker Actor +PROGRESS_TRACKER_ACTOR = "ProgressTrackerActor" + # Executable locations GENSORT_PATH = os.path.join(__DIR__, "bin/gensort/64/gensort") VALSORT_PATH = os.path.join(__DIR__, "bin/gensort/64/valsort") @@ -18,10 +21,12 @@ DATA_DIR_FMT = { "input": "{mnt}/tmp/input/", "output": "{mnt}/tmp/output/", + "temp": "{mnt}/tmp/temp/", } FILENAME_FMT = { "input": "input-{part_id:08}", "output": "output-{part_id:08}", + "temp": "temp-{part_id:08}", } # Prometheus config @@ -33,3 +38,7 @@ def bytes_to_records(n_bytes: ByteCount) -> RecordCount: assert n_bytes % RECORD_SIZE == 0 return int(n_bytes / RECORD_SIZE) + + +def merge_part_ids(reducer_id: PartId, mapper_id: PartId) -> PartId: + return reducer_id * 1_000_000 + mapper_id diff --git a/python/ray/experimental/raysort/main.py b/python/ray/experimental/raysort/main.py index 1cc8d0df1c5af..0df5bfb59ec75 100644 --- a/python/ray/experimental/raysort/main.py +++ b/python/ray/experimental/raysort/main.py @@ -1,10 +1,12 @@ import argparse +import contextlib import csv import logging import os import random import subprocess -from typing import Iterable, List +import tempfile +from typing import Callable, Dict, Iterable, List import numpy as np import ray @@ -13,14 +15,17 @@ from ray.experimental.raysort import logging_utils from ray.experimental.raysort import sortlib from ray.experimental.raysort import tracing_utils -from ray.experimental.raysort.types import BlockInfo, ByteCount, RecordCount, PartId, PartitionInfo, Path # noqa: E501 +from ray.experimental.raysort.types import \ + BlockInfo, ByteCount, RecordCount, PartId, PartInfo, Path + +Args = argparse.Namespace # ------------------------------------------------------------ # Parse Arguments # ------------------------------------------------------------ -def get_args(): +def get_args(*args, **kwargs): parser = argparse.ArgumentParser() parser.add_argument( "--ray_address", @@ -30,27 +35,39 @@ def get_args(): ) parser.add_argument( "--total_data_size", - default=1_000_000_000, + default=1 * 1000 * 1024 * 1024 * 1024, type=ByteCount, - help="partition size in bytes", + help="total data size in bytes", ) parser.add_argument( "--num_mappers", - default=4, + default=256, type=int, help="number of map tasks", ) + parser.add_argument( + "--num_mappers_per_round", + default=16, + type=int, + help="number of map tasks per first-stage merge tasks", + ) parser.add_argument( "--num_reducers", + default=16, + type=int, + help="number of second-stage reduce tasks", + ) + parser.add_argument( + "--num_concurrent_rounds", default=4, type=int, - help="number of reduce tasks", + help="max number of rounds of map/merge tasks in flight", ) parser.add_argument( - "--reducer_batch_num_records", - default=1_000_000, - type=RecordCount, - help="number of bytes to buffer before writing the output to EBS", + "--reducer_input_chunk", + default=100 * 1024 * 1024, + type=ByteCount, + help="bytes to read from each file in reduce tasks", ) parser.add_argument( "--skip_sorting", @@ -75,13 +92,13 @@ def get_args(): "tasks to run", "if no task is specified, will run all tasks") tasks = ["generate_input", "sort", "validate_output"] for task in tasks: - tasks_group.add_argument( - f"--{task}", action="store_true", help=f"run task {task}") + tasks_group.add_argument(f"--{task}", action="store_true") - args = parser.parse_args() + args = parser.parse_args(*args, **kwargs) # Derive additional arguments. args.input_part_size = ByteCount(args.total_data_size / args.num_mappers) - args.output_part_size = ByteCount(args.total_data_size / args.num_reducers) + assert args.num_mappers % args.num_mappers_per_round == 0 + args.num_rounds = int(args.num_mappers / args.num_mappers_per_round) args.mount_points = _get_mount_points() # If no tasks are specified, run all tasks. args_dict = vars(args) @@ -92,28 +109,29 @@ def get_args(): def _get_mount_points(): + default_ret = [tempfile.gettempdir()] mnt = "/mnt" - if not os.path.exists(mnt): - return [] - return [os.path.join(mnt, d) for d in os.listdir(mnt)] + if os.path.exists(mnt): + ret = [os.path.join(mnt, d) for d in os.listdir(mnt)] + if len(ret) > 0: + return ret + return default_ret -args = None - # ------------------------------------------------------------ # Generate Input # ------------------------------------------------------------ -def _make_partition_info(part_id: PartId, kind="input") -> PartitionInfo: +def _part_info(args: Args, part_id: PartId, kind="input") -> PartInfo: node = ray.worker.global_worker.node_ip_address mnt = random.choice(args.mount_points) filepath = _get_part_path(mnt, part_id, kind) - return PartitionInfo(part_id, node, filepath) + return PartInfo(part_id, node, filepath) def _get_part_path(mnt: Path, part_id: PartId, kind="input") -> Path: - assert kind in {"input", "output"} + assert kind in {"input", "output", "temp"} dir_fmt = constants.DATA_DIR_FMT[kind] dirpath = dir_fmt.format(mnt=mnt) os.makedirs(dirpath, exist_ok=True) @@ -124,26 +142,25 @@ def _get_part_path(mnt: Path, part_id: PartId, kind="input") -> Path: @ray.remote -def generate_part(part_id: PartId, size: RecordCount, - offset: RecordCount) -> PartitionInfo: +def generate_part(args: Args, part_id: PartId, size: RecordCount, + offset: RecordCount) -> PartInfo: logging_utils.init() - pinfo = _make_partition_info(part_id) - if not args.skip_input: - subprocess.run( - [constants.GENSORT_PATH, f"-b{offset}", f"{size}", pinfo.path], - check=True) - logging.info(f"Generated input {pinfo}") + pinfo = _part_info(args, part_id) + subprocess.run( + [constants.GENSORT_PATH, f"-b{offset}", f"{size}", pinfo.path], + check=True) + logging.info(f"Generated input {pinfo}") return pinfo -def generate_input(): +def generate_input(args: Args): if args.skip_input: return size = constants.bytes_to_records(args.input_part_size) offset = 0 tasks = [] for part_id in range(args.num_mappers): - tasks.append(generate_part.remote(part_id, size, offset)) + tasks.append(generate_part.remote(args, part_id, size, offset)) offset += size assert offset == constants.bytes_to_records(args.total_data_size), args logging.info(f"Generating {len(tasks)} partitions") @@ -158,22 +175,21 @@ def generate_input(): # ------------------------------------------------------------ -def _load_manifest(path: Path) -> List[PartitionInfo]: +def _load_manifest(args: Args, path: Path) -> List[PartInfo]: if args.skip_input: - return _load_dummy_manifest() + return [PartInfo(i, None, None) for i in range(args.num_mappers)] with open(path) as fin: reader = csv.reader(fin) return [ - PartitionInfo(int(part_id), node, path) + PartInfo(int(part_id), node, path) for part_id, node, path in reader ] -def _load_dummy_manifest() -> List[PartitionInfo]: - return [PartitionInfo(i, "", "") for i in range(args.num_mappers)] - - -def _load_partition(path: Path) -> np.ndarray: +def _load_partition(args: Args, path: Path) -> np.ndarray: + if args.skip_input: + return np.frombuffer( + np.random.bytes(args.input_part_size), dtype=np.uint8).copy() return np.fromfile(path, dtype=np.uint8) @@ -190,115 +206,214 @@ def _dummy_sort_and_partition(part: np.ndarray, @ray.remote -def mapper(boundaries: List[int], mapper_id: PartId, - path: Path) -> List[ray.ObjectRef]: +@tracing_utils.timeit("map") +def mapper(args: Args, mapper_id: PartId, boundaries: List[int], + path: Path) -> List[np.ndarray]: logging_utils.init() - task_id = f"M-{mapper_id} Mapper" - logging.info(f"{task_id} starting {args}") - if args.skip_input: - block_size = int(np.ceil(args.input_part_size / args.num_reducers)) - return [ - ray.put( - np.frombuffer(np.random.bytes(block_size), dtype=np.uint8)) - for _ in range(args.num_reducers) - ] - - part = _load_partition(path) + part = _load_partition(args, path) sort_fn = _dummy_sort_and_partition \ if args.skip_sorting else sortlib.sort_and_partition blocks = sort_fn(part, boundaries) - logging.info(f"{task_id} saving to object store") - return [ray.put(part[offset:offset + size]) for offset, size in blocks] + return [part[offset:offset + size] for offset, size in blocks] -def _dummy_merge(blocks: List[np.ndarray], _n: int) -> Iterable[memoryview]: - for block in blocks: +def _dummy_merge( + num_blocks: int, _n: int, + get_block: Callable[[int, int], np.ndarray]) -> Iterable[np.ndarray]: + blocks = [((i, 0), get_block(i, 0)) for i in range(num_blocks)] + while len(blocks) > 0: + (m, d), block = blocks.pop(random.randrange(len(blocks))) yield block - - -@ray.remote -def reducer(reducer_id: PartId, *blocks: List[ray.ObjectRef]) -> PartitionInfo: - logging_utils.init() - task_id = f"R-{reducer_id} Reducer" - logging.info(f"{task_id} starting") - blocks = [np.copy(ray.get(block)) for block in blocks] + d_ = d + 1 + block = get_block(m, d_) + if block is None: + continue + blocks.append(((m, d_), block)) + + +def _merge_impl(args: Args, + M: int, + pinfo: PartInfo, + get_block: Callable[[int, int], np.ndarray], + skip_output=False): merge_fn = _dummy_merge if args.skip_sorting else sortlib.merge_partitions - merger = merge_fn(blocks, args.reducer_batch_num_records) - if args.skip_output: + merger = merge_fn(M, get_block) + + if skip_output: for datachunk in merger: del datachunk - logging.info(f"{task_id} done") - return None else: - pinfo = _make_partition_info(reducer_id, "output") with open(pinfo.path, "wb") as fout: for datachunk in merger: fout.write(datachunk) - logging.info(f"{task_id} done") - return pinfo + return pinfo -@tracing_utils.timeit("sorting") -def sort_main(): - partitions = _load_manifest(constants.INPUT_MANIFEST_FILE) +# See worker_placement_groups() for why `num_cpus=0`. +@ray.remote(num_cpus=0, resources={"worker": 1}) +@tracing_utils.timeit("merge") +def merge_mapper_blocks(args: Args, reducer_id: PartId, mapper_id: PartId, + *blocks: List[np.ndarray]) -> PartInfo: + part_id = constants.merge_part_ids(reducer_id, mapper_id) + pinfo = _part_info(args, part_id, kind="temp") + M = len(blocks) + + def get_block(i, d): + if i >= M or d > 0: + return None + return blocks[i] + + return _merge_impl(args, M, pinfo, get_block) + + +# See worker_placement_groups() for why `num_cpus=0`. +@ray.remote(num_cpus=0, resources={"worker": 1}) +@tracing_utils.timeit("reduce") +def final_merge(args: Args, reducer_id: PartId, + *merged_parts: List[PartInfo]) -> PartInfo: + M = len(merged_parts) + + def _load_block_chunk(pinfo: PartInfo, d: int) -> np.ndarray: + return np.fromfile( + pinfo.path, + dtype=np.uint8, + count=args.reducer_input_chunk, + offset=d * args.reducer_input_chunk) + + def get_block(i, d): + ret = _load_block_chunk(merged_parts[i], d) + if ret.size == 0: + return None + return ret + + pinfo = _part_info(args, reducer_id, "output") + return _merge_impl(args, M, pinfo, get_block, args.skip_output) + + +def _node_res(node: str) -> Dict[str, float]: + return {"resources": {f"node:{node}": 1e-3}} + + +@contextlib.contextmanager +def worker_placement_groups(args: Args) -> List[ray.PlacementGroupID]: + """ + Returns one placement group per node with a `worker` resource. To run + tasks in the placement group, use + `@ray.remote(num_cpus=0, resources={"worker": 1})`. Ray does not + automatically reserve CPU resources, so tasks must specify `num_cpus=0` + in order to run in a placement group. + """ + pgs = [ + ray.util.placement_group([{ + "worker": 1 + }]) for _ in range(args.num_reducers) + ] + ray.get([pg.ready() for pg in pgs]) + try: + yield pgs + finally: + for pg in pgs: + ray.util.remove_placement_group(pg) + + +@tracing_utils.timeit("sort", report_time=True) +def sort_main(args: Args): + parts = _load_manifest(args, constants.INPUT_MANIFEST_FILE) + assert len(parts) == args.num_mappers boundaries = sortlib.get_boundaries(args.num_reducers) - mapper_results = np.empty( - (args.num_mappers, args.num_reducers), dtype=object) - for part_id, node, path in partitions: - opt = {} if args.skip_input else { - "resources": { - f"node:{node}": 1 / args.num_mappers - }, - "memory": args.input_part_size * 1.2, - } - opt.update(num_returns=args.num_reducers) - mapper_results[part_id, :] = mapper.options(**opt).remote( - boundaries, part_id, path) - - reducer_results = [] - for r in range(args.num_reducers): - opt = { - "memory": args.output_part_size * 1.0, - } - blocks = mapper_results[:, r].tolist() - ret = reducer.options(**opt).remote(r, *blocks) - reducer_results.append(ret) - - reducer_results = ray.get(reducer_results) + + mapper_opt = { + "num_returns": args.num_reducers, + "num_cpus": os.cpu_count() / args.num_concurrent_rounds, + } # Load balance across worker nodes by setting `num_cpus`. + merge_results = np.empty( + (args.num_rounds, args.num_reducers), dtype=object) + + part_id = 0 + with worker_placement_groups(args) as pgs: + for round in range(args.num_rounds): + # Limit the number of in-flight rounds. + num_extra_rounds = round - args.num_concurrent_rounds + 1 + if num_extra_rounds > 0: + ray.wait( + [f for f in merge_results.flatten() if f is not None], + num_returns=num_extra_rounds * args.num_reducers) + + # Submit map tasks. + mapper_results = np.empty( + (args.num_mappers_per_round, args.num_reducers), dtype=object) + for _ in range(args.num_mappers_per_round): + _, node, path = parts[part_id] + m = part_id % args.num_mappers_per_round + mapper_results[m, :] = mapper.options(**mapper_opt).remote( + args, part_id, boundaries, path) + part_id += 1 + + # Submit merge tasks. + merge_results[round, :] = [ + merge_mapper_blocks.options(placement_group=pgs[r]).remote( + args, r, round, *mapper_results[:, r].tolist()) + for r in range(args.num_reducers) + ] + + # Delete local references to mapper results. + mapper_results = None + + # Submit second-stage reduce tasks. + reducer_results = [ + final_merge.options(placement_group=pgs[r]).remote( + args, r, *merge_results[:, r].tolist()) + for r in range(args.num_reducers) + ] + reducer_results = ray.get(reducer_results) + if not args.skip_output: with open(constants.OUTPUT_MANIFEST_FILE, "w") as fout: writer = csv.writer(fout) writer.writerows(reducer_results) + logging.info(ray.internal.internal_api.memory_summary(stats_only=True)) + # ------------------------------------------------------------ # Validate Output # ------------------------------------------------------------ +def _run_valsort(args: List[str]): + proc = subprocess.run([constants.VALSORT_PATH] + args, capture_output=True) + if proc.returncode != 0: + logging.critical("\n" + proc.stderr.decode("ascii")) + raise RuntimeError(f"Validation failed: {args}") + + @ray.remote def validate_part(path: Path): logging_utils.init() - proc = subprocess.run([constants.VALSORT_PATH, path], capture_output=True) - if proc.returncode != 0: - logging.critical("\n" + proc.stderr.decode("ascii")) - raise RuntimeError(f"Validation failed: {path}") + sum_path = path + ".sum" + _run_valsort(["-o", sum_path, path]) logging.info(f"Validated output {path}") + with open(sum_path, "rb") as fin: + return os.path.getsize(path), fin.read() -def validate_output(): - if args.skip_output: +def validate_output(args: Args): + if args.skip_sorting or args.skip_output: return - partitions = _load_manifest(constants.OUTPUT_MANIFEST_FILE) - tasks = [] + partitions = _load_manifest(args, constants.OUTPUT_MANIFEST_FILE) + results = [] for _, node, path in partitions: - tasks.append( - validate_part.options(resources={ - f"node:{node}": 1 / args.num_reducers - }).remote(path)) - logging.info(f"Validating {len(tasks)} partitions") - ray.get(tasks) - logging.info("All done!") + results.append(validate_part.options(**_node_res(node)).remote(path)) + logging.info(f"Validating {len(results)} partitions") + results = ray.get(results) + total = sum(s for s, _ in results) + assert total == args.total_data_size, total - args.total_data_size + all_checksum = b"".join(c for _, c in results) + with tempfile.NamedTemporaryFile() as fout: + fout.write(all_checksum) + fout.flush() + _run_valsort(["-s", fout.name]) + logging.info("All OK!") # ------------------------------------------------------------ @@ -306,30 +421,34 @@ def validate_output(): # ------------------------------------------------------------ -def init(): - if args.ray_address is None: - ray.init() +def init(args: Args): + if not args.ray_address: + ray.init(resources={"worker": os.cpu_count()}) else: ray.init(address=args.ray_address) logging_utils.init() logging.info(args) - logging.info(ray.available_resources()) os.makedirs(constants.WORK_DIR, exist_ok=True) + resources = ray.cluster_resources() + logging.info(resources) + args.num_workers = resources["worker"] + progress_tracker = tracing_utils.create_progress_tracker(args) + return progress_tracker -def main(): - init() +def main(args: Args): + # Keep the actor handle in scope for the duration of the program. + _progress_tracker = init(args) # noqa F841 if args.generate_input: - generate_input() + generate_input(args) if args.sort: - sort_main() + sort_main(args) if args.validate_output: - validate_output() + validate_output(args) if __name__ == "__main__": - args = get_args() - main() + main(get_args()) diff --git a/python/ray/experimental/raysort/sortlib.py b/python/ray/experimental/raysort/sortlib.py index ea79ec7168de4..6242867286d5f 100644 --- a/python/ray/experimental/raysort/sortlib.py +++ b/python/ray/experimental/raysort/sortlib.py @@ -1,4 +1,4 @@ -from typing import Iterable, List +from typing import Callable, Iterable, List import numpy as np @@ -21,7 +21,9 @@ def sort_and_partition(part: np.ndarray, return blocks -def merge_partitions(blocks: List[np.ndarray], - _n: int) -> Iterable[memoryview]: +def merge_partitions( + num_blocks: int, + get_block: Callable[[int, int], np.ndarray]) -> Iterable[memoryview]: + blocks = [get_block(i, 0) for i in range(num_blocks)] for block in blocks: yield block diff --git a/python/ray/experimental/raysort/tracing_utils.py b/python/ray/experimental/raysort/tracing_utils.py index e75bae8297429..e67584b62588c 100644 --- a/python/ray/experimental/raysort/tracing_utils.py +++ b/python/ray/experimental/raysort/tracing_utils.py @@ -1,13 +1,122 @@ -import contextlib +import datetime +import functools import logging import time +from typing import List, Tuple +import ray +from ray.util.metrics import Gauge, Histogram -@contextlib.contextmanager -def timeit(event="operation", args={}): - start = time.time() - yield - end = time.time() - duration = end - start - args = {"duration": duration} - logging.info(f"{event} {args}") +from ray.experimental.raysort import constants +from ray.experimental.raysort import logging_utils + +HISTOGRAM_BOUNDARIES = list(range(50, 200, 50)) + + +def timeit( + event: str, + report_time=False, + report_in_progress=True, + report_completed=True, +): + def decorator(f): + @functools.wraps(f) + def wrapped_f(*args, **kwargs): + progress_tracker = ray.get_actor(constants.PROGRESS_TRACKER_ACTOR) + progress_tracker.inc.remote( + f"{event}_in_progress", echo=report_in_progress) + try: + start = time.time() + ret = f(*args, **kwargs) + end = time.time() + duration = end - start + progress_tracker.observe.remote( + f"{event}_time", + duration, + echo=report_time, + ) + progress_tracker.inc.remote( + f"{event}_completed", echo=report_completed) + return ret + finally: + progress_tracker.dec.remote(f"{event}_in_progress") + + return wrapped_f + + return decorator + + +def get_metrics(_args): + return { + "gauges": [ + "map_in_progress", + "merge_in_progress", + "reduce_in_progress", + "sort_in_progress", + "map_completed", + "merge_completed", + "reduce_completed", + "sort_completed", + ], + "histograms": [ + ("map_time", HISTOGRAM_BOUNDARIES), + ("merge_time", HISTOGRAM_BOUNDARIES), + ("reduce_time", HISTOGRAM_BOUNDARIES), + ("sort_time", HISTOGRAM_BOUNDARIES), + ], + } + + +def create_progress_tracker(args): + return ProgressTracker.options( + name=constants.PROGRESS_TRACKER_ACTOR).remote(**get_metrics(args)) + + +@ray.remote +class ProgressTracker: + def __init__( + self, + gauges: List[str], + histograms: List[Tuple[str, List[int]]], + ): + self.counts = {m: 0 for m in gauges} + self.gauges = {m: Gauge(m) for m in gauges} + self.reset_gauges() + self.histograms = { + m: Histogram(m, boundaries=b) + for m, b in histograms + } + logging_utils.init() + + def reset_gauges(self): + for g in self.gauges.values(): + g.set(0) + + def inc(self, metric_name, value=1, echo=False): + gauge = self.gauges.get(metric_name) + if gauge is None: + logging.warning(f"No such Gauge: {metric_name}") + return + self.counts[metric_name] += value + gauge.set(self.counts[metric_name]) + if echo: + logging.info(f"{metric_name} {self.counts[metric_name]}") + + def dec(self, metric_name, value=1, echo=False): + return self.inc(metric_name, -value, echo) + + def observe(self, metric_name, value, echo=False): + histogram = self.histograms.get(metric_name) + if histogram is None: + logging.warning(f"No such Histogram: {metric_name}") + return + histogram.observe(value) + if echo: + logging.info(f"{metric_name} {value}") + + +def export_timeline(): + timestr = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + filename = f"/tmp/ray-timeline-{timestr}.json" + ray.timeline(filename=filename) + logging.info(f"Exported Ray timeline to {filename}") diff --git a/python/ray/experimental/raysort/types.py b/python/ray/experimental/raysort/types.py index 02c6f70e5004a..5d1c39a33a521 100644 --- a/python/ray/experimental/raysort/types.py +++ b/python/ray/experimental/raysort/types.py @@ -7,6 +7,12 @@ RecordCount = int BlockInfo = Tuple[int, int] -PartitionInfo = NamedTuple("PartitionInfo", - [("part_id", PartId), ("node", NodeAddress), - ("path", Path)]) + + +class PartInfo(NamedTuple): + part_id: PartId + node: NodeAddress + path: Path + + def __repr__(self): + return f"Part({self.node}:{self.path})" diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 33d9ce1a92fd4..17b2f3879f05b 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -260,8 +260,7 @@ cdef extern from "ray/core_worker/common.h" nogil: unordered_map[c_string, double] &resources, c_string concurrency_group_name, c_string serialized_runtime_env, - const unordered_map[c_string, c_string] - &override_environment_variables) + c_vector[c_string] runtime_env_uris) cdef cppclass CActorCreationOptions "ray::core::ActorCreationOptions": CActorCreationOptions() @@ -277,8 +276,7 @@ cdef extern from "ray/core_worker/common.h" nogil: c_pair[CPlacementGroupID, int64_t] placement_options, c_bool placement_group_capture_child_tasks, c_string serialized_runtime_env, - const unordered_map[c_string, c_string] - &override_environment_variables) + c_vector[c_string] runtime_env_uris) cdef cppclass CPlacementGroupCreationOptions \ "ray::core::PlacementGroupCreationOptions": diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 7e56dab60965a..a95eaee2c228f 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -2,7 +2,7 @@ # distutils: language = c++ # cython: embedsignature = True -from libc.stdint cimport int64_t +from libc.stdint cimport int64_t, uint64_t from libcpp cimport bool as c_bool from libcpp.memory cimport shared_ptr, unique_ptr from libcpp.pair cimport pair as c_pair @@ -177,7 +177,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_vector[CObjectReference] GetObjectRefs( const c_vector[CObjectID] &object_ids) const - void PromoteObjectToPlasma(const CObjectID &object_id) void GetOwnershipInfo(const CObjectID &object_id, CAddress *owner_address, c_string *object_status) @@ -254,6 +253,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: int64_t GetNumLeasesRequested() const + unordered_map[c_string, c_vector[uint64_t]] GetActorCallStats() const + cdef cppclass CCoreWorkerOptions "ray::core::CoreWorkerOptions": CWorkerType worker_type CLanguage language diff --git a/python/ray/job_config.py b/python/ray/job_config.py index 9ba513f71195e..e9dc6b3d7cd7d 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -1,17 +1,13 @@ from typing import Any, Dict, Optional import uuid -import json import ray._private.gcs_utils as gcs_utils -from ray.core.generated.common_pb2 import RuntimeEnv as RuntimeEnvPB class JobConfig: """A class used to store the configurations of a job. Attributes: - worker_env (dict): Environment variables to be set on worker - processes. num_java_workers_per_process (int): The number of java workers per worker process. jvm_options (str[]): The jvm options for java workers of the job. @@ -24,7 +20,6 @@ class JobConfig: """ def __init__(self, - worker_env=None, num_java_workers_per_process=1, jvm_options=None, code_search_path=None, @@ -32,10 +27,6 @@ def __init__(self, client_job=False, metadata=None, ray_namespace=None): - if worker_env is None: - self.worker_env = dict() - else: - self.worker_env = worker_env self.num_java_workers_per_process = num_java_workers_per_process self.jvm_options = jvm_options or [] self.code_search_path = code_search_path or [] @@ -54,21 +45,23 @@ def set_metadata(self, key: str, value: str) -> None: def serialize(self): """Serialize the struct into protobuf string""" - job_config = self.get_proto_job_config() - return job_config.SerializeToString() + return self.get_proto_job_config().SerializeToString() def set_runtime_env(self, runtime_env: Optional[Dict[str, Any]]) -> None: - # Lazily import this to avoid circular dependencies. - import ray._private.runtime_env as runtime_support - if runtime_env: - self._parsed_runtime_env = runtime_support.RuntimeEnvDict( - runtime_env) - self.worker_env.update( - self._parsed_runtime_env.get_parsed_dict().get("env_vars") - or {}) - else: - self._parsed_runtime_env = runtime_support.RuntimeEnvDict({}) + # TODO(edoakes): this is really unfortunate, but JobConfig is imported + # all over the place so this causes circular imports. We should remove + # this dependency and pass in a validated runtime_env instead. + from ray._private.runtime_env.validation import ParsedRuntimeEnv + self._parsed_runtime_env = ParsedRuntimeEnv(runtime_env or {}) self.runtime_env = runtime_env or dict() + eager_install = False + if runtime_env and "eager_install" in runtime_env: + eager_install = runtime_env["eager_install"] + self.runtime_env_eager_install = eager_install + assert isinstance(self.runtime_env_eager_install, bool), \ + f"The type of eager_install is incorrect: " \ + f"{type(self.runtime_env_eager_install)}" \ + f", the bool type is needed." self._cached_pb = None def set_ray_namespace(self, ray_namespace: str) -> None: @@ -84,35 +77,27 @@ def get_proto_job_config(self): self._cached_pb.ray_namespace = str(uuid.uuid4()) else: self._cached_pb.ray_namespace = self.ray_namespace - for key in self.worker_env: - self._cached_pb.worker_env[key] = self.worker_env[key] self._cached_pb.num_java_workers_per_process = ( self.num_java_workers_per_process) self._cached_pb.jvm_options.extend(self.jvm_options) self._cached_pb.code_search_path.extend(self.code_search_path) - self._cached_pb.runtime_env.CopyFrom(self._get_proto_runtime()) - self._cached_pb.serialized_runtime_env = \ - self.get_serialized_runtime_env() + self._cached_pb.runtime_env.uris[:] = self.get_runtime_env_uris() + serialized_env = self.get_serialized_runtime_env() + self._cached_pb.runtime_env.serialized_runtime_env = serialized_env for k, v in self.metadata.items(): self._cached_pb.metadata[k] = v + self._cached_pb.runtime_env.runtime_env_eager_install = \ + self.runtime_env_eager_install return self._cached_pb def get_runtime_env_uris(self): """Get the uris of runtime environment""" - if self.runtime_env.get("uris"): - return self.runtime_env.get("uris") - return [] - - def set_runtime_env_uris(self, uris): - self.runtime_env["uris"] = uris - self._parsed_runtime_env.set_uris(uris) + return self._parsed_runtime_env.get("uris") or [] def get_serialized_runtime_env(self) -> str: """Return the JSON-serialized parsed runtime env dict""" return self._parsed_runtime_env.serialize() - def _get_proto_runtime(self) -> RuntimeEnvPB: - runtime_env = RuntimeEnvPB() - runtime_env.uris[:] = self.get_runtime_env_uris() - runtime_env.raw_json = json.dumps(self.runtime_env) - return runtime_env + def set_runtime_env_uris(self, uris): + self.runtime_env["uris"] = uris + self._parsed_runtime_env["uris"] = uris diff --git a/python/ray/node.py b/python/ray/node.py index cee0f8bfebeac..0ff15b709b102 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -572,7 +572,10 @@ def _get_log_file_names(self, name, unique=False): log_stderr = os.path.join(self._logs_dir, f"{name}.err") return log_stdout, log_stderr - def _get_unused_port(self, close_on_exit=True): + def _get_unused_port(self, allocated_ports=None): + if allocated_ports is None: + allocated_ports = set() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) port = s.getsockname()[1] @@ -582,6 +585,10 @@ def _get_unused_port(self, close_on_exit=True): # from this method has been used by a different process. for _ in range(NUM_PORT_RETRIES): new_port = random.randint(port, 65535) + if new_port in allocated_ports: + # This port is allocated for other usage already, + # so we shouldn't use it even if it's not in use right now. + continue new_s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: new_s.bind(("", new_port)) @@ -589,13 +596,11 @@ def _get_unused_port(self, close_on_exit=True): new_s.close() continue s.close() - if close_on_exit: - new_s.close() - return new_port, new_s + new_s.close() + return new_port logger.error("Unable to succeed in selecting a random port.") - if close_on_exit: - s.close() - return port, s + s.close() + return port def _prepare_socket_file(self, socket_path, default_prefix): """Prepare the socket file for raylet and plasma. @@ -613,7 +618,7 @@ def _prepare_socket_file(self, socket_path, default_prefix): if sys.platform == "win32": if socket_path is None: result = (f"tcp://{self._localhost}" - f":{self._get_unused_port()[0]}") + f":{self._get_unused_port()}") else: if socket_path is None: result = self._make_inc_temp( @@ -665,7 +670,8 @@ def _get_cached_port(self, port = int(ports_by_node[self.unique_id][port_name]) else: # Pick a new port to use and cache it at this node. - port = (default_port or self._get_unused_port()[0]) + port = (default_port or self._get_unused_port( + set(ports_by_node[self.unique_id].values()))) ports_by_node[self.unique_id][port_name] = port with open(file_path, "w") as f: json.dump(ports_by_node, f) diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 6854a93535b9e..ea3df3acb2f9a 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -1,7 +1,7 @@ -import uuid -import logging -import inspect from functools import wraps +import inspect +import logging +import uuid from ray import cloudpickle as pickle from ray._raylet import PythonFunctionDescriptor @@ -14,7 +14,8 @@ get_current_placement_group, ) import ray._private.signature -import ray._private.runtime_env as runtime_support +from ray._private.runtime_env.validation import ( + override_task_or_actor_runtime_env, ParsedRuntimeEnv) from ray.util.tracing.tracing_helper import (_tracing_task_invocation, _inject_tracing_into_function) @@ -78,7 +79,7 @@ class RemoteFunction: def __init__(self, language, function, function_descriptor, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type, num_returns, max_calls, max_retries, - retry_exceptions, runtime_env): + retry_exceptions, runtime_env, placement_group): if inspect.iscoroutinefunction(function): raise ValueError("'async def' should not be used for remote " "tasks. You can wrap the async function with " @@ -108,7 +109,12 @@ def __init__(self, language, function, function_descriptor, num_cpus, self._retry_exceptions = (DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS if retry_exceptions is None else retry_exceptions) - self._runtime_env = runtime_env + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + self._runtime_env = ParsedRuntimeEnv( + runtime_env or {}, is_task_or_actor=True) + self._placement_group = placement_group self._decorator = getattr(function, "__ray_invocation_decorator__", None) self._function_signature = ray._private.signature.extract_signature( @@ -145,7 +151,6 @@ def options(self, placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, runtime_env=None, - override_environment_variables=None, name=""): """Configures and overrides the task invocation parameters. @@ -164,6 +169,11 @@ def f(): """ func_cls = self + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + new_runtime_env = ParsedRuntimeEnv( + runtime_env or {}, is_task_or_actor=True) class FuncWrapper: def remote(self, *args, **kwargs): @@ -183,9 +193,7 @@ def remote(self, *args, **kwargs): placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), - runtime_env=runtime_env, - override_environment_variables=( - override_environment_variables), + runtime_env=new_runtime_env, name=name) return FuncWrapper() @@ -207,10 +215,10 @@ def _remote(self, placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, runtime_env=None, - override_environment_variables=None, name=""): """Submit the remote function for execution.""" - if client_mode_should_convert(): + + if client_mode_should_convert(auto_init=True): return client_mode_convert_function( self, args, @@ -229,7 +237,6 @@ def _remote(self, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), runtime_env=runtime_env, - override_environment_variables=override_environment_variables, name=name) worker = ray.worker.global_worker @@ -270,7 +277,12 @@ def _remote(self, placement_group_capture_child_tasks = ( worker.should_capture_child_tasks_in_placement_group) - if placement_group == "default": + if self._placement_group != "default": + if self._placement_group: + placement_group = self._placement_group + else: + placement_group = PlacementGroup.empty() + elif placement_group == "default": if placement_group_capture_child_tasks: placement_group = get_current_placement_group() else: @@ -288,18 +300,16 @@ def _remote(self, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type) - if runtime_env is None: + if runtime_env and not isinstance(runtime_env, ParsedRuntimeEnv): + runtime_env = ParsedRuntimeEnv(runtime_env) + elif isinstance(runtime_env, ParsedRuntimeEnv): + pass + else: runtime_env = self._runtime_env - job_runtime_env = worker.core_worker.get_current_runtime_env_dict() - runtime_env_dict = runtime_support.override_task_or_actor_runtime_env( - runtime_env, job_runtime_env) - - if override_environment_variables: - logger.warning("override_environment_variables is deprecated and " - "will be removed in Ray 1.6. Please use " - ".options(runtime_env={'env_vars': {...}}).remote()" - "instead.") + parent_runtime_env = worker.core_worker.get_current_runtime_env() + parsed_runtime_env = override_task_or_actor_runtime_env( + runtime_env, parent_runtime_env) def invocation(args, kwargs): if self._is_cross_language: @@ -315,21 +325,12 @@ def invocation(args, kwargs): "Cross language remote function " \ "cannot be executed locally." object_refs = worker.core_worker.submit_task( - self._language, - self._function_descriptor, - list_args, - name, - num_returns, - resources, - max_retries, - retry_exceptions, - placement_group.id, - placement_group_bundle_index, + self._language, self._function_descriptor, list_args, name, + num_returns, resources, max_retries, retry_exceptions, + placement_group.id, placement_group_bundle_index, placement_group_capture_child_tasks, - worker.debugger_breakpoint, - runtime_env_dict, - override_environment_variables=override_environment_variables - or dict()) + worker.debugger_breakpoint, parsed_runtime_env.serialize(), + parsed_runtime_env.get("uris") or []) # Reset worker's debug context from the last "remote" command # (which applies only to this .remote call). worker.debugger_breakpoint = b"" diff --git a/python/ray/runtime_context.py b/python/ray/runtime_context.py index 750e213cc12b0..64bee3fc7cf79 100644 --- a/python/ray/runtime_context.py +++ b/python/ray/runtime_context.py @@ -152,7 +152,7 @@ def should_capture_child_tasks_in_placement_group(self): @property def runtime_env(self): - """Get the runtime env passed to job_config + """Get the runtime env used for the current driver or worker. Returns: The runtime env currently using by this worker. @@ -172,12 +172,24 @@ def current_actor(self): worker.check_connected() return worker.core_worker.get_actor_handle(self.actor_id) + def _get_actor_call_stats(self): + """Get the current worker's task counters. + + Returns: + A dictionary keyed by the function name. The values are + dictionaries with form ``{"received": 0, "executing": 1, + "exectued": 2}``. + """ + worker = self.worker + worker.check_connected() + return worker.core_worker.get_actor_call_stats() + _runtime_context = None @PublicAPI(stability="beta") -@client_mode_hook +@client_mode_hook(auto_init=False) def get_runtime_context(): """Get the runtime context of the current driver/worker. diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 08465f7a422e0..f2b0cfe7b4ed8 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -598,10 +598,9 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, "password.", cf.bold("--redis-password"), cf.bold("--address")) - node_ip_address = services.get_node_ip_address() - # Get the node IP address if one is not provided. - ray_params.update_if_absent(node_ip_address=node_ip_address) + ray_params.update_if_absent( + node_ip_address=services.get_node_ip_address()) cli_logger.labeled_value("Local node IP", ray_params.node_ip_address) ray_params.update_if_absent( redis_port=port, @@ -614,7 +613,7 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, # Fail early when starting a new cluster when one is already running if address is None: - default_address = f"{node_ip_address}:{port}" + default_address = f"{ray_params.node_ip_address}:{port}" redis_addresses = services.find_redis_address(default_address) if len(redis_addresses) > 0: raise ConnectionError( diff --git a/python/ray/serialization.py b/python/ray/serialization.py index bc335e4a8c539..5bf8c0d1437f3 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -91,7 +91,7 @@ def object_ref_reducer(obj): worker = ray.worker.global_worker worker.check_connected() obj, owner_address, object_status = ( - worker.core_worker.serialize_and_promote_object_ref(obj)) + worker.core_worker.serialize_object_ref(obj)) return _object_ref_deserializer, \ (obj.binary(), obj.call_site(), owner_address, object_status) diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index cfab567726e77..9417f7c3798a3 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -106,7 +106,7 @@ py_test( py_test( name = "test_ray_client", - size = "small", + size = "medium", srcs = serve_tests_srcs, tags = ["exclusive", "team:serverless"], deps = [":serve_lib"], @@ -338,3 +338,11 @@ py_test( tags = ["exclusive", "team:serve"], deps = [":serve_lib"] ) + +py_test( + name = "conda_env", + size = "medium", + srcs = glob(["examples/doc/*.py"]), + tags = ["exclusive", "post_wheel_build", "team:serve"], + deps = [":serve_lib"] +) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index cd0ea1b033816..05a40dc34df0a 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -7,8 +7,7 @@ import time from dataclasses import dataclass from functools import wraps -from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, Union, - overload) +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, overload from weakref import WeakValueDictionary from fastapi import APIRouter, FastAPI @@ -189,7 +188,8 @@ def _wait_for_goal(self, def deploy(self, name: str, backend_def: Union[Callable, Type[Callable], str], - *init_args: Any, + init_args: Tuple[Any], + init_kwargs: Dict[Any, Any], ray_actor_options: Optional[Dict] = None, config: Optional[Union[BackendConfig, Dict[str, Any]]] = None, version: Optional[str] = None, @@ -213,7 +213,10 @@ def deploy(self, del ray_actor_options["runtime_env"]["working_dir"] replica_config = ReplicaConfig( - backend_def, *init_args, ray_actor_options=ray_actor_options) + backend_def, + init_args=init_args, + init_kwargs=init_kwargs, + ray_actor_options=ray_actor_options) if isinstance(config, dict): backend_config = BackendConfig.parse_obj(config) @@ -222,16 +225,10 @@ def deploy(self, else: raise TypeError("config must be a BackendConfig or a dictionary.") - python_methods = [] - if inspect.isclass(backend_def): - for method_name, _ in inspect.getmembers(backend_def, - inspect.isfunction): - python_methods.append(method_name) - goal_id, updating = ray.get( self._controller.deploy.remote( - name, backend_config.to_proto_bytes(), replica_config, - python_methods, version, prev_version, route_prefix, + name, backend_config.to_proto_bytes(), replica_config, version, + prev_version, route_prefix, ray.get_runtime_context().job_id)) tag = f"component=serve deployment={name}" @@ -318,27 +315,16 @@ def get_handle( "to create sync handle. Learn more at https://docs.ray.io/en/" "master/serve/http-servehandle.html#sync-and-async-handles") - if endpoint_name in all_endpoints: - this_endpoint = all_endpoints[endpoint_name] - python_methods: List[str] = this_endpoint["python_methods"] - else: - # This can happen in the missing_ok=True case. - # handle.method_name.remote won't work and user must - # use the legacy handle.options(method).remote(). - python_methods: List[str] = [] - if sync: handle = RayServeSyncHandle( self._controller, endpoint_name, - known_python_methods=python_methods, _internal_pickled_http_request=_internal_pickled_http_request, ) else: handle = RayServeHandle( self._controller, endpoint_name, - known_python_methods=python_methods, _internal_pickled_http_request=_internal_pickled_http_request, ) @@ -619,6 +605,7 @@ def __init__(self, version: Optional[str] = None, prev_version: Optional[str] = None, init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Tuple[Any]] = None, route_prefix: Optional[str] = None, ray_actor_options: Optional[Dict] = None, _internal=False) -> None: @@ -644,6 +631,8 @@ def __init__(self, raise TypeError("prev_version must be a string.") if not (init_args is None or isinstance(init_args, tuple)): raise TypeError("init_args must be a tuple.") + if not (init_kwargs is None or isinstance(init_kwargs, dict)): + raise TypeError("init_kwargs must be a dict.") if route_prefix is not None: if not isinstance(route_prefix, str): raise TypeError("route_prefix must be a string.") @@ -660,6 +649,16 @@ def __init__(self, if init_args is None: init_args = () + if init_kwargs is None: + init_kwargs = {} + + # TODO(architkulkarni): Enforce that autoscaling_config and + # user-provided num_replicas should be mutually exclusive. + if version is None and config.autoscaling_config is not None: + # TODO(architkulkarni): Remove this restriction. + raise ValueError( + "Currently autoscaling is only supported for " + "versioned deployments. Try @serve.deployment(version=...).") self._func_or_class = func_or_class self._name = name @@ -667,6 +666,7 @@ def __init__(self, self._prev_version = prev_version self._config = config self._init_args = init_args + self._init_kwargs = init_kwargs self._route_prefix = route_prefix self._ray_actor_options = ray_actor_options @@ -724,7 +724,12 @@ def ray_actor_options(self) -> Optional[Dict]: @property def init_args(self) -> Tuple[Any]: - """Arguments passed to the underlying class's constructor.""" + """Positional args passed to the underlying class's constructor.""" + return self._init_args + + @property + def init_kwargs(self) -> Tuple[Any]: + """Keyword args passed to the underlying class's constructor.""" return self._init_args @property @@ -738,20 +743,25 @@ def __call__(self): "Use `deployment.deploy() instead.`") @PublicAPI - def deploy(self, *init_args, _blocking=True): + def deploy(self, *init_args, _blocking=True, **init_kwargs): """Deploy or update this deployment. Args: init_args (optional): args to pass to the class __init__ method. Not valid if this deployment wraps a function. + init_kwargs (optional): kwargs to pass to the class __init__ + method. Not valid if this deployment wraps a function. """ if len(init_args) == 0 and self._init_args is not None: init_args = self._init_args + if len(init_kwargs) == 0 and self._init_kwargs is not None: + init_kwargs = self._init_kwargs return _get_global_client().deploy( self._name, self._func_or_class, - *init_args, + init_args, + init_kwargs, ray_actor_options=self._ray_actor_options, config=self._config, version=self._version, @@ -783,19 +793,23 @@ def get_handle(self, sync: Optional[bool] = True self._name, missing_ok=True, sync=sync) @PublicAPI - def options( - self, - func_or_class: Optional[Callable] = None, - name: Optional[str] = None, - version: Optional[str] = None, - prev_version: Optional[str] = None, - init_args: Optional[Tuple[Any]] = None, - route_prefix: Optional[str] = None, - num_replicas: Optional[int] = None, - ray_actor_options: Optional[Dict] = None, - user_config: Optional[Any] = None, - max_concurrent_queries: Optional[int] = None, - ) -> "Deployment": + def options(self, + func_or_class: Optional[Callable] = None, + name: Optional[str] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Dict[Any, Any]] = None, + route_prefix: Optional[str] = None, + num_replicas: Optional[int] = None, + ray_actor_options: Optional[Dict] = None, + user_config: Optional[Any] = None, + max_concurrent_queries: Optional[int] = None, + _autoscaling_config: Optional[Union[Dict, + AutoscalingConfig]] = None, + _graceful_shutdown_wait_loop_s: Optional[float] = None, + _graceful_shutdown_timeout_s: Optional[float] = None + ) -> "Deployment": """Return a copy of this deployment with updated options. Only those options passed in will be updated, all others will remain @@ -821,6 +835,9 @@ def options( if init_args is None: init_args = self._init_args + if init_kwargs is None: + init_kwargs = self._init_kwargs + if route_prefix is None: if self._route_prefix == f"/{self._name}": route_prefix = None @@ -830,6 +847,17 @@ def options( if ray_actor_options is None: ray_actor_options = self._ray_actor_options + if _autoscaling_config is None: + new_config.autoscaling_config = _autoscaling_config + + if _graceful_shutdown_wait_loop_s is not None: + new_config.graceful_shutdown_wait_loop_s = ( + _graceful_shutdown_wait_loop_s) + + if _graceful_shutdown_timeout_s is not None: + new_config.graceful_shutdown_timeout_s = ( + _graceful_shutdown_timeout_s) + return Deployment( func_or_class, name, @@ -837,6 +865,7 @@ def options( version=version, prev_version=prev_version, init_args=init_args, + init_kwargs=init_kwargs, route_prefix=route_prefix, ray_actor_options=ray_actor_options, _internal=True, @@ -848,6 +877,7 @@ def __eq__(self, other): self._version == other._version, self._config == other._config, self._init_args == other._init_args, + self._init_kwargs == other._init_kwargs, self._route_prefix == other._route_prefix, self._ray_actor_options == self._ray_actor_options, ]) @@ -871,16 +901,20 @@ def deployment(func_or_class: Callable) -> Deployment: @overload -def deployment(name: Optional[str] = None, - version: Optional[str] = None, - prev_version: Optional[str] = None, - num_replicas: Optional[int] = None, - init_args: Optional[Tuple[Any]] = None, - ray_actor_options: Optional[Dict] = None, - user_config: Optional[Any] = None, - max_concurrent_queries: Optional[int] = None, - _autoscaling_config: Optional[dict] = None - ) -> Callable[[Callable], Deployment]: +def deployment( + name: Optional[str] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + num_replicas: Optional[int] = None, + init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Dict[Any, Any]] = None, + ray_actor_options: Optional[Dict] = None, + user_config: Optional[Any] = None, + max_concurrent_queries: Optional[int] = None, + _autoscaling_config: Optional[Union[Dict, AutoscalingConfig]] = None, + _graceful_shutdown_wait_loop_s: Optional[float] = None, + _graceful_shutdown_timeout_s: Optional[float] = None +) -> Callable[[Callable], Deployment]: pass @@ -892,11 +926,14 @@ def deployment( prev_version: Optional[str] = None, num_replicas: Optional[int] = None, init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Dict[Any, Any]] = None, route_prefix: Optional[str] = None, ray_actor_options: Optional[Dict] = None, user_config: Optional[Any] = None, max_concurrent_queries: Optional[int] = None, - _autoscaling_config: Optional[dict] = None, + _autoscaling_config: Optional[Union[Dict, AutoscalingConfig]] = None, + _graceful_shutdown_wait_loop_s: Optional[float] = None, + _graceful_shutdown_timeout_s: Optional[float] = None ) -> Callable[[Callable], Deployment]: """Define a Serve deployment. @@ -915,7 +952,10 @@ def deployment( not check the existing deployment's version. num_replicas (Optional[int]): The number of processes to start up that will handle requests to this deployment. Defaults to 1. - init_args (Optional[Tuple]): Arguments to be passed to the class + init_args (Optional[Tuple]): Positional args to be passed to the class + constructor when starting up deployment replicas. These can also be + passed when you call `.deploy()` on the returned Deployment. + init_kwargs (Optional[Dict]): Keyword args to be passed to the class constructor when starting up deployment replicas. These can also be passed when you call `.deploy()` on the returned Deployment. route_prefix (Optional[str]): Requests to paths under this HTTP path @@ -962,8 +1002,13 @@ class MyDeployment: config.max_concurrent_queries = max_concurrent_queries if _autoscaling_config is not None: - config.autoscaling_config = AutoscalingConfig.parse_obj( - _autoscaling_config) + config.autoscaling_config = _autoscaling_config + + if _graceful_shutdown_wait_loop_s is not None: + config.graceful_shutdown_wait_loop_s = _graceful_shutdown_wait_loop_s + + if _graceful_shutdown_timeout_s is not None: + config.graceful_shutdown_timeout_s = _graceful_shutdown_timeout_s def decorator(_func_or_class): return Deployment( @@ -973,6 +1018,7 @@ def decorator(_func_or_class): version=version, prev_version=prev_version, init_args=init_args, + init_kwargs=init_kwargs, route_prefix=route_prefix, ray_actor_options=ray_actor_options, _internal=True, @@ -1014,6 +1060,7 @@ def get_deployment(name: str) -> Deployment: backend_info.backend_config, version=backend_info.version, init_args=backend_info.replica_config.init_args, + init_kwargs=backend_info.replica_config.init_kwargs, route_prefix=route_prefix, ray_actor_options=backend_info.replica_config.ray_actor_options, _internal=True, @@ -1037,6 +1084,7 @@ def list_deployments() -> Dict[str, Deployment]: backend_info.backend_config, version=backend_info.version, init_args=backend_info.replica_config.init_args, + init_kwargs=backend_info.replica_config.init_kwargs, route_prefix=route_prefix, ray_actor_options=backend_info.replica_config.ray_actor_options, _internal=True, diff --git a/python/ray/serve/autoscaling_metrics.py b/python/ray/serve/autoscaling_metrics.py index 084996760d297..4b0d030d700cc 100644 --- a/python/ray/serve/autoscaling_metrics.py +++ b/python/ray/serve/autoscaling_metrics.py @@ -74,7 +74,7 @@ def add_metrics_point(self, data_points: Dict[str, float], Args: data_points(dict): dictionary containing the metrics values. The - key should be a string that uniquely identitify this time series + key should be a string that uniquely identifies this time series and to be used to perform aggregation. timestamp(float): the unix epoch timestamp the metrics are collected at. @@ -98,6 +98,9 @@ def window_average(self, do_compact(bool): whether or not to delete the datapoints that's before `window_start_timestamp_s` to save memory. Default is true. + Returns: + The average of all the datapoints for the key on and after time + window_start_timestamp_s, or None if there are no such points. """ datapoints = self.data[key] diff --git a/python/ray/serve/autoscaling_policy.py b/python/ray/serve/autoscaling_policy.py index 6a9887fb7497c..23dbbf65159e9 100644 --- a/python/ray/serve/autoscaling_policy.py +++ b/python/ray/serve/autoscaling_policy.py @@ -16,7 +16,6 @@ def calculate_desired_num_replicas(autoscaling_config: AutoscalingConfig, current_num_ongoing_requests (List[float]): A list of the number of ongoing requests for each replica. Assumes each entry has already been time-averaged over the desired lookback window. - current_num_replicas (int): The current number of active replicas. Returns: desired_num_replicas: The desired number of replicas to scale to, based diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 2ab4c5e41d99d..6068887f9bd7a 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -76,6 +76,7 @@ def __init__(self, actor_name: str, detached: bool, controller_name: str, self._ready_obj_ref = None self._graceful_shutdown_ref = None + self._graceful_shutdown_timeout_s = None self._actor_resources = None self._health_check_ref = None @@ -147,6 +148,8 @@ def start(self, backend_info: BackendInfo, version: BackendVersion): Start a new actor for current BackendReplica instance. """ self._actor_resources = backend_info.replica_config.resource_dict + self._graceful_shutdown_timeout_s = ( + backend_info.backend_config.graceful_shutdown_timeout_s) if USE_PLACEMENT_GROUP: self._placement_group = self.create_placement_group( self._placement_group_name, self._actor_resources) @@ -164,6 +167,7 @@ def start(self, backend_info: BackendInfo, version: BackendVersion): **backend_info.replica_config.ray_actor_options).remote( self.backend_tag, self.replica_tag, backend_info.replica_config.init_args, + backend_info.replica_config.init_kwargs, backend_info.backend_config.to_proto_bytes(), version, self._controller_name, self._detached) @@ -243,14 +247,19 @@ def actor_resources(self) -> Dict[str, float]: def available_resources(self) -> Dict[str, float]: return ray.available_resources() - def graceful_stop(self) -> None: - """Request the actor to exit gracefully.""" + def graceful_stop(self) -> Duration: + """Request the actor to exit gracefully. + + Returns the timeout after which to kill the actor. + """ try: handle = ray.get_actor(self._actor_name) self._graceful_shutdown_ref = handle.prepare_for_shutdown.remote() except ValueError: pass + return self._graceful_shutdown_timeout_s + def check_stopped(self) -> bool: """Check if the actor has exited.""" try: @@ -386,14 +395,15 @@ def check_started(self) -> ReplicaStartupStatus: return status - def stop(self, graceful_shutdown_timeout_s: Duration = 0) -> None: + def stop(self, graceful: bool = True) -> None: """Stop the replica. Should handle the case where the replica is already stopped. """ - self._actor.graceful_stop() - self._graceful_shutdown_timeout_s = graceful_shutdown_timeout_s - self._shutdown_deadline = time.time() + graceful_shutdown_timeout_s + timeout_s = self._actor.graceful_stop() + if not graceful: + timeout_s = 0 + self._shutdown_deadline = time.time() + timeout_s def check_stopped(self) -> bool: """Check if the replica has finished stopping.""" @@ -402,14 +412,13 @@ def check_stopped(self) -> bool: self._actor.cleanup() return True - timeout_passed = time.time() >= self._shutdown_deadline - + timeout_passed = time.time() > self._shutdown_deadline if timeout_passed: # Graceful period passed, kill it forcefully. # This will be called repeatedly until the replica shuts down. logger.debug( - f"Replica {self.replica_tag} did not shutdown after " - f"{self._graceful_shutdown_timeout_s}s, force-killing. " + f"Replica {self.replica_tag} did not shut down after grace " + "period, force-killing it. " f"component=serve deployment={self.backend_tag} " f"replica={self.replica_tag}") @@ -722,9 +731,9 @@ def deploy(self, backend_info: BackendInfo) -> Tuple[Optional[GoalId], bool]: """Deploy the backend. - If the backend already exists with the same version, this is a no-op - and returns the GoalId corresponding to the existing update if there - is one. + If the backend already exists with the same version and BackendConfig, + this is a no-op and returns the GoalId corresponding to the existing + update if there is one. Returns: GoalId, bool: The GoalId for the client to wait for and whether or @@ -760,11 +769,8 @@ def deploy(self, self._goal_manager.complete_goal(existing_goal_id) return new_goal_id, True - def delete(self, force_kill: bool = False) -> Optional[GoalId]: + def delete(self) -> Optional[GoalId]: new_goal_id, existing_goal_id = self._set_backend_goal(None) - if force_kill: - self._target_info.backend_config.\ - experimental_graceful_shutdown_timeout_s = 0 self._save_checkpoint_func() self._notify_backend_configs_changed() @@ -822,9 +828,6 @@ def _stop_wrong_version_replicas(self) -> int: states=[ReplicaState.STARTING, ReplicaState.RUNNING], max_replicas=max_to_stop) - graceful_shutdown_timeout_s = ( - self._target_info.backend_config. - experimental_graceful_shutdown_timeout_s) code_version_changes = 0 user_config_changes = 0 for replica in replicas_to_update: @@ -834,8 +837,7 @@ def _stop_wrong_version_replicas(self) -> int: if (replica.version.code_version != self._target_version.code_version): code_version_changes += 1 - replica.stop( - graceful_shutdown_timeout_s=graceful_shutdown_timeout_s) + replica.stop() self._replicas.add(ReplicaState.STOPPING, replica) # If only the user_config is a mismatch, we update it dynamically # without restarting the replica. @@ -869,10 +871,6 @@ def _scale_backend_replicas(self) -> bool: assert self._target_replicas >= 0, ("Number of replicas must be" " greater than or equal to 0.") - graceful_shutdown_timeout_s = ( - self._target_info.backend_config. - experimental_graceful_shutdown_timeout_s) - self._stop_wrong_version_replicas() current_replicas = self._replicas.count(states=[ @@ -924,8 +922,7 @@ def _scale_backend_replicas(self) -> bool: for replica in replicas_to_stop: logger.debug(f"Adding STOPPING to replica_tag: {replica}, " f"backend_tag: {self._name}") - replica.stop( - graceful_shutdown_timeout_s=graceful_shutdown_timeout_s) + replica.stop() self._replicas.add(ReplicaState.STOPPING, replica) return True @@ -1014,7 +1011,7 @@ def _check_startup_replicas(self, # Increase startup failure counter if we're tracking it self._replica_constructor_retry_counter += 1 - replica.stop(graceful_shutdown_timeout_s=0) + replica.stop(graceful=False) self._replicas.add(ReplicaState.STOPPING, replica) transitioned = True elif start_status == ReplicaStartupStatus.PENDING: @@ -1026,7 +1023,7 @@ def _check_startup_replicas(self, if not stop_on_slow: self._replicas.add(original_state, replica) else: - replica.stop(graceful_shutdown_timeout_s=0) + replica.stop(graceful=False) self._replicas.add(ReplicaState.STOPPING, replica) transitioned = True slow_replicas.append(replica) @@ -1049,7 +1046,7 @@ def _check_and_update_replicas(self) -> bool: f"{self._name} failed health check, stopping it. " f"component=serve deployment={self._name} " f"replica={replica.replica_tag}") - replica.stop(graceful_shutdown_timeout_s=0) + replica.stop(graceful=False) self._replicas.add(ReplicaState.STOPPING, replica) slow_start_replicas = [] @@ -1073,8 +1070,9 @@ def _check_and_update_replicas(self) -> bool: f"Deployment '{self._name}' has " f"{len(slow_start_replicas)} replicas that have taken " f"more than {SLOW_STARTUP_WARNING_S}s to start up. This " - "may be caused by waiting for the cluster to auto-scale " - "or because the constructor is slow. Resources required " + "may be caused by waiting for the cluster to auto-scale, " + "waiting for a runtime environment to install, or a slow " + "constructor. Resources required " f"for each replica: {required}, resources available: " f"{available}. component=serve deployment={self._name}") @@ -1236,7 +1234,7 @@ def shutdown(self) -> List[GoalId]: shutdown_goals = [] for backend_state in self._backend_states.values(): - goal = backend_state.delete(force_kill=True) + goal = backend_state.delete() if goal is not None: shutdown_goals.append(goal) @@ -1302,9 +1300,9 @@ def deploy_backend(self, backend_tag: BackendTag, backend_info: BackendInfo ) -> Tuple[Optional[GoalId], bool]: """Deploy the backend. - If the backend already exists with the same version, this is a no-op - and returns the GoalId corresponding to the existing update if there - is one. + If the backend already exists with the same version and BackendConfig, + this is a no-op and returns the GoalId corresponding to the existing + update if there is one. Returns: GoalId, bool: The GoalId for the client to wait for and whether or @@ -1319,15 +1317,14 @@ def deploy_backend(self, backend_tag: BackendTag, backend_info: BackendInfo return self._backend_states[backend_tag].deploy(backend_info) - def delete_backend(self, backend_tag: BackendTag, - force_kill: bool = False) -> Optional[GoalId]: + def delete_backend(self, backend_tag: BackendTag) -> Optional[GoalId]: # This method must be idempotent. We should validate that the # specified backend exists on the client. if backend_tag not in self._backend_states: return None backend_state = self._backend_states[backend_tag] - return backend_state.delete(force_kill=force_kill) + return backend_state.delete() def update(self) -> bool: """Updates the state of all backends to match their goal state.""" diff --git a/python/ray/serve/common.py b/python/ray/serve/common.py index be13503c97334..0e97b5cf98eb7 100644 --- a/python/ray/serve/common.py +++ b/python/ray/serve/common.py @@ -1,7 +1,7 @@ import ray -from dataclasses import dataclass, field -from typing import List, Optional +from dataclasses import dataclass +from typing import Optional from uuid import UUID from ray.actor import ActorClass @@ -17,7 +17,6 @@ @dataclass class EndpointInfo: - python_methods: Optional[List[str]] = field(default_factory=list) route: Optional[str] = None diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 4002550ae109f..b7d5c08457691 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -1,7 +1,7 @@ import inspect import pickle from enum import Enum -from typing import Any, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import pydantic from google.protobuf.json_format import MessageToDict @@ -24,8 +24,10 @@ class AutoscalingConfig(BaseModel): # Private options below # Metrics scraping options + + # How often to scrape for metrics metrics_interval_s: float = 10.0 - loop_period_s: float = 30.0 + # Time window to average over for metrics. look_back_period_s: float = 30.0 # Internal autoscaling configuration options @@ -34,6 +36,7 @@ class AutoscalingConfig(BaseModel): smoothing_factor: float = 1.0 # TODO(architkulkarni): implement below + # loop_period_s = 30 # How frequently to make autoscaling decisions # How long to wait before scaling down replicas # downscale_delay_s: float = 600.0 # How long to wait before scaling up replicas @@ -52,8 +55,6 @@ class AutoscalingConfig(BaseModel): class BackendConfig(BaseModel): """Configuration options for a backend, to be set by the user. - DEPRECATED. Will be removed in Ray 1.5. See docs for details. - Args: num_replicas (Optional[int]): The number of processes to start up that will handle requests to this backend. Defaults to 1. @@ -63,10 +64,10 @@ class BackendConfig(BaseModel): user_config (Optional[Any]): Arguments to pass to the reconfigure method of the backend. The reconfigure method is called if user_config is not None. - experimental_graceful_shutdown_wait_loop_s (Optional[float]): Duration + graceful_shutdown_wait_loop_s (Optional[float]): Duration that backend workers will wait until there is no more work to be done before shutting down. Defaults to 2s. - experimental_graceful_shutdown_timeout_s (Optional[float]): + graceful_shutdown_timeout_s (Optional[float]): Controller waits for this duration to forcefully kill the replica for shutdown. Defaults to 20s. """ @@ -75,8 +76,8 @@ class BackendConfig(BaseModel): max_concurrent_queries: Optional[int] = None user_config: Any = None - experimental_graceful_shutdown_wait_loop_s: NonNegativeFloat = 2.0 - experimental_graceful_shutdown_timeout_s: NonNegativeFloat = 20.0 + graceful_shutdown_wait_loop_s: NonNegativeFloat = 2.0 + graceful_shutdown_timeout_s: NonNegativeFloat = 20.0 autoscaling_config: Optional[AutoscalingConfig] = None @@ -121,16 +122,23 @@ def from_proto_bytes(cls, proto_bytes: bytes): class ReplicaConfig: - def __init__(self, backend_def, *init_args, ray_actor_options=None): + def __init__(self, + backend_def: Callable, + init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Dict[Any, Any]] = None, + ray_actor_options=None): # Validate that backend_def is an import path, function, or class. if isinstance(backend_def, str): self.func_or_class_name = backend_def pass elif inspect.isfunction(backend_def): self.func_or_class_name = backend_def.__name__ - if len(init_args) != 0: + if init_args: raise ValueError( "init_args not supported for function backend.") + if init_kwargs: + raise ValueError( + "init_kwargs not supported for function backend.") elif inspect.isclass(backend_def): self.func_or_class_name = backend_def.__name__ else: @@ -139,7 +147,8 @@ def __init__(self, backend_def, *init_args, ray_actor_options=None): format(type(backend_def))) self.serialized_backend_def = cloudpickle.dumps(backend_def) - self.init_args = init_args + self.init_args = init_args if init_args is not None else () + self.init_kwargs = init_kwargs if init_kwargs is not None else {} if ray_actor_options is None: self.ray_actor_options = {} else: @@ -158,12 +167,13 @@ def _validate(self): raise TypeError("ray_actor_options must be a dictionary.") elif "lifetime" in self.ray_actor_options: raise ValueError( - "Specifying lifetime in init_args is not allowed.") + "Specifying lifetime in ray_actor_options is not allowed.") elif "name" in self.ray_actor_options: - raise ValueError("Specifying name in init_args is not allowed.") + raise ValueError( + "Specifying name in ray_actor_options is not allowed.") elif "max_restarts" in self.ray_actor_options: raise ValueError("Specifying max_restarts in " - "init_args is not allowed.") + "ray_actor_options is not allowed.") else: # Ray defaults to zero CPUs for placement, we default to one here. if "num_cpus" not in self.ray_actor_options: diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index c367dc4232b81..cdaf1cf008151 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -8,8 +8,8 @@ import ray from ray.actor import ActorHandle from ray.serve.async_goal_manager import AsyncGoalManager +from ray.serve.autoscaling_policy import calculate_desired_num_replicas from ray.serve.backend_state import ReplicaState, BackendStateManager -from ray.serve.backend_worker import create_backend_replica from ray.serve.common import ( BackendInfo, BackendTag, @@ -20,9 +20,10 @@ ReplicaTag, ) from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig -from ray.serve.constants import (CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY) +from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY from ray.serve.endpoint_state import EndpointState from ray.serve.http_state import HTTPState +from ray.serve.replica import create_replica_wrapper from ray.serve.storage.checkpoint_path import make_kv_store from ray.serve.long_poll import LongPollHost from ray.serve.utils import logger @@ -104,6 +105,10 @@ def record_autoscaling_metrics(self, data: Dict[str, float], def _dump_autoscaling_metrics_for_testing(self): return self.autoscaling_metrics_store.data + def _dump_replica_states_for_testing(self, deployment_name): + return self.backend_state_manager._backend_states[ + deployment_name]._replicas + async def wait_for_goal(self, goal_id: GoalId) -> Optional[Exception]: return await self.goal_manager.wait_for_goal(goal_id) @@ -129,8 +134,55 @@ def get_http_proxies(self) -> Dict[NodeId, ActorHandle]: """Returns a dictionary of node ID to http_proxy actor handles.""" return self.http_state.get_http_proxy_handles() + def autoscale(self) -> None: + """Update autoscaling deployments with calculated num_replicas.""" + for deployment_name, (backend_info, + route_prefix) in self.list_deployments().items(): + backend_config = backend_info.backend_config + autoscaling_config = backend_config.autoscaling_config + + if autoscaling_config is None: + continue + + replicas = self.backend_state_manager._backend_states[ + deployment_name]._replicas + running_replicas = replicas.get([ReplicaState.RUNNING]) + + current_num_ongoing_requests = [] + for replica in running_replicas: + replica_tag = replica.replica_tag + num_ongoing_requests = ( + self.autoscaling_metrics_store.window_average( + replica_tag, + time.time() - autoscaling_config.look_back_period_s)) + if num_ongoing_requests is not None: + current_num_ongoing_requests.append(num_ongoing_requests) + + if len(current_num_ongoing_requests) == 0: + continue + + new_backend_config = backend_config.copy() + new_backend_config.num_replicas = calculate_desired_num_replicas( + autoscaling_config, current_num_ongoing_requests) + + replica_config = backend_info.replica_config + deployer_job_id = backend_info.deployer_job_id + backend_config_proto_bytes = new_backend_config.to_proto_bytes() + goal_id, updating = self.deploy( + deployment_name, + backend_config_proto_bytes, + replica_config, + version=backend_info.version, + prev_version=backend_info.version, + route_prefix=route_prefix, + deployer_job_id=deployer_job_id) + async def run_control_loop(self) -> None: while True: + try: + self.autoscale() + except Exception: + logger.exception("Exception while autoscaling deployments.") async with self.write_lock: try: self.http_state.update() @@ -218,57 +270,56 @@ async def shutdown(self) -> List[GoalId]: return goal_ids - async def deploy(self, - name: str, - backend_config_proto_bytes: bytes, - replica_config: ReplicaConfig, - python_methods: List[str], - version: Optional[str], - prev_version: Optional[str], - route_prefix: Optional[str], - deployer_job_id: "Optional[ray._raylet.JobID]" = None - ) -> Tuple[Optional[GoalId], bool]: + def deploy(self, + name: str, + backend_config_proto_bytes: bytes, + replica_config: ReplicaConfig, + version: Optional[str], + prev_version: Optional[str], + route_prefix: Optional[str], + deployer_job_id: "Optional[ray._raylet.JobID]" = None + ) -> Tuple[Optional[GoalId], bool]: if route_prefix is not None: assert route_prefix.startswith("/") backend_config = BackendConfig.from_proto_bytes( backend_config_proto_bytes) - async with self.write_lock: - if prev_version is not None: - existing_backend_info = self.backend_state_manager.get_backend( - name) - if (existing_backend_info is None - or not existing_backend_info.version): - raise ValueError( - f"prev_version '{prev_version}' is specified but " - "there is no existing deployment.") - if existing_backend_info.version != prev_version: - raise ValueError( - f"prev_version '{prev_version}' " - "does not match with the existing " - f"version '{existing_backend_info.version}'.") - backend_info = BackendInfo( - actor_def=ray.remote( - create_backend_replica( - name, replica_config.serialized_backend_def)), - version=version, - backend_config=backend_config, - replica_config=replica_config, - deployer_job_id=deployer_job_id, - start_time_ms=int(time.time() * 1000)) - - goal_id, updating = self.backend_state_manager.deploy_backend( - name, backend_info) - endpoint_info = EndpointInfo( - route=route_prefix, python_methods=python_methods) - self.endpoint_state.update_endpoint(name, endpoint_info) - return goal_id, updating + if prev_version is not None: + existing_backend_info = self.backend_state_manager.get_backend( + name) + if (existing_backend_info is None + or not existing_backend_info.version): + raise ValueError( + f"prev_version '{prev_version}' is specified but " + "there is no existing deployment.") + if existing_backend_info.version != prev_version: + raise ValueError(f"prev_version '{prev_version}' " + "does not match with the existing " + f"version '{existing_backend_info.version}'.") + backend_info = BackendInfo( + actor_def=ray.remote( + create_replica_wrapper(name, + replica_config.serialized_backend_def)), + version=version, + backend_config=backend_config, + replica_config=replica_config, + deployer_job_id=deployer_job_id, + start_time_ms=int(time.time() * 1000)) + # TODO(architkulkarni): When a deployment is redeployed, even if + # the only change was num_replicas, the start_time_ms is refreshed. + # This is probably not the desired behavior for an autoscaling + # deployment, which redeploys very often to change num_replicas. + + goal_id, updating = self.backend_state_manager.deploy_backend( + name, backend_info) + endpoint_info = EndpointInfo(route=route_prefix) + self.endpoint_state.update_endpoint(name, endpoint_info) + return goal_id, updating def delete_deployment(self, name: str) -> Optional[GoalId]: self.endpoint_state.delete_endpoint(name) - return self.backend_state_manager.delete_backend( - name, force_kill=False) + return self.backend_state_manager.delete_backend(name) def get_deployment_info(self, name: str) -> Tuple[BackendInfo, str]: """Get the current information about a deployment. diff --git a/python/ray/serve/endpoint_state.py b/python/ray/serve/endpoint_state.py index 6483f7355ff0e..5bba277001c54 100644 --- a/python/ray/serve/endpoint_state.py +++ b/python/ray/serve/endpoint_state.py @@ -79,7 +79,6 @@ def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]: for endpoint, info in self._endpoints.items(): endpoints[endpoint] = { "route": info.route, - "python_methods": info.python_methods, } return endpoints diff --git a/python/ray/serve/examples/doc/conda_env.py b/python/ray/serve/examples/doc/conda_env.py index c431964bb0772..1607cf7d60e37 100644 --- a/python/ray/serve/examples/doc/conda_env.py +++ b/python/ray/serve/examples/doc/conda_env.py @@ -1,27 +1,28 @@ import requests from ray import serve -import tensorflow as tf serve.start() @serve.deployment -def tf_version(request): - return ("Tensorflow " + tf.__version__) +def requests_version(request): + return requests.__version__ -tf_version.options( - name="tf1", ray_actor_options={ +requests_version.options( + name="25", + ray_actor_options={ "runtime_env": { - "conda": "ray-tf1" + "pip": ["ray[serve]", "requests==2.25.1"] } }).deploy() -tf_version.options( - name="tf2", ray_actor_options={ +requests_version.options( + name="26", + ray_actor_options={ "runtime_env": { - "conda": "ray-tf2" + "pip": ["ray[serve]", "requests==2.26.0"] } }).deploy() -print(requests.get("http://127.0.0.1:8000/tf1").text) # Tensorflow 1.15.0 -print(requests.get("http://127.0.0.1:8000/tf2").text) # Tensorflow 2.3.0 +assert requests.get("http://127.0.0.1:8000/25").text == "2.25.1" +assert requests.get("http://127.0.0.1:8000/26").text == "2.26.0" diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 7c315f66605f4..340be1f987a7c 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -1,7 +1,7 @@ import asyncio import concurrent.futures from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union, Coroutine +from typing import Dict, Optional, Union, Coroutine import threading from enum import Enum @@ -75,14 +75,12 @@ def __init__( endpoint_name: EndpointTag, handle_options: Optional[HandleOptions] = None, *, - known_python_methods: List[str] = [], _router: Optional[Router] = None, _internal_pickled_http_request: bool = False, ): self.controller_handle = controller_handle self.endpoint_name = endpoint_name self.handle_options = handle_options or HandleOptions() - self.known_python_methods = known_python_methods self.handle_tag = f"{self.endpoint_name}#{get_random_letters()}" self._pickled_http_request = _internal_pickled_http_request @@ -181,21 +179,11 @@ def __reduce__(self): "controller_handle": self.controller_handle, "endpoint_name": self.endpoint_name, "handle_options": self.handle_options, - "known_python_methods": self.known_python_methods, "_internal_pickled_http_request": self._pickled_http_request, } return lambda kwargs: RayServeHandle(**kwargs), (serialized_data, ) def __getattr__(self, name): - if name not in self.known_python_methods: - raise AttributeError( - f"ServeHandle for endpoint {self.endpoint_name} doesn't have " - f"python method {name}. If you used the " - f"get_handle('{self.endpoint_name}', missing_ok=True) flag, " - f"Serve cannot know all methods for {self.endpoint_name}. " - "You can set the method manually via " - f"handle.options(method_name='{name}').remote().") - return self.options(method_name=name) @@ -237,7 +225,6 @@ def __reduce__(self): "controller_handle": self.controller_handle, "endpoint_name": self.endpoint_name, "handle_options": self.handle_options, - "known_python_methods": self.known_python_methods, "_internal_pickled_http_request": self._pickled_http_request, } return lambda kwargs: RayServeSyncHandle(**kwargs), (serialized_data, ) diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index 7eedc17fcfd5a..e129f5d60cab5 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -259,8 +259,11 @@ def __init__(self, port: int, controller_name: str, controller_namespace: str, - http_middlewares: List[ - "starlette.middleware.Middleware"] = []): # noqa: F821 + http_middlewares: Optional[List[ + "starlette.middleware.Middleware"]] = None): # noqa: F821 + if http_middlewares is None: + http_middlewares = [] + self.host = host self.port = port diff --git a/python/ray/serve/long_poll.py b/python/ray/serve/long_poll.py index b1133adb5a251..9d5a31bf86e6b 100644 --- a/python/ray/serve/long_poll.py +++ b/python/ray/serve/long_poll.py @@ -103,13 +103,14 @@ def _process_update(self, updates: Dict[str, UpdatedObject]): "Shutting down.") return + if isinstance(updates, ConnectionError): + logger.warning("LongPollClient connection failed, shutting down.") + return + if isinstance(updates, (ray.exceptions.RayTaskError)): - # This can happen during shutdown where the controller doesn't - # contain this key, we will just repull. - # NOTE(simon): should we repull or just wait in the long poll - # host? - if not isinstance(updates.as_instanceof_cause(), ValueError): - logger.error("LongPollHost errored\n" + updates.traceback_str) + # Some error happened in the controller. It could be a bug or some + # undesired state. + logger.error("LongPollHost errored\n" + updates.traceback_str) self._poll_next() return @@ -167,22 +168,21 @@ async def listen_for_change( until there's one updates. """ watched_keys = keys_to_snapshot_ids.keys() - nonexistent_keys = set(watched_keys) - set(self.snapshot_ids.keys()) - if len(nonexistent_keys) > 0: - raise ValueError(f"Keys not found: {nonexistent_keys}.") + existent_keys = set(watched_keys).intersection( + set(self.snapshot_ids.keys())) - # 2. If there are any outdated keys (by comparing snapshot ids) - # return immediately. + # If there are any outdated keys (by comparing snapshot ids) + # return immediately. client_outdated_keys = { key: UpdatedObject(self.object_snapshots[key], self.snapshot_ids[key]) - for key in watched_keys + for key in existent_keys if self.snapshot_ids[key] != keys_to_snapshot_ids[key] } if len(client_outdated_keys) > 0: return client_outdated_keys - # 3. Otherwise, register asyncio events to be waited. + # Otherwise, register asyncio events to be waited. async_task_to_watched_keys = {} for key in watched_keys: # Create a new asyncio event for this key diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/replica.py similarity index 91% rename from python/ray/serve/backend_worker.py rename to python/ray/serve/replica.py index a049bdfac3a84..cc90ada23fd7e 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/replica.py @@ -15,12 +15,11 @@ from ray.serve.autoscaling_metrics import start_metrics_pusher from ray.serve.common import BackendTag, ReplicaTag +from ray.serve.config import BackendConfig from ray.serve.http_util import ASGIHTTPSender from ray.serve.utils import parse_request_item, _get_logger from ray.serve.exceptions import RayServeException from ray.util import metrics -from ray.serve.config import BackendConfig -from ray.serve.long_poll import LongPollClient, LongPollNamespace from ray.serve.router import Query, RequestMetadata from ray.serve.constants import ( BACKEND_RECONFIGURE_METHOD, @@ -32,7 +31,7 @@ logger = _get_logger() -def create_backend_replica(name: str, serialized_backend_def: bytes): +def create_replica_wrapper(name: str, serialized_backend_def: bytes): """Creates a replica class wrapping the provided function or class. This approach is picked over inheritance to avoid conflict between user @@ -43,7 +42,7 @@ def create_backend_replica(name: str, serialized_backend_def: bytes): # TODO(architkulkarni): Add type hints after upgrading cloudpickle class RayServeWrappedReplica(object): async def __init__(self, backend_tag, replica_tag, init_args, - backend_config_proto_bytes: bytes, + init_kwargs, backend_config_proto_bytes: bytes, version: BackendVersion, controller_name: str, detached: bool): backend = cloudpickle.loads(serialized_backend_def) @@ -72,7 +71,8 @@ async def __init__(self, backend_tag, replica_tag, init_args, # This allows backends to define an async __init__ method # (required for FastAPI backend definition). _callable = backend.__new__(backend) - await sync_to_async(_callable.__init__)(*init_args) + await sync_to_async(_callable.__init__)(*init_args, + **init_kwargs) # Setting the context again to update the servable_object. ray.serve.api._set_internal_replica_context( backend_tag, @@ -149,8 +149,6 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, self.replica_tag = replica_tag self.callable = _callable self.is_function = is_function - - self.backend_config = backend_config self.user_config = user_config self.version = version @@ -166,16 +164,6 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, "replica": self.replica_tag }) - self.loop = asyncio.get_event_loop() - self.long_poll_client = LongPollClient( - controller_handle, - { - (LongPollNamespace.BACKEND_CONFIGS, self.backend_tag): self. - _update_backend_configs, - }, - call_in_event_loop=self.loop, - ) - self.error_counter = metrics.Counter( "serve_deployment_error_counter", description=("The number of exceptions that have " @@ -217,6 +205,9 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, self.restart_counter.inc() + self._shutdown_wait_loop_s = ( + backend_config.graceful_shutdown_wait_loop_s) + if backend_config.autoscaling_config: config = backend_config.autoscaling_config start_metrics_pusher( @@ -240,10 +231,19 @@ def _collect_autoscaling_metrics(self): def get_runner_method(self, request_item: Query) -> Callable: method_name = request_item.metadata.call_method if not hasattr(self.callable, method_name): - raise RayServeException("Backend doesn't have method {} " - "which is specified in the request. " - "The available methods are {}".format( - method_name, dir(self.callable))) + # Filter to methods that don't start with '__' prefix. + def callable_method_filter(attr): + if attr.startswith("__"): + return False + elif not callable(getattr(self.callable, attr)): + return False + + return True + + methods = list(filter(callable_method_filter, dir(self.callable))) + raise RayServeException(f"Tried to call a method '{method_name}' " + "that does not exist. Available methods: " + f"{methods}.") if self.is_function: return self.callable return getattr(self.callable, method_name) @@ -309,9 +309,6 @@ async def reconfigure(self, getattr(self.callable, BACKEND_RECONFIGURE_METHOD)) await reconfigure_method(user_config) - def _update_backend_configs(self, new_config_bytes: bytes) -> None: - self.backend_config = BackendConfig.from_proto_bytes(new_config_bytes) - async def handle_request(self, request: Query) -> asyncio.Future: request.tick_enter_replica = time.time() logger.debug("Replica {} received request {}".format( @@ -341,18 +338,17 @@ async def prepare_for_shutdown(self): Trigger a graceful shutdown protocol that will wait for all the queued tasks to be completed and return to the controller. """ - sleep_time = self.backend_config.experimental_graceful_shutdown_wait_loop_s # noqa: E501 while True: # Sleep first because we want to make sure all the routers receive # the notification to remove this replica first. - await asyncio.sleep(sleep_time) + await asyncio.sleep(self._shutdown_wait_loop_s) if self.num_ongoing_requests == 0: break else: logger.info( - f"Waiting for an additional {sleep_time}s to shut down " - f"because there are {self.num_ongoing_requests} " - "ongoing requests.") + "Waiting for an additional " + f"{self._shutdown_wait_loop_s}s to shut down because " + f"there are {self.num_ongoing_requests} ongoing requests.") # Explicitly call the del method to trigger clean up. # We set the del method to noop after succssifully calling it so the diff --git a/python/ray/serve/tests/conftest.py b/python/ray/serve/tests/conftest.py index 1e635dd1b647a..36fdba0d5b7cc 100644 --- a/python/ray/serve/tests/conftest.py +++ b/python/ray/serve/tests/conftest.py @@ -9,6 +9,13 @@ serve.controller._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5 +@pytest.fixture +def ray_shutdown(): + yield + serve.shutdown() + ray.shutdown() + + @pytest.fixture(scope="session") def _shared_serve_instance(): # Note(simon): diff --git a/python/ray/serve/tests/test_advanced.py b/python/ray/serve/tests/test_advanced.py index 74287ed358bef..03f606e58fcbd 100644 --- a/python/ray/serve/tests/test_advanced.py +++ b/python/ray/serve/tests/test_advanced.py @@ -9,12 +9,11 @@ def test_serve_forceful_shutdown(serve_instance): - @serve.deployment + @serve.deployment(_graceful_shutdown_timeout_s=0.1) def sleeper(): while True: time.sleep(1000) - sleeper._config.experimental_graceful_shutdown_timeout_s = 0.1 sleeper.deploy() handle = sleeper.get_handle() @@ -28,14 +27,15 @@ def sleeper(): def test_serve_graceful_shutdown(serve_instance): signal = SignalActor.remote() - @serve.deployment(name="wait", max_concurrent_queries=10) + @serve.deployment( + name="wait", + max_concurrent_queries=10, + _graceful_shutdown_timeout_s=1000, + _graceful_shutdown_wait_loop_s=0.5) class Wait: async def __call__(self, signal_actor): await signal_actor.wait.remote() - return "" - Wait._config.experimental_graceful_shutdown_wait_loop_s = 0.5 - Wait._config.experimental_graceful_shutdown_timeout_s = 1000 Wait.deploy() handle = Wait.get_handle() refs = [handle.remote(signal) for _ in range(10)] diff --git a/python/ray/serve/tests/test_autoscaling_metrics.py b/python/ray/serve/tests/test_autoscaling_metrics.py index e641f515d372d..d8a92d8a28b7a 100644 --- a/python/ray/serve/tests/test_autoscaling_metrics.py +++ b/python/ray/serve/tests/test_autoscaling_metrics.py @@ -59,20 +59,20 @@ def test_e2e(serve_instance): "min_replicas": 1, "max_replicas": 1 }, - max_concurrent_queries=1000) + # We will send over a lot of queries. This will make sure replicas are + # killed quickly during cleanup. + _graceful_shutdown_timeout_s=1, + max_concurrent_queries=1000, + version="v1") class A: def __call__(self): time.sleep(0.5) - # We will send over a lot of queries. This will make sure replicas are - # killed quickly during cleanup. - A._config.experimental_graceful_shutdown_timeout_s = 1 - A.deploy() handle = A.get_handle() [handle.remote() for _ in range(100)] - # Wait for metrics to propogate + # Wait for metrics to propagate def get_data(): return ray.get(serve_instance._controller. _dump_autoscaling_metrics_for_testing.remote()) diff --git a/python/ray/serve/tests/test_autoscaling_policy.py b/python/ray/serve/tests/test_autoscaling_policy.py index 56fb72ac4eea1..e72c2f68b65ce 100644 --- a/python/ray/serve/tests/test_autoscaling_policy.py +++ b/python/ray/serve/tests/test_autoscaling_policy.py @@ -1,3 +1,11 @@ +import sys +import time +import pytest + +import ray +from ray import serve +from ray._private.test_utils import wait_for_condition +from ray.serve.backend_state import ReplicaState from ray.serve.config import AutoscalingConfig from ray.serve.autoscaling_policy import calculate_desired_num_replicas @@ -71,3 +79,47 @@ def test_smoothing_factor(self): autoscaling_config=config, current_num_ongoing_requests=num_ongoing_requests) assert 5 <= desired_num_replicas <= 8 # 10 + 0.5 * (2.5 - 10) = 6.25 + + +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +def test_e2e_basic_scale_up_down(serve_instance): + """Send 100 requests and check that we autoscale up, and then back down.""" + + @serve.deployment( + _autoscaling_config={ + "metrics_interval_s": 0.1, + "min_replicas": 1, + "max_replicas": 2, + "look_back_period_s": 0.2 + }, + # We will send over a lot of queries. This will make sure replicas are + # killed quickly during cleanup. + _graceful_shutdown_timeout_s=1, + max_concurrent_queries=1000, + version="v1") + class A: + def __call__(self): + time.sleep(1) + + A.deploy() + handle = A.get_handle() + [handle.remote() for _ in range(100)] + + controller = serve_instance._controller + + def get_num_running_replicas(): + replicas = ray.get( + controller._dump_replica_states_for_testing.remote("A")) + running_replicas = replicas.get([ReplicaState.RUNNING]) + return len(running_replicas) + + wait_for_condition(lambda: get_num_running_replicas() >= 2) + + # As the queue is drained, we should scale back down. + wait_for_condition(lambda: get_num_running_replicas() <= 1) + + +if __name__ == "__main__": + import sys + import pytest + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_backend_state.py b/python/ray/serve/tests/test_backend_state.py index aa31dc6a9d82a..0112868821388 100644 --- a/python/ray/serve/tests/test_backend_state.py +++ b/python/ray/serve/tests/test_backend_state.py @@ -1,3 +1,5 @@ +import os +import sys import time from typing import Any, Dict, List, Optional, Tuple from unittest.mock import patch, Mock @@ -181,6 +183,7 @@ def set_starting_version(self, version: BackendVersion): def start(self, backend_info: BackendInfo, version: BackendVersion): self.started = True self.version = version + self.backend_info = backend_info def update_user_config(self, user_config: Any): self.started = True @@ -218,6 +221,7 @@ def available_resources(self) -> Dict[str, float]: def graceful_stop(self) -> None: assert self.started self.stopped = True + return self.backend_info.backend_config.graceful_shutdown_timeout_s def check_stopped(self) -> bool: return self.done_stopping @@ -526,9 +530,6 @@ def test_create_delete_single_replica(mock_backend_state): # Now the replica should be marked running. backend_state.update() check_counts(backend_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - - # TODO(edoakes): can we remove this extra update period for completing it? - backend_state.update() assert goal_manager.check_complete(create_goal) # Removing the replica should transition it to stopping. @@ -542,12 +543,9 @@ def test_create_delete_single_replica(mock_backend_state): # Once it's done stopping, replica should be removed. replica = backend_state._replicas.get()[0] replica._actor.set_done_stopping() - backend_state.update() - check_counts(backend_state, total=0) - - # TODO(edoakes): can we remove this extra update period for completing it? deleted = backend_state.update() assert deleted + check_counts(backend_state, total=0) assert goal_manager.check_complete(delete_goal) assert replica._actor.cleaned_up @@ -557,7 +555,7 @@ def test_force_kill(mock_backend_state): grace_period_s = 10 b_info_1, b_version_1 = backend_info( - experimental_graceful_shutdown_timeout_s=grace_period_s) + graceful_shutdown_timeout_s=grace_period_s) # Create and delete the backend. backend_state.deploy(b_info_1) @@ -571,8 +569,8 @@ def test_force_kill(mock_backend_state): check_counts(backend_state, total=1, by_state=[(ReplicaState.STOPPING, 1)]) assert backend_state._replicas.get()[0]._actor.stopped - backend_state.update() - backend_state.update() + for _ in range(10): + backend_state.update() # force_stop shouldn't be called until after the timer. assert not backend_state._replicas.get()[0]._actor.force_stopped_counter @@ -597,12 +595,9 @@ def test_force_kill(mock_backend_state): # Once the replica is done stopping, it should be removed. replica = backend_state._replicas.get()[0] replica._actor.set_done_stopping() - backend_state.update() - check_counts(backend_state, total=0) - - # TODO(edoakes): can we remove this extra update period for completing it? deleted = backend_state.update() assert deleted + check_counts(backend_state, total=0) assert goal_manager.check_complete(delete_goal) assert replica._actor.cleaned_up @@ -644,8 +639,6 @@ def test_redeploy_same_version(mock_backend_state): version=b_version_1, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - - backend_state.update() assert goal_manager.check_complete(goal_1) # Test redeploying after the initial deployment has finished. @@ -727,12 +720,10 @@ def test_redeploy_no_version(mock_backend_state): states=[ReplicaState.STARTING])[0]._actor.set_ready() check_counts(backend_state, total=1, by_state=[(ReplicaState.STARTING, 1)]) - backend_state.update() - check_counts(backend_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - deleted = backend_state.update() - assert goal_manager.check_complete(goal_3) assert not deleted + check_counts(backend_state, total=1, by_state=[(ReplicaState.RUNNING, 1)]) + assert goal_manager.check_complete(goal_3) def test_redeploy_new_version(mock_backend_state): @@ -826,16 +817,14 @@ def test_redeploy_new_version(mock_backend_state): total=1, by_state=[(ReplicaState.STARTING, 1)]) - backend_state.update() + deleted = backend_state.update() + assert not deleted check_counts( backend_state, version=b_version_3, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - - deleted = backend_state.update() assert goal_manager.check_complete(goal_3) - assert not deleted def test_deploy_new_config_same_version(mock_backend_state): @@ -855,7 +844,6 @@ def test_deploy_new_config_same_version(mock_backend_state): version=b_version_1, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - backend_state.update() assert goal_manager.check_complete(goal_id) # Update to a new config without changing the version. @@ -886,8 +874,6 @@ def test_deploy_new_config_same_version(mock_backend_state): version=b_version_2, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - - backend_state.update() assert goal_manager.check_complete(goal_id) @@ -907,7 +893,6 @@ def test_deploy_new_config_new_version(mock_backend_state): version=b_version_1, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - backend_state.update() assert goal_manager.check_complete(create_goal) # Update to a new config and a new version. @@ -945,8 +930,6 @@ def test_deploy_new_config_new_version(mock_backend_state): version=b_version_2, total=1, by_state=[(ReplicaState.RUNNING, 1)]) - - backend_state.update() assert goal_manager.check_complete(update_goal) @@ -966,8 +949,6 @@ def test_initial_deploy_no_throttling(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() - backend_state.update() - # Check that the new replicas have started. backend_state.update() check_counts( @@ -994,8 +975,6 @@ def test_new_version_deploy_throttling(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() - backend_state.update() - # Check that the new replicas have started. backend_state.update() check_counts( @@ -1236,8 +1215,6 @@ def test_new_version_deploy_throttling(mock_backend_state): version=b_version_2, total=10, by_state=[(ReplicaState.RUNNING, 10)]) - - backend_state.update() assert goal_manager.check_complete(goal_2) @@ -1258,8 +1235,6 @@ def test_reconfigure_throttling(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() - backend_state.update() - # Check that the new replicas have started. backend_state.update() check_counts(backend_state, total=2, by_state=[(ReplicaState.RUNNING, 2)]) @@ -1318,8 +1293,6 @@ def test_reconfigure_throttling(mock_backend_state): version=b_version_2, total=2, by_state=[(ReplicaState.RUNNING, 2)]) - - backend_state.update() assert goal_manager.check_complete(goal_1) @@ -1341,8 +1314,6 @@ def test_new_version_and_scale_down(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() - backend_state.update() - # Check that the new replicas have started. backend_state.update() check_counts( @@ -1479,8 +1450,6 @@ def test_new_version_and_scale_down(mock_backend_state): version=b_version_2, total=2, by_state=[(ReplicaState.RUNNING, 2)]) - - backend_state.update() assert goal_manager.check_complete(goal_2) @@ -1501,8 +1470,6 @@ def test_new_version_and_scale_up(mock_backend_state): for replica in backend_state._replicas.get(): replica._actor.set_ready() - backend_state.update() - # Check that the new replicas have started. backend_state.update() check_counts(backend_state, total=2, by_state=[(ReplicaState.RUNNING, 2)]) @@ -1610,8 +1577,6 @@ def test_health_check(mock_backend_state): # Check that the new replicas have started. backend_state.update() check_counts(backend_state, total=2, by_state=[(ReplicaState.RUNNING, 2)]) - - backend_state.update() assert goal_manager.check_complete(goal_1) backend_state.update() @@ -1859,6 +1824,9 @@ def mock_backend_state_manager( yield backend_state_manager, timer, goal_manager # Clear checkpoint at the end of each test kv_store.delete(CHECKPOINT_KEY) + if sys.platform != "win32": + # This line fails on windows with a PermissionError. + os.remove("test_kv_store.db") def test_shutdown(mock_backend_state_manager): @@ -1870,7 +1838,9 @@ def test_shutdown(mock_backend_state_manager): tag = "test" - b_info_1, b_version_1 = backend_info() + grace_period_s = 10 + b_info_1, b_version_1 = backend_info( + graceful_shutdown_timeout_s=grace_period_s) create_goal, updating = backend_state_manager.deploy_backend(tag, b_info_1) backend_state = backend_state_manager._backend_states[tag] @@ -1889,25 +1859,21 @@ def test_shutdown(mock_backend_state_manager): shutdown_goal = backend_state_manager.shutdown()[0] + timer.advance(grace_period_s + 0.1) backend_state_manager.update() check_counts(backend_state, total=1, by_state=[(ReplicaState.STOPPING, 1)]) assert backend_state._replicas.get()[0]._actor.stopped - assert backend_state._replicas.get()[0]._actor.force_stopped_counter == 1 assert not backend_state._replicas.get()[0]._actor.cleaned_up assert not goal_manager.check_complete(shutdown_goal) # Once it's done stopping, replica should be removed. replica = backend_state._replicas.get()[0] replica._actor.set_done_stopping() - backend_state.update() - check_counts(backend_state, total=0) - - # TODO(edoakes): can we remove this extra update period for completing it? backend_state_manager.update() + check_counts(backend_state, total=0) assert goal_manager.check_complete(shutdown_goal) assert replica._actor.cleaned_up - assert len(backend_state_manager._backend_states) == 0 @@ -1974,5 +1940,4 @@ def test_resume_backend_state_from_replica_tags(mock_backend_state_manager): if __name__ == "__main__": - import sys sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_config.py b/python/ray/serve/tests/test_config.py index dd30aeab0f77f..8c71cf8ae3a91 100644 --- a/python/ray/serve/tests/test_config.py +++ b/python/ray/serve/tests/test_config.py @@ -52,6 +52,8 @@ def function(_): # Check ray_actor_options validation. ReplicaConfig( Class, + tuple(), + dict(), ray_actor_options={ "num_cpus": 1.0, "num_gpus": 10, diff --git a/python/ray/serve/tests/test_deploy.py b/python/ray/serve/tests/test_deploy.py index d593081a43a9b..ab5ab2e2f3d75 100644 --- a/python/ray/serve/tests/test_deploy.py +++ b/python/ray/serve/tests/test_deploy.py @@ -10,6 +10,7 @@ import ray from ray._private.test_utils import SignalActor, wait_for_condition from ray import serve +from ray.serve.exceptions import RayServeException from ray.serve.utils import get_random_letters @@ -676,8 +677,8 @@ def b(self, *args): assert ray.get(handle.options(method_name="b").remote()) == "hello" # New code path assert ray.get(handle.b.remote()) == "hello" - with pytest.raises(AttributeError): - handle.c.remote() + with pytest.raises(RayServeException): + ray.get(handle.c.remote()) def test_init_args(serve_instance): @@ -733,6 +734,58 @@ def check(*args): check(10, 11, 12) +def test_init_kwargs(serve_instance): + with pytest.raises(TypeError): + + @serve.deployment(init_kwargs=[1, 2, 3]) + class BadInitArgs: + pass + + @serve.deployment(init_kwargs={"a": 1, "b": 2}) + class D: + def __init__(self, **kwargs): + self._kwargs = kwargs + + def get_kwargs(self, *args): + return self._kwargs + + D.deploy() + handle = D.get_handle() + + def check(kwargs): + assert ray.get(handle.get_kwargs.remote()) == kwargs + + # Basic sanity check. + check({"a": 1, "b": 2}) + + # Check passing args to `.deploy()`. + D.deploy(a=3, b=4) + check({"a": 3, "b": 4}) + + # Passing args to `.deploy()` shouldn't override those passed in decorator. + D.deploy() + check({"a": 1, "b": 2}) + + # Check setting with `.options()`. + new_D = D.options(init_kwargs={"c": 8, "d": 10}) + new_D.deploy() + check({"c": 8, "d": 10}) + + # Should not have changed old deployment object. + D.deploy() + check({"a": 1, "b": 2}) + + # Check that args are only updated on version change. + D.options(version="1").deploy() + check({"a": 1, "b": 2}) + + D.options(version="1").deploy(c=10, d=11) + check({"a": 1, "b": 2}) + + D.options(version="2").deploy(c=10, d=11) + check({"c": 10, "d": 11}) + + def test_input_validation(): name = "test" diff --git a/python/ray/serve/tests/test_get_deployment.py b/python/ray/serve/tests/test_get_deployment.py index cb1d6c9484e31..1f6968abe4974 100644 --- a/python/ray/serve/tests/test_get_deployment.py +++ b/python/ray/serve/tests/test_get_deployment.py @@ -116,6 +116,37 @@ def __call__(self, *arg): assert pid3 != pid2 +def test_init_kwargs(serve_instance): + name = "test" + + @serve.deployment(name=name) + class D: + def __init__(self, *, val=None): + assert val is not None + self._val = val + + def __call__(self, *arg): + return self._val, os.getpid() + + D.deploy(val="1") + val1, pid1 = ray.get(D.get_handle().remote()) + assert val1 == "1" + + del D + + D2 = serve.get_deployment(name=name) + D2.deploy() + val2, pid2 = ray.get(D2.get_handle().remote()) + assert val2 == "1" + assert pid2 != pid1 + + D2 = serve.get_deployment(name=name) + D2.deploy(val="2") + val3, pid3 = ray.get(D2.get_handle().remote()) + assert val3 == "2" + assert pid3 != pid2 + + def test_scale_replicas(serve_instance): name = "test" diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index 95c55aba35b3e..360fb3336b247 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -1,9 +1,10 @@ +import concurrent.futures import pytest import requests import ray -import concurrent.futures from ray import serve +from ray.serve.exceptions import RayServeException @pytest.mark.asyncio @@ -167,6 +168,30 @@ def call(): ray.get(obj_ref) +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False]) +async def test_nonexistent_method(serve_instance, sync): + @serve.deployment + class A: + def exists(self): + pass + + A.deploy() + handle = A.get_handle(sync=sync) + + if sync: + obj_ref = handle.does_not_exist.remote() + else: + obj_ref = await handle.does_not_exist.remote() + + with pytest.raises(RayServeException) as excinfo: + ray.get(obj_ref) + + exception_string = str(excinfo.value) + assert "'does_not_exist'" in exception_string + assert "Available methods: ['exists']" in exception_string + + if __name__ == "__main__": import sys import pytest diff --git a/python/ray/serve/tests/test_long_poll.py b/python/ray/serve/tests/test_long_poll.py index 79cf0c841ea35..2081e705d976e 100644 --- a/python/ray/serve/tests/test_long_poll.py +++ b/python/ray/serve/tests/test_long_poll.py @@ -37,6 +37,20 @@ def test_host_standalone(serve_instance): assert "key_2" in result +def test_long_poll_wait_for_keys(serve_instance): + # Variation of the basic case, but the keys are requests before any values + # are set. + host = ray.remote(LongPollHost).remote() + object_ref = host.listen_for_change.remote({"key_1": -1, "key_2": -1}) + ray.get(host.notify_changed.remote("key_1", 999)) + ray.get(host.notify_changed.remote("key_2", 999)) + + # We should be able to get the one of the result immediately + result: Dict[str, UpdatedObject] = ray.get(object_ref) + assert set(result.keys()).issubset({"key_1", "key_2"}) + assert {v.object_snapshot for v in result.values()} == {999} + + def test_long_poll_restarts(serve_instance): @ray.remote( max_restarts=-1, diff --git a/python/ray/serve/tests/test_ray_client.py b/python/ray/serve/tests/test_ray_client.py index 7bc2d54aad388..db640970eedbe 100644 --- a/python/ray/serve/tests/test_ray_client.py +++ b/python/ray/serve/tests/test_ray_client.py @@ -126,7 +126,7 @@ def hello(request): @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows") -def test_quickstart_task(serve_with_client): +def test_quickstart_counter(serve_with_client): serve.start() @serve.deployment @@ -140,10 +140,13 @@ def __call__(self, *args): # Deploy our class. Counter.deploy() + print("deploy finished") # Query our endpoint in two different ways: from HTTP and from Python. assert requests.get("http://127.0.0.1:8000/Counter").json() == {"count": 1} + print("query 1 finished") assert ray.get(Counter.get_handle().remote()) == {"count": 2} + print("query 2 finished") if __name__ == "__main__": diff --git a/python/ray/serve/tests/test_regression.py b/python/ray/serve/tests/test_regression.py index e4d519bc06ffc..9ac205803e492 100644 --- a/python/ray/serve/tests/test_regression.py +++ b/python/ray/serve/tests/test_regression.py @@ -71,7 +71,7 @@ async def __call__(self, _request): assert result.json() == 100.0 -def test_backend_worker_memory_growth(serve_instance): +def test_replica_memory_growth(serve_instance): # https://github.com/ray-project/ray/issues/12395 @serve.deployment(name="model") def gc_unreachable_objects(*args): diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index ce8183b2d8577..1c6df064247f5 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -27,13 +27,6 @@ import ray._private.gcs_utils as gcs_utils -@pytest.fixture -def ray_shutdown(): - yield - serve.shutdown() - ray.shutdown() - - @pytest.fixture def ray_cluster(): cluster = Cluster() @@ -102,7 +95,7 @@ def test_detached_deployment(ray_cluster): # https://github.com/ray-project/ray/issues/11437 cluster = ray_cluster - head_node = cluster.add_node(node_ip_address="127.0.0.1", num_cpus=6) + head_node = cluster.add_node(num_cpus=6) # Create first job, check we can run a simple serve endpoint ray.init(head_node.address, namespace="serve") diff --git a/python/ray/sgd/__init__.py b/python/ray/sgd/__init__.py index c5d4677aa041e..d5f8ec4c0d6f1 100644 --- a/python/ray/sgd/__init__.py +++ b/python/ray/sgd/__init__.py @@ -1,2 +1 @@ -from ray.util.sgd.v2 import * # noqa: F401, F403 -from ray.util.sgd.v2.callbacks import JsonLoggerCallback, TBXLoggerCallback # noqa: E501, F401, F403 +from ray.util.sgd.v2 import * # noqa: F401, F403 diff --git a/python/ray/sgd/callbacks.py b/python/ray/sgd/callbacks.py new file mode 100644 index 0000000000000..9b85815190b9b --- /dev/null +++ b/python/ray/sgd/callbacks.py @@ -0,0 +1 @@ +from ray.util.sgd.v2.callbacks import * # noqa: E501, F401, F403 diff --git a/python/ray/state.py b/python/ray/state.py index 3c2f2185caffb..b074fd4062641 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -1,7 +1,6 @@ from collections import defaultdict import json import logging -import os import ray @@ -50,10 +49,6 @@ def _check_connected(self): # _really_init_global_state should have set self.global_state_accessor if self.global_state_accessor is None: - if os.environ.get("RAY_ENABLE_AUTO_CONNECT", "") != "0": - ray.client().connect() - # Retry connect! - return self._check_connected() raise ray.exceptions.RaySystemError( "Ray has not been started yet. You can start Ray with " "'ray.init()'.") @@ -720,6 +715,7 @@ def _live_node_ids(self): def _available_resources_per_node(self): """Returns a dictionary mapping node id to avaiable resources.""" + self._check_connected() available_resources_by_id = {} all_available_resources = \ @@ -811,7 +807,7 @@ def next_job_id(): @DeveloperAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def nodes(): """Get a list of the nodes in the cluster (for debugging only). @@ -875,7 +871,7 @@ def actors(actor_id=None): @DeveloperAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def timeline(filename=None): """Return a list of profiling events that can viewed as a timeline. @@ -917,7 +913,7 @@ def object_transfer_timeline(filename=None): @DeveloperAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def cluster_resources(): """Get the current total cluster resources. @@ -932,7 +928,7 @@ def cluster_resources(): @DeveloperAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def available_resources(): """Get the current available cluster resources. diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index f854f00e560e7..bcdb790d6d723 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -48,6 +48,7 @@ py_test_module_list( files = [ "test_client.py", "test_client_builder.py", + "test_client_compat.py", "test_client_init.py", "test_client_multi.py", "test_client_proxy.py", @@ -77,12 +78,12 @@ py_test_module_list( "test_placement_group.py", "test_placement_group_2.py", "test_placement_group_3.py", - "test_placement_group_mini_integration.py", "test_ray_init.py", "test_reconstruction.py", "test_reference_counting.py", "test_resource_demand_scheduler.py", "test_runtime_env_env_vars.py", + "test_runtime_env_plugin.py", "test_runtime_env_fork_process.py", "test_serialization.py", "test_shuffle.py", @@ -167,6 +168,7 @@ py_test_module_list( "test_failure_4.py", "test_object_spilling.py", "test_plasma_unlimited.py", + "test_placement_group_mini_integration.py", ], size = "large", extra_srcs = SRCS, @@ -300,6 +302,14 @@ py_test( deps = ["//:ray_lib"], ) +py_test( + name = "test_runtime_env_validation", + size = "small", + srcs = SRCS + ["test_runtime_env_validation.py"], + tags = ["exclusive", "team:serve"], + deps = ["//:ray_lib"], +) + # TODO(ekl) we can't currently support tagging these as flaky since there's # no way to filter by both flaky and client mode tests in bazel. py_test_module_list( diff --git a/python/ray/tests/client_test_utils.py b/python/ray/tests/client_test_utils.py index c7b0081d3274c..30c016d32bd3a 100644 --- a/python/ray/tests/client_test_utils.py +++ b/python/ray/tests/client_test_utils.py @@ -18,3 +18,20 @@ async def wait(self, should_wait=True): await self.ready_event.wait() return SignalActor + + +# See test_client::test_wrapped_actor_creation for details on usage of +# run_wrapped_actor_creation and SomeClass. +def run_wrapped_actor_creation(): + import ray + RemoteClass = ray.remote(SomeClass) + handle = RemoteClass.remote() + return ray.get(handle.ready.remote()) + + +class SomeClass: + def __init__(self): + pass + + def ready(self): + return 1 diff --git a/python/ray/tests/mock_setup_worker.py b/python/ray/tests/mock_setup_worker.py index a19a9ce22d1fd..7cd981b9ac00f 100644 --- a/python/ray/tests/mock_setup_worker.py +++ b/python/ray/tests/mock_setup_worker.py @@ -30,6 +30,9 @@ parser.add_argument( "--session-dir", type=str, help="the directory for the current session") +parser.add_argument( + "--language", type=str, help="the language type of the worker") + args, remaining_args = parser.parse_known_args() # add worker-shim-pid argument diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py index b7962ff71e44b..041e5e7bb559a 100644 --- a/python/ray/tests/test_advanced.py +++ b/python/ray/tests/test_advanced.py @@ -777,14 +777,13 @@ def method(self): # This case tests whether RequestWorkerLeaseReply carries normal task resources # when the request is rejected (due to resource preemption by normal tasks). -@pytest.mark.skip( - reason="The period of pull based resource report (10ms) is hard-coded.") +@pytest.mark.skipif(sys.platform == "win32", reason="Time out on Windows") def test_worker_lease_reply_with_resources(ray_start_cluster): cluster = ray_start_cluster cluster.add_node( memory=2000 * 1024**2, _system_config={ - "raylet_report_resources_period_milliseconds": 1000000, + "gcs_resource_report_poll_period_ms": 1000000, "gcs_actor_scheduling_enabled": True, }) node2 = cluster.add_node(memory=1000 * 1024**2) diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index a03850916328a..90d3de16dd60d 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -1,5 +1,6 @@ # coding: utf-8 import glob +import json import logging import os import sys @@ -726,20 +727,19 @@ def test_k8s_cpu(): def test_sync_job_config(shutdown_only): num_java_workers_per_process = 8 - worker_env = { - "key": "value", - } + runtime_env = {"env_vars": {"key": "value"}} ray.init( job_config=ray.job_config.JobConfig( num_java_workers_per_process=num_java_workers_per_process, - worker_env=worker_env)) + runtime_env=runtime_env)) # Check that the job config is synchronized at the driver side. job_config = ray.worker.global_worker.core_worker.get_job_config() assert (job_config.num_java_workers_per_process == num_java_workers_per_process) - assert (job_config.worker_env == worker_env) + job_runtime_env = json.loads(job_config.runtime_env.serialized_runtime_env) + assert job_runtime_env["env_vars"] == runtime_env["env_vars"] @ray.remote def get_job_config(): @@ -751,7 +751,8 @@ def get_job_config(): job_config.ParseFromString(ray.get(get_job_config.remote())) assert (job_config.num_java_workers_per_process == num_java_workers_per_process) - assert (job_config.worker_env == worker_env) + job_runtime_env = json.loads(job_config.runtime_env.serialized_runtime_env) + assert job_runtime_env["env_vars"] == runtime_env["env_vars"] def test_duplicated_arg(ray_start_cluster): diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index d428188173cbd..4cc7ee63570fe 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -1,6 +1,7 @@ import json import jsonschema import os +import re import shutil from subprocess import CalledProcessError import tempfile @@ -13,7 +14,7 @@ from collections import defaultdict from ray.autoscaler._private.commands import get_or_create_head_node from jsonschema.exceptions import ValidationError -from typing import Dict, Callable +from typing import Dict, Callable, List, Optional import ray from ray.autoscaler._private.util import prepare_config, validate_config @@ -105,42 +106,56 @@ def check_output(self, cmd): return return_string.encode() - def assert_has_call(self, ip, pattern=None, exact=None): + def assert_has_call(self, + ip: str, + pattern: Optional[str] = None, + exact: Optional[List[str]] = None): + """Checks if the given value was called by this process runner. + + NOTE: Either pattern or exact must be specified, not both! + + Args: + ip: IP address of the node that the given call was executed on. + pattern: RegEx that matches one specific call. + exact: List of strings that when joined exactly match one call. + """ with self.lock: - assert pattern or exact, \ + assert bool(pattern) ^ bool(exact), \ "Must specify either a pattern or exact match." - out = "" + debug_output = "" if pattern is not None: for cmd in self.command_history(): if ip in cmd: - out += cmd - out += "\n" - if pattern in out: - return True + debug_output += cmd + debug_output += "\n" + if re.search(pattern, cmd): + return True else: raise Exception( - f"Did not find [{pattern}] in [{out}] for ip={ip}." - f"\n\nFull output: {self.command_history()}") + f"Did not find [{pattern}] in [{debug_output}] for " + f"ip={ip}.\n\nFull output: {self.command_history()}") elif exact is not None: exact_cmd = " ".join(exact) for cmd in self.command_history(): if ip in cmd: - out += cmd - out += "\n" + debug_output += cmd + debug_output += "\n" if cmd == exact_cmd: return True raise Exception( - f"Did not find [{exact_cmd}] in [{out}] for ip={ip}." - f"\n\nFull output: {self.command_history()}") + f"Did not find [{exact_cmd}] in [{debug_output}] for " + f"ip={ip}.\n\nFull output: {self.command_history()}") - def assert_not_has_call(self, ip, pattern): + def assert_not_has_call(self, ip: str, pattern: str): + """Ensure that the given regex pattern was never called. + """ with self.lock: out = "" for cmd in self.command_history(): if ip in cmd: out += cmd out += "\n" - if pattern in out: + if re.search(pattern, out): raise Exception("Found [{}] in [{}] for {}".format( pattern, out, ip)) else: @@ -449,7 +464,10 @@ def waitFor(self, condition, num_retries=50, fail_msg=None): fail_msg = fail_msg or "Timed out waiting for {}".format(condition) raise RayTestTimeoutException(fail_msg) - def waitForNodes(self, expected, comparison=None, tag_filters={}): + def waitForNodes(self, expected, comparison=None, tag_filters=None): + if tag_filters is None: + tag_filters = {} + MAX_ITER = 50 for i in range(MAX_ITER): n = len(self.provider.non_terminated_nodes(tag_filters)) @@ -2560,8 +2578,7 @@ def testContinuousFileMounts(self): for i in [0, 1]: runner.assert_not_has_call(f"172.0.0.{i}", "setup_cmd") runner.assert_has_call( - f"172.0.0.{i}", f"172.0.0.{i}", - f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" + f"172.0.0.{i}", f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" f"{docker_mount_prefix}/home/test-folder/") def testFileMountsNonContinuous(self): @@ -2596,8 +2613,7 @@ def testFileMountsNonContinuous(self): for i in [0, 1]: runner.assert_has_call(f"172.0.0.{i}", "setup_cmd") runner.assert_has_call( - f"172.0.0.{i}", f"172.0.0.{i}", - f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" + f"172.0.0.{i}", f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" f"{docker_mount_prefix}/home/test-folder/") runner.clear_history() @@ -2640,8 +2656,7 @@ def testFileMountsNonContinuous(self): for i in [0, 1]: runner.assert_has_call(f"172.0.0.{i}", "setup_cmd") runner.assert_has_call( - f"172.0.0.{i}", f"172.0.0.{i}", - f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" + f"172.0.0.{i}", f"{file_mount_dir}/ ubuntu@172.0.0.{i}:" f"{docker_mount_prefix}/home/test-folder/") def testAutodetectResources(self): diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index de552b1fe2977..4ab08dc95e4e2 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -6,9 +6,11 @@ import queue import threading import _thread +from unittest.mock import patch import ray.util.client.server.server as ray_client_server from ray.tests.client_test_utils import create_remote_signal_actor +from ray.tests.client_test_utils import run_wrapped_actor_creation from ray.util.client.common import ClientObjectRef from ray.util.client.ray_client_helpers import connect_to_client_or_not from ray.util.client.ray_client_helpers import ray_start_client_server @@ -24,11 +26,11 @@ def test_client_context_manager(ray_start_regular_shared, connect_to_client): with connect_to_client_or_not(connect_to_client): if connect_to_client: # Client mode is on. - assert client_mode_should_convert() + assert client_mode_should_convert(auto_init=True) # We're connected to Ray client. assert ray.util.client.ray.is_connected() else: - assert not client_mode_should_convert() + assert not client_mode_should_convert(auto_init=True) assert not ray.util.client.ray.is_connected() @@ -70,20 +72,20 @@ def run(self): def test_client_mode_hook_thread_safe(ray_start_regular_shared): with ray_start_client_server(): with enable_client_mode(): - assert client_mode_should_convert() + assert client_mode_should_convert(auto_init=True) lock = threading.Lock() lock.acquire() q = queue.Queue() def disable(): with disable_client_hook(): - q.put(client_mode_should_convert()) + q.put(client_mode_should_convert(auto_init=True)) lock.acquire() - q.put(client_mode_should_convert()) + q.put(client_mode_should_convert(auto_init=True)) t = threading.Thread(target=disable) t.start() - assert client_mode_should_convert() + assert client_mode_should_convert(auto_init=True) lock.release() t.join() assert q.get( @@ -467,7 +469,7 @@ def print_on_stderr_and_stdout(s): time.sleep(1) print_on_stderr_and_stdout.remote("Hello world") time.sleep(1) - assert len(log_msgs) == 2 + assert len(log_msgs) == 2, log_msgs assert all((msg.find("Hello world") for msg in log_msgs)) @@ -648,6 +650,7 @@ def stop_server(server): @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +@patch.dict(os.environ, {"RAY_ENABLE_AUTO_CONNECT": "0"}) def test_client_gpu_ids(call_ray_stop_only): import ray ray.init(num_cpus=2) @@ -702,7 +705,42 @@ def test_object_ref_cleanup(): # See https://github.com/ray-project/ray/issues/17968 for details with ray_start_client_server(): result = run_string_as_driver(object_ref_cleanup_script) - assert result == "" + assert "Error in sys.excepthook:" not in result + assert "AttributeError: 'NoneType' object has no " not in result + assert "Exception ignored in" not in result + + +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +@pytest.mark.parametrize( + "call_ray_start", + ["ray start --head --ray-client-server-port 25552 --port 0"], + indirect=True) +def test_wrapped_actor_creation(call_ray_start): + """ + When the client schedules an actor, the server will load a separate + copy of the actor class if it's defined in a separate file. This + means that modifications to the client's copy of the actor class + aren't propagated to the server. Currently, tracing logic modifies + the signatures of actor methods to pass around metadata when ray.remote + is applied to an actor class. However, if a user does something like: + + class SomeActor: + def __init__(self): + pass + + def decorate_actor(): + RemoteActor = ray.remote(SomeActor) + ... + + Then the SomeActor class will have its signatures modified on the client + side, but not on the server side, since ray.remote was applied inside of + the function instead of directly on the actor. Note if it were directly + applied to the actor then the signature would be modified when the server + imports the class. + """ + import ray + ray.init("ray://localhost:25552") + run_wrapped_actor_creation() if __name__ == "__main__": diff --git a/python/ray/tests/test_client_compat.py b/python/ray/tests/test_client_compat.py new file mode 100644 index 0000000000000..98f4e9f4ba43d --- /dev/null +++ b/python/ray/tests/test_client_compat.py @@ -0,0 +1,33 @@ +import pytest +import sys + +import ray +try: + import pyspark # noqa +except ImportError: + pyspark = None + + +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +@pytest.mark.skipif(pyspark is None, reason="PySpark dependency not found") +@pytest.mark.parametrize( + "call_ray_start", [ + "ray start --head --num-cpus=1 --min-worker-port=0 " + "--max-worker-port=0 --port 0 --ray-client-server-port 10002", + ], + indirect=True) +def test_client_data_get(call_ray_start): + """PySpark import changes NamedTuple pickling behavior, leading + to inconpatibilities with the Ray client and Ray Data. This test + makes sure that our fix in the ClientPickler works.""" + address = call_ray_start + ip = address.split(":")[0] + + ray.util.connect(f"{ip}:10002") + + ray_pipeline = ray.data.from_items(list(range(1_000))) + ray.get(ray_pipeline.to_numpy()[0]) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_client_library_integration.py b/python/ray/tests/test_client_library_integration.py index 774f46954d045..417b31efb5e3b 100644 --- a/python/ray/tests/test_client_library_integration.py +++ b/python/ray/tests/test_client_library_integration.py @@ -14,11 +14,11 @@ def test_rllib_integration(ray_start_regular_shared): import ray.rllib.agents.dqn as dqn # Confirming the behavior of this context manager. # (Client mode hook not yet enabled.) - assert not client_mode_should_convert() + assert not client_mode_should_convert(auto_init=True) # Need to enable this for client APIs to be used. with enable_client_mode(): # Confirming mode hook is enabled. - assert client_mode_should_convert() + assert client_mode_should_convert(auto_init=True) config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy() # Run locally. @@ -38,11 +38,11 @@ def test_rllib_integration_tune(ray_start_regular_shared): with ray_start_client_server(): # Confirming the behavior of this context manager. # (Client mode hook not yet enabled.) - assert not client_mode_should_convert() + assert not client_mode_should_convert(auto_init=True) # Need to enable this for client APIs to be used. with enable_client_mode(): # Confirming mode hook is enabled. - assert client_mode_should_convert() + assert client_mode_should_convert(auto_init=True) tune.run( "DQN", config={"env": "CartPole-v1"}, diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index 03d1f34cb6582..8440268da6980 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -253,7 +253,10 @@ def test_prepare_runtime_init_req_no_modification(): """ Check that `prepare_runtime_init_req` properly extracts the JobConfig. """ - job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc") + job_config = JobConfig( + runtime_env={"env_vars": { + "KEY": "VALUE" + }}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest( init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config), @@ -273,7 +276,10 @@ def test_prepare_runtime_init_req_modified_job(): Check that `prepare_runtime_init_req` properly extracts the JobConfig and modifies it according to `ray_client_server_env_prep`. """ - job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc") + job_config = JobConfig( + runtime_env={"env_vars": { + "KEY": "VALUE" + }}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest( init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config), diff --git a/python/ray/tests/test_client_reconnect.py b/python/ray/tests/test_client_reconnect.py index b830403449ba3..0672b755f9eb1 100644 --- a/python/ray/tests/test_client_reconnect.py +++ b/python/ray/tests/test_client_reconnect.py @@ -294,6 +294,7 @@ def disconnect(middleman): disconnect_thread.join() +@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on windows") def test_valid_actor_state(): """ Repeatedly inject errors in the middle of mutating actor calls. Check @@ -311,24 +312,28 @@ def incr(self): return self.val i = 0 + # This is to prevent erroring in the initial connection logic. + started = False def fail_every_seven(_): # Inject an error every seventh time this method is called - nonlocal i + nonlocal i, started i += 1 - if i % 7 == 0: + if i % 7 == 0 and started: raise RuntimeError with start_middleman_server( on_data_response=fail_every_seven, on_task_request=fail_every_seven, on_task_response=fail_every_seven): + started = True actor = IncrActor.remote() for _ in range(100): ref = actor.incr.remote() assert ray.get(ref) == 100 +@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on windows") def test_valid_actor_state_2(): """ Do a full disconnect (cancel channel) every 11 requests. Failure diff --git a/python/ray/tests/test_dashboard.py b/python/ray/tests/test_dashboard.py index c92d9610ead84..578707baebf4a 100644 --- a/python/ray/tests/test_dashboard.py +++ b/python/ray/tests/test_dashboard.py @@ -4,14 +4,34 @@ import sys import time +import psutil import pytest import requests -from ray._private.test_utils import run_string_as_driver, wait_for_condition +from ray._private.test_utils import (run_string_as_driver, wait_for_condition, + get_error_message) import ray from ray import ray_constants +def search_agents(cluster): + all_processes = cluster.head_node.all_processes + raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0] + raylet_proc = psutil.Process(raylet_proc_info.process.pid) + + def _search_agent(processes): + for p in processes: + try: + for c in p.cmdline(): + if "dashboard/agent.py" in c: + return p + except Exception: + pass + + agent_proc = _search_agent(raylet_proc.children()) + return agent_proc + + def test_ray_start_default_port_conflict(call_ray_stop_only, shutdown_only): subprocess.check_call(["ray", "start", "--head"]) ray.init(address="auto") @@ -90,8 +110,6 @@ def test_port_conflict(call_ray_stop_only, shutdown_only): sock.close() -@pytest.mark.skipif( - sys.version_info < (3, 5, 3), reason="requires python3.5.3 or higher") def test_dashboard(shutdown_only): addresses = ray.init(include_dashboard=True, num_cpus=1) dashboard_url = addresses["webui_url"] @@ -121,8 +139,32 @@ def test_dashboard(shutdown_only): f"Dashboard output log: {out_log}\n") -if __name__ == "__main__": - import sys +@pytest.mark.parametrize( + "ray_start_cluster_head", [{ + "metrics_export_port": 6379, + "_system_config": { + "agent_restart_interval_ms": 10, + "agent_max_restart_count": 5 + } + }], + indirect=True) +def test_dashboard_agent_restart(ray_start_cluster_head, error_pubsub): + """Test that when the agent fails to start many times in a row + if the error message is suppressed correctly without spamming + the driver. + """ + # Choose a duplicated port for the agent so that it will crash. + p = error_pubsub + errors = get_error_message( + p, 1, ray_constants.DASHBOARD_AGENT_DIED_ERROR, timeout=10) + for e in errors: + assert ("There are 2 possible problems " + "if you see this error." in e.error_message) + # Make sure the agent process is not started anymore. + cluster = ray_start_cluster_head + wait_for_condition(lambda: search_agents(cluster) is None) + +if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_distributed_sort.py b/python/ray/tests/test_distributed_sort.py index 55cc7e37ebdfd..75cb682b165e8 100644 --- a/python/ray/tests/test_distributed_sort.py +++ b/python/ray/tests/test_distributed_sort.py @@ -4,14 +4,19 @@ from ray.experimental.raysort import main +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows") def test_distributed_sort(): - main.args = main.get_args() - main.args.ray_address = None - main.args.total_data_size = 1_000_000_000 - main.args.skip_input = True - main.args.skip_output = True - main.main() + args = main.get_args([ + "--total_data_size=1_000_000_000", + "--num_mappers=4", + "--num_reducers=4", + "--num_mappers_per_round=2", + "--ray_address=", + "--skip_input", + "--skip_output", + ]) + main.main(args) if __name__ == "__main__": - sys.exit(pytest.main(["-v", __file__])) + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_failure_2.py b/python/ray/tests/test_failure_2.py index 6bb0986e649c3..3b33e1c3f173b 100644 --- a/python/ray/tests/test_failure_2.py +++ b/python/ray/tests/test_failure_2.py @@ -67,11 +67,12 @@ class Foo: pass # The actor creation should be infeasible. - Foo.remote() + a = Foo.remote() errors = get_error_message(p, 1, ray_constants.INFEASIBLE_TASK_ERROR) assert len(errors) == 1 assert errors[0].type == ray_constants.INFEASIBLE_TASK_ERROR p.close() + del a def test_warning_for_too_many_actors(shutdown_only): diff --git a/python/ray/tests/test_multi_tenancy.py b/python/ray/tests/test_multi_tenancy.py index 1267570d3660b..f2913a50c05ba 100644 --- a/python/ray/tests/test_multi_tenancy.py +++ b/python/ray/tests/test_multi_tenancy.py @@ -111,12 +111,14 @@ def get_pid(): all_worker_pids.add(worker_pid) -def test_worker_env(shutdown_only): +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +def test_runtime_env(shutdown_only): ray.init( - job_config=ray.job_config.JobConfig(worker_env={ - "foo1": "bar1", - "foo2": "bar2" - })) + job_config=ray.job_config.JobConfig( + runtime_env={"env_vars": { + "foo1": "bar1", + "foo2": "bar2" + }})) @ray.remote def get_env(key): diff --git a/python/ray/tests/test_object_manager.py b/python/ray/tests/test_object_manager.py index 1f2c5e5dc4944..e44bf22e83187 100644 --- a/python/ray/tests/test_object_manager.py +++ b/python/ray/tests/test_object_manager.py @@ -296,8 +296,6 @@ def driver(): ray.get(driver.remote()) -# TODO(ekl) this sometimes takes much longer (10+s) due to a higher level -# pull retry. We should try to resolve these hangs in the chunk transfer logic. def test_pull_bundles_admission_control(shutdown_only): cluster = Cluster() object_size = int(6e6) @@ -605,6 +603,52 @@ def task(x): ray.get(t, timeout=10) +@pytest.mark.parametrize( + "ray_start_cluster_head", [{ + "num_cpus": 0, + "object_store_memory": 75 * 1024 * 1024, + "_system_config": { + "worker_lease_timeout_milliseconds": 0, + "object_manager_pull_timeout_ms": 20000, + "object_spilling_threshold": 1.0, + } + }], + indirect=True) +def test_maximize_concurrent_pull_race_condition(ray_start_cluster_head): + # Test if https://github.com/ray-project/ray/issues/18062 is mitigated + cluster = ray_start_cluster_head + cluster.add_node(num_cpus=8, object_store_memory=75 * 1024 * 1024) + + @ray.remote + class RemoteObjectCreator: + def put(self, i): + return np.random.rand(i * 1024 * 1024) # 8 MB data + + def idle(self): + pass + + @ray.remote + def f(x): + print(f"timestamp={time.time()} pulled {len(x)*8} bytes") + time.sleep(1) + return + + remote_obj_creator = RemoteObjectCreator.remote() + remote_refs = [remote_obj_creator.put.remote(1) for _ in range(7)] + print(remote_refs) + # Make sure all objects are created. + ray.get(remote_obj_creator.idle.remote()) + + local_refs = [ray.put(np.random.rand(1 * 1024 * 1024)) for _ in range(20)] + remote_tasks = [f.remote(x) for x in local_refs] + + start = time.time() + ray.get(remote_tasks) + end = time.time() + assert end - start < 20, "Too much time spent in pulling objects, " \ + "check the amount of time in retries" + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tests/test_output.py b/python/ray/tests/test_output.py index 93cba471ee21a..958cdecfe732e 100644 --- a/python/ray/tests/test_output.py +++ b/python/ray/tests/test_output.py @@ -65,13 +65,15 @@ def test_autoscaler_no_spam(): import ray import time -ray.init(num_cpus=1) +# Check that there are no false positives with custom resources. +ray.init(num_cpus=1, resources={"node:x": 1}) -@ray.remote(num_cpus=1) +@ray.remote(num_cpus=1, resources={"node:x": 1}) def f(): time.sleep(1) + print("task done") -ray.get([f.remote() for _ in range(5)]) +ray.get([f.remote() for _ in range(15)]) """ proc = run_string_as_driver_nonblocking(script) diff --git a/python/ray/tests/test_placement_group_3.py b/python/ray/tests/test_placement_group_3.py index 12afdfee47ecb..eeb6df0f5c4bb 100644 --- a/python/ray/tests/test_placement_group_3.py +++ b/python/ray/tests/test_placement_group_3.py @@ -608,5 +608,40 @@ def is_usage_updated(): assert cpu_usage == expected +def test_placement_group_removal_leak_regression(ray_start_cluster): + """Related issue: + https://github.com/ray-project/ray/issues/19131 + """ + cluster = ray_start_cluster + cluster.add_node(num_cpus=5) + ray.init(address=cluster.address) + + TOTAL_CPUS = 8 + bundles = [{"CPU": 1, "GPU": 1}] + bundles += [{"CPU": 1} for _ in range(TOTAL_CPUS - 1)] + + pg = placement_group(bundles, strategy="PACK") + # Here, we simulate that the ready task is queued and + # the new node is up. As soon as the new node is up, + # the ready task is scheduled. + # See https://github.com/ray-project/ray/pull/19138 + # for more details about the test. + o = pg.ready() + # Add an artificial delay until the new node is up. + time.sleep(3) + cluster.add_node(num_cpus=5, num_gpus=1) + ray.get(o) + bundle_resource_name = f"bundle_group_{pg.id.hex()}" + expected_bundle_wildcard_val = TOTAL_CPUS * 1000 + + # This should fail if there's a leakage + # because the bundle resources are never returned properly. + def check_bundle_leaks(): + bundle_resources = ray.available_resources()[bundle_resource_name] + return expected_bundle_wildcard_val == bundle_resources + + wait_for_condition(check_bundle_leaks) + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_ray_debugger.py b/python/ray/tests/test_ray_debugger.py index fdc6c56da1eb3..3cc980bb14026 100644 --- a/python/ray/tests/test_ray_debugger.py +++ b/python/ray/tests/test_ray_debugger.py @@ -11,6 +11,7 @@ import ray from ray.cluster_utils import Cluster from ray._private.test_utils import run_string_as_driver, wait_for_condition +from ray._private import services def test_ray_debugger_breakpoint(shutdown_only): @@ -217,7 +218,7 @@ def f(): host, port = session["pdb_address"].split(":") if ray_debugger_external: - assert host not in ["localhost", "127.0.0.1"], host + assert host == services.get_node_ip_address(), host else: assert host == "localhost", host @@ -267,13 +268,13 @@ def f(): host1, port1 = session1["pdb_address"].split(":") if ray_debugger_external: - assert host1 not in ["localhost", "127.0.0.1"], host1 + assert host1 == services.get_node_ip_address(), host1 else: assert host1 == "localhost", host1 host2, port2 = session2["pdb_address"].split(":") if ray_debugger_external: - assert host2 not in ["localhost", "127.0.0.1"], host2 + assert host2 == services.get_node_ip_address(), host2 else: assert host2 == "localhost", host2 diff --git a/python/ray/tests/test_ray_init.py b/python/ray/tests/test_ray_init.py index 5040f4bd65ef4..3fdb6a6ea110d 100644 --- a/python/ray/tests/test_ray_init.py +++ b/python/ray/tests/test_ray_init.py @@ -11,6 +11,7 @@ from ray.client_builder import ClientContext from ray.cluster_utils import Cluster from ray._private.test_utils import run_string_as_driver +from ray._raylet import ClientObjectRef from ray.util.client.worker import Worker import grpc @@ -216,6 +217,7 @@ def test_ray_address(input, call_ray_start): res = ray.init(input) # Ensure this is not a client.connect() assert not isinstance(res, ClientContext) + ray.shutdown() class Credentials(grpc.ChannelCredentials): @@ -257,9 +259,47 @@ def mock_secure_channel(conn_str, with pytest.raises(Stop) as stop: ray.init("ray://127.0.0.1", _credentials=Credentials("test")) + ray.util.disconnect() assert stop.value.credentials.name == "test" +def test_auto_init_non_client(call_ray_start): + address = call_ray_start + with unittest.mock.patch.dict(os.environ, {"RAY_ADDRESS": address}): + res = ray.put(300) + # Ensure this is not a client.connect() + assert not isinstance(res, ClientObjectRef) + ray.shutdown() + + addr = "localhost:{}".format(address.split(":")[-1]) + with unittest.mock.patch.dict(os.environ, {"RAY_ADDRESS": addr}): + res = ray.put(300) + # Ensure this is not a client.connect() + assert not isinstance(res, ClientObjectRef) + + +@pytest.mark.parametrize( + "call_ray_start", + ["ray start --head --ray-client-server-port 25036 --port 0"], + indirect=True) +@pytest.mark.parametrize( + "function", [lambda: ray.put(300), lambda: ray.remote(ray.nodes).remote()]) +def test_auto_init_client(call_ray_start, function): + address = call_ray_start.split(":")[0] + with unittest.mock.patch.dict(os.environ, + {"RAY_ADDRESS": f"ray://{address}:25036"}): + res = function() + # Ensure this is a client connection. + assert isinstance(res, ClientObjectRef) + ray.shutdown() + + with unittest.mock.patch.dict(os.environ, + {"RAY_ADDRESS": "ray://localhost:25036"}): + res = function() + # Ensure this is a client connection. + assert isinstance(res, ClientObjectRef) + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tests/test_resource_demand_scheduler.py b/python/ray/tests/test_resource_demand_scheduler.py index add24d4a571a9..eb1260db32aa1 100644 --- a/python/ray/tests/test_resource_demand_scheduler.py +++ b/python/ray/tests/test_resource_demand_scheduler.py @@ -46,24 +46,52 @@ def get_nodes_for(*a, **kw): def test_util_score(): assert _utilization_score({"CPU": 64}, [{"TPU": 16}]) is None - assert _utilization_score({"GPU": 4}, [{"GPU": 2}]) == (0.5, 0.5) + assert _utilization_score({"GPU": 4}, [{"GPU": 2}]) == (1, 0.5, 0.5) assert _utilization_score({"GPU": 4}, [{"GPU": 1}, {"GPU": 1}]) == \ - (0.5, 0.5) - assert _utilization_score({"GPU": 2}, [{"GPU": 2}]) == (2, 2) - assert _utilization_score({"GPU": 2}, [{"GPU": 1}, {"GPU": 1}]) == (2, 2) - assert _utilization_score({"GPU": 2, "TPU": 1}, [{"GPU": 2}]) == (0, 1) - assert _utilization_score({"CPU": 64}, [{"CPU": 64}]) == (64, 64) - assert _utilization_score({"CPU": 64}, [{"CPU": 32}]) == (8, 8) + (1, 0.5, 0.5) + assert _utilization_score({"GPU": 2}, [{"GPU": 2}]) == (1, 2, 2) + assert _utilization_score({ + "GPU": 2 + }, [{ + "GPU": 1 + }, { + "GPU": 1 + }]) == (1, 2, 2) + assert _utilization_score({ + "GPU": 1 + }, [{ + "GPU": 1, + "CPU": 1 + }, { + "GPU": 1 + }]) == (1, 1, 1) + assert _utilization_score({ + "GPU": 1, + "CPU": 1 + }, [{ + "GPU": 1, + "CPU": 1 + }, { + "GPU": 1 + }]) == (2, 1, 1) + assert _utilization_score({"GPU": 2, "TPU": 1}, [{"GPU": 2}]) == (1, 0, 1) + assert _utilization_score({"CPU": 64}, [{"CPU": 64}]) == (1, 64, 64) + assert _utilization_score({"CPU": 64}, [{"CPU": 32}]) == (1, 8, 8) assert _utilization_score({"CPU": 64}, [{"CPU": 16}, {"CPU": 16}]) == \ - (8, 8) + (1, 8, 8) def test_gpu_node_util_score(): # Avoid scheduling CPU tasks on GPU node. assert _utilization_score({"GPU": 1, "CPU": 1}, [{"CPU": 1}]) is None assert _utilization_score({"GPU": 1, "CPU": 1}, [{"CPU": 1, "GPU": 1}]) \ - == (1.0, 1.0) - assert _utilization_score({"GPU": 1, "CPU": 1}, [{"GPU": 1}]) == (0.0, 0.5) + == (2, 1.0, 1.0) + assert _utilization_score({ + "GPU": 1, + "CPU": 1 + }, [{ + "GPU": 1 + }]) == (1, 0.0, 0.5) def test_zero_resource(): @@ -197,7 +225,7 @@ def test_get_nodes_packing_heuristic(): }] * 8) + ([{ "CPU": 1 }] * 64)) == { - "m4.16xlarge": 1, + "m4.4xlarge": 2, "p2.8xlarge": 1 } @@ -215,6 +243,47 @@ def test_get_nodes_packing_heuristic(): } +def test_node_packing_gpu_cpu_bundles(): + TYPES = { + "cpu": { + "resources": { + "CPU": 16, + }, + "max_workers": 10, + }, + "gpu": { + "resources": { + "CPU": 16, + "GPU": 1, + }, + "max_workers": 10, + }, + } + nodes = get_nodes_for(TYPES, {}, "cpu", 9999, ([{ + "CPU": 1 + }] * 30 + [{ + "GPU": 1, + "CPU": 1 + }])) + assert nodes == {"gpu": 1, "cpu": 1} + + nodes = get_nodes_for(TYPES, {}, "cpu", 9999, ([{ + "GPU": 1, + "CPU": 1 + }] + [{ + "CPU": 1 + }] * 30)) + assert nodes == {"gpu": 1, "cpu": 1} + + nodes = get_nodes_for(TYPES, {}, "cpu", 9999, ([{ + "GPU": 1, + "CPU": 1 + }] + [{ + "CPU": 1 + }] * 15)) + assert nodes == {"gpu": 1} + + def test_gpu_node_avoid_cpu_task(): types = { "cpu": { @@ -630,13 +699,8 @@ def test_backlog_queue_impact_on_binpacking_time_aux( "CPU": 1 }]) # If not for the max launch concurrency the next assert should be: - # {'m4.large': 4, 'm4.4xlarge': 2, 'm4.16xlarge': 15, 'p2.8xlarge': 125}. - assert to_launch == { - "m4.large": 4, - "m4.4xlarge": 2, - "m4.16xlarge": 5, - "p2.8xlarge": 5 - } + # {'m4.16xlarge': 1, 'p2.8xlarge': 125, 'p2.xlarge': 1} + assert to_launch == {"m4.16xlarge": 1, "p2.8xlarge": 5, "p2.xlarge": 1} # Check the time it takes when there are 100 nodes available and the demand # requires another 75 nodes. @@ -1322,7 +1386,10 @@ def tearDown(self): shutil.rmtree(self.tmpdir) ray.shutdown() - def waitForNodes(self, expected, comparison=None, tag_filters={}): + def waitForNodes(self, expected, comparison=None, tag_filters=None): + if tag_filters is None: + tag_filters = {} + MAX_ITER = 50 for i in range(MAX_ITER): n = len(self.provider.non_terminated_nodes(tag_filters)) @@ -1664,7 +1731,7 @@ def testScaleUpMinWorkers(self): assert cnt == 2 def testScaleUpIgnoreUsed(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) # Commenting out this line causes the test case to fail?!?! config["min_workers"] = 0 config["target_utilization_fraction"] = 1.0 @@ -1705,7 +1772,7 @@ def testScaleUpIgnoreUsed(self): assert self.provider.mock_nodes[1].node_type == "p2.xlarge" def testRequestBundlesAccountsForHeadNode(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["head_node_type"] = "p2.8xlarge" config["min_workers"] = 0 config["max_workers"] = 50 @@ -1744,7 +1811,7 @@ def testRequestBundlesAccountsForHeadNode(self): assert self.provider.mock_nodes[1].node_type == "p2.8xlarge" def testRequestBundles(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["min_workers"] = 0 config["max_workers"] = 50 config_path = self.write_config(config) @@ -1781,7 +1848,7 @@ def testRequestBundles(self): assert self.provider.mock_nodes[4].node_type == "m4.16xlarge" def testResourcePassing(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["min_workers"] = 0 config["max_workers"] = 50 config_path = self.write_config(config) @@ -1812,7 +1879,7 @@ def testResourcePassing(self): assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" # TODO (Alex): Autoscaler creates the node during one update then - # starts the updater in the enxt update. The sleep is largely + # starts the updater in the next update. The sleep is largely # unavoidable because the updater runs in its own thread and we have no # good way of ensuring that the commands are sent in time. autoscaler.update() @@ -1827,7 +1894,7 @@ def testResourcePassing(self): runner.assert_has_call("172.0.0.2", "\"GPU\":8") def testScaleUpLoadMetrics(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["min_workers"] = 0 config["max_workers"] = 50 config_path = self.write_config(config) @@ -1858,16 +1925,15 @@ def testScaleUpLoadMetrics(self): "CPU": 16 }]) autoscaler.update() - self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) nodes = { self.provider.mock_nodes[1].node_type, - self.provider.mock_nodes[2].node_type } - assert nodes == {"p2.xlarge", "m4.4xlarge"} + assert nodes == {"p2.xlarge"} def testCommandPassing(self): t = "custom" - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["available_node_types"]["p2.8xlarge"][ "worker_setup_commands"] = ["new_worker_setup_command"] config["available_node_types"]["p2.xlarge"][ @@ -1923,7 +1989,7 @@ def testCommandPassing(self): "init_cmd") def testDockerWorkers(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["available_node_types"]["p2.8xlarge"]["docker"] = { "worker_image": "p2.8x_image:latest", "worker_run_options": ["p2.8x-run-options"] @@ -1981,7 +2047,7 @@ def testDockerWorkers(self): }]) autoscaler.update() self.waitForNodes(5) - assert self.provider.mock_nodes[4].node_type == "m4.16xlarge" + assert self.provider.mock_nodes[4].node_type == "m4.large" autoscaler.update() sleep(0.1) runner.assert_has_call(self.provider.mock_nodes[2].internal_ip, @@ -2044,7 +2110,7 @@ def testUpdateConfig(self): self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) def testEmptyDocker(self): - config = MULTI_WORKER_CLUSTER.copy() + config = copy.deepcopy(MULTI_WORKER_CLUSTER) del config["docker"] config["min_workers"] = 0 config["max_workers"] = 10 diff --git a/python/ray/tests/test_runtime_context.py b/python/ray/tests/test_runtime_context.py index 1c069e10066df..8ce983da2085a 100644 --- a/python/ray/tests/test_runtime_context.py +++ b/python/ray/tests/test_runtime_context.py @@ -4,6 +4,8 @@ import time import sys +from ray._private.test_utils import SignalActor + def test_was_current_actor_reconstructed(shutdown_only): ray.init() @@ -113,6 +115,119 @@ def echo2(self, s): assert ray.get(ray.get(obj)) == "hello" +def test_actor_stats_normal_task(ray_start_regular): + # Because it works at the core worker level, this API works for tasks. + @ray.remote + def func(): + return ray.get_runtime_context()._get_actor_call_stats() + + assert ray.get(func.remote())["func"] == { + "pending": 0, + "running": 1, + "finished": 0, + } + + +def test_actor_stats_sync_actor(ray_start_regular): + signal = SignalActor.remote() + + @ray.remote + class SyncActor: + def run(self): + return ray.get_runtime_context()._get_actor_call_stats() + + def wait_signal(self): + ray.get(signal.wait.remote()) + return ray.get_runtime_context()._get_actor_call_stats() + + actor = SyncActor.remote() + counts = ray.get(actor.run.remote()) + assert counts == { + "SyncActor.run": { + "pending": 0, + "running": 1, + "finished": 0 + }, + "SyncActor.__init__": { + "pending": 0, + "running": 0, + "finished": 1 + } + } + + ref = actor.wait_signal.remote() + other_refs = [actor.run.remote() for _ in range(3) + ] + [actor.wait_signal.remote() for _ in range(5)] + ray.wait(other_refs, timeout=1) + signal.send.remote() + counts = ray.get(ref) + assert counts == { + "SyncActor.run": { + "pending": 3, + "running": 0, + "finished": 1, # from previous run + }, + "SyncActor.wait_signal": { + "pending": 5, + "running": 1, + "finished": 0, + }, + "SyncActor.__init__": { + "pending": 0, + "running": 0, + "finished": 1 + } + } + + +def test_actor_stats_threaded_actor(ray_start_regular): + signal = SignalActor.remote() + + @ray.remote + class ThreadedActor: + def func(self): + ray.get(signal.wait.remote()) + return ray.get_runtime_context()._get_actor_call_stats() + + actor = ThreadedActor.options(max_concurrency=3).remote() + refs = [actor.func.remote() for _ in range(6)] + ready, _ = ray.wait(refs, timeout=1) + assert len(ready) == 0 + signal.send.remote() + results = ray.get(refs) + assert max(result["ThreadedActor.func"]["running"] + for result in results) > 1 + assert max(result["ThreadedActor.func"]["pending"] + for result in results) > 1 + + +def test_actor_stats_async_actor(ray_start_regular): + signal = SignalActor.remote() + + @ray.remote + class AysncActor: + async def func(self): + await signal.wait.remote() + return ray.get_runtime_context()._get_actor_call_stats() + + actor = AysncActor.options(max_concurrency=3).remote() + refs = [actor.func.remote() for _ in range(6)] + ready, _ = ray.wait(refs, timeout=1) + assert len(ready) == 0 + signal.send.remote() + results = ray.get(refs) + assert max(result["AysncActor.func"]["running"] for result in results) == 3 + assert max(result["AysncActor.func"]["pending"] for result in results) == 3 + + +# get_runtime_context() can be called outside of Ray so it should not start +# Ray automatically. +def test_no_auto_init(shutdown_only): + assert not ray.is_initialized() + ray.get_runtime_context() + assert not ray.is_initialized() + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_runtime_env.py b/python/ray/tests/test_runtime_env.py index 110beb4490a6b..0f9297238c3cd 100644 --- a/python/ray/tests/test_runtime_env.py +++ b/python/ray/tests/test_runtime_env.py @@ -13,7 +13,6 @@ from ray._private.test_utils import ( run_string_as_driver, run_string_as_driver_nonblocking, wait_for_condition) from ray._private.runtime_env import working_dir as working_dir_pkg -from ray._private.runtime_env.validation import override_task_or_actor_runtime_env # noqa: E501 from ray._private.utils import (get_wheel_filename, get_master_wheel_url, get_release_wheel_url) @@ -774,41 +773,38 @@ def test_container_option_serialize(): job_config = ray.job_config.JobConfig(runtime_env=runtime_env) job_config_serialized = job_config.serialize() # job_config_serialized is JobConfig protobuf serialized string, - # job_config.runtime_env.raw_json has container_option info - # job_config.serialized_runtime_env also has container_option info - assert job_config_serialized.count(b"image") == 2 + # job_config.runtime_env.serialized_runtime_env has container_option info + assert job_config_serialized.count(b"image") == 1 def test_working_dir_override_failure(shutdown_only): ray.init() - @ray.remote(runtime_env={"working_dir": "."}) - def f(): - pass - with pytest.raises(NotImplementedError): - f.remote() + + @ray.remote(runtime_env={"working_dir": "."}) + def f(): + pass @ray.remote def g(): pass with pytest.raises(NotImplementedError): - g.options(runtime_env={"working_dir": "."}).remote() - - @ray.remote(runtime_env={"working_dir": "."}) - class A: - pass + g.options(runtime_env={"working_dir": "."}) with pytest.raises(NotImplementedError): - A.remote() + + @ray.remote(runtime_env={"working_dir": "."}) + class A: + pass @ray.remote class B: pass with pytest.raises(NotImplementedError): - B.options(runtime_env={"working_dir": "."}).remote() + B.options(runtime_env={"working_dir": "."}) @pytest.mark.skipif( @@ -944,46 +940,6 @@ def test_large_file_error(shutdown_only): os.chdir(old_dir) -class TestOverrideTaskOrActorRuntimeEnv: - def test_working_dir_in_child_invalid(self): - child_env = {"working_dir": "some_dir"} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - - with pytest.raises(NotImplementedError): - override_task_or_actor_runtime_env(child_env, parent_env) - - def test_uri_inherit(self): - child_env = {} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env == {"uris": ["a", "b"]} - - # The dicts passed in should not be mutated. - assert child_env == {} - assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} - - def test_uri_override(self): - child_env = {"uris": ["c", "d"]} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env["uris"] == ["c", "d"] - assert result_env.get("working_dir") is None - - # The dicts passed in should not be mutated. - assert child_env == {"uris": ["c", "d"]} - assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} - - def test_no_mutate(self): - child_env = {} - parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} - result_env = override_task_or_actor_runtime_env(child_env, parent_env) - assert result_env == {"uris": ["a", "b"]} - - # The dictis passed in should not be mutated. - assert child_env == {} - assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} - - if __name__ == "__main__": import sys sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_runtime_env_complicated.py b/python/ray/tests/test_runtime_env_complicated.py index d8c334c413606..e5c7047f275b5 100644 --- a/python/ray/tests/test_runtime_env_complicated.py +++ b/python/ray/tests/test_runtime_env_complicated.py @@ -12,15 +12,16 @@ import yaml import ray -from ray._private.runtime_env import RuntimeEnvDict from ray._private.runtime_env.conda import ( inject_dependencies, _inject_ray_to_conda_site, _resolve_install_from_source_ray_dependencies, _current_py_version, ) -from ray._private.test_utils import (run_string_as_driver, - run_string_as_driver_nonblocking) + +from ray._private.runtime_env.conda_utils import get_conda_env_list +from ray._private.test_utils import ( + run_string_as_driver, run_string_as_driver_nonblocking, wait_for_condition) from ray._private.utils import get_conda_env_dir, get_conda_bin_executable if not os.environ.get("CI"): @@ -190,6 +191,39 @@ def test_job_config_conda_env(conda_envs, shutdown_only): ray.shutdown() +@pytest.mark.skipif( + os.environ.get("CONDA_DEFAULT_ENV") is None, + reason="must be run from within a conda environment") +@pytest.mark.skipif( + os.environ.get("CI") and sys.platform != "linux", + reason="This test is only run on linux CI machines.") +def test_job_eager_install(shutdown_only): + # Test enable eager install + runtime_env = {"conda": {"dependencies": ["toolz"]}, "eager_install": True} + env_count = len(get_conda_env_list()) + ray.init(runtime_env=runtime_env) + wait_for_condition( + lambda: len(get_conda_env_list()) == env_count + 1, timeout=60) + ray.shutdown() + # Test disable eager install + runtime_env = { + "conda": { + "dependencies": ["toolz"] + }, + "eager_install": False + } + ray.init(runtime_env=runtime_env) + with pytest.raises(RuntimeError): + wait_for_condition( + lambda: len(get_conda_env_list()) == env_count + 2, timeout=60) + ray.shutdown() + # Test unavailable type + runtime_env = {"conda": {"dependencies": ["toolz"]}, "eager_install": 123} + with pytest.raises(AssertionError): + ray.init(runtime_env=runtime_env) + ray.shutdown() + + def test_get_conda_env_dir(tmp_path): """ Typical output of `conda env list`, for context: @@ -449,28 +483,6 @@ def f(): assert ray.get(f.remote()) -@pytest.mark.skipif(sys.platform == "win32", reason="Unsupported on Windows.") -@pytest.mark.parametrize("use_working_dir", [True, False]) -def test_conda_input_filepath(use_working_dir, tmp_path): - conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} - d = tmp_path / "pip_requirements" - d.mkdir() - p = d / "environment.yml" - - p.write_text(yaml.dump(conda_dict)) - - if use_working_dir: - runtime_env_dict = RuntimeEnvDict({ - "working_dir": str(d), - "conda": "environment.yml" - }) - else: - runtime_env_dict = RuntimeEnvDict({"conda": str(p)}) - - output_conda_dict = runtime_env_dict.get_parsed_dict().get("conda") - assert output_conda_dict == conda_dict - - @skipIf(sys.platform == "win32", "Fail to create temp dir.") def test_experimental_package(shutdown_only): ray.init(num_cpus=2) @@ -514,7 +526,7 @@ def test_experimental_package_github(shutdown_only): ["ray start --head --ray-client-server-port 24001 --port 0"], indirect=True) def test_client_working_dir_filepath(call_ray_start, tmp_path): - """Test that pip and conda relative filepaths work with working_dir.""" + """Test that pip and conda filepaths work with working_dir.""" working_dir = tmp_path / "requirements" working_dir.mkdir() @@ -524,10 +536,7 @@ def test_client_working_dir_filepath(call_ray_start, tmp_path): pip-install-test==0.5 """ pip_file.write_text(requirements_txt) - runtime_env_pip = { - "working_dir": str(working_dir), - "pip": "requirements.txt" - } + runtime_env_pip = {"working_dir": str(working_dir), "pip": str(pip_file)} conda_file = working_dir / "environment.yml" conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} @@ -535,7 +544,7 @@ def test_client_working_dir_filepath(call_ray_start, tmp_path): conda_file.write_text(conda_str) runtime_env_conda = { "working_dir": str(working_dir), - "conda": "environment.yml" + "conda": str(conda_file) } @ray.remote @@ -557,6 +566,64 @@ def f(): assert ray.get(f.remote()) +@pytest.mark.skipif( + os.environ.get("CI") and sys.platform != "linux", + reason="This test is only run on linux CI machines.") +@pytest.mark.parametrize( + "call_ray_start", + ["ray start --head --ray-client-server-port 24001 --port 0"], + indirect=True) +def test_conda_pip_filepaths_remote(call_ray_start, tmp_path): + """Test that pip and conda filepaths work, simulating a remote cluster.""" + + working_dir = tmp_path / "requirements" + working_dir.mkdir() + + pip_file = working_dir / "requirements.txt" + requirements_txt = """ + pip-install-test==0.5 + """ + pip_file.write_text(requirements_txt) + runtime_env_pip = {"pip": str(pip_file)} + + conda_file = working_dir / "environment.yml" + conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} + conda_str = yaml.dump(conda_dict) + conda_file.write_text(conda_str) + runtime_env_conda = {"conda": str(conda_file)} + + @ray.remote + def f(): + import pip_install_test # noqa + return True + + with ray.client("localhost:24001").connect(): + with pytest.raises(ModuleNotFoundError): + # Ensure pip-install-test is not installed in a client that doesn't + # use the runtime_env + ray.get(f.remote()) + + # pip and conda files should be parsed when the function is declared. + f_pip = f.options(runtime_env=runtime_env_pip) + f_conda = f.options(runtime_env=runtime_env_conda) + + # Remove the pip and conda files from the local filesystem. This is + # necessary to simulate the files not being present on the remote cluster, + # because in this single-machine test, the cluster has the same filesystem. + os.remove(pip_file) + os.remove(conda_file) + + # Test with and without a working_dir. + client_envs = [{}, {"working_dir": str(working_dir)}] + for runtime_env in client_envs: + with ray.client("localhost:24001").env(runtime_env).connect(): + with pytest.raises(ModuleNotFoundError): + # Ensure pip-install-test is not installed on the test machine + import pip_install_test # noqa + assert ray.get(f_pip.remote()) + assert ray.get(f_conda.remote()) + + install_env_script = """ import ray import time @@ -718,7 +785,7 @@ def test(self): # Start a new job on the same cluster using the Summit 2021 requirements. with ray.client(f"localhost:{CLIENT_SERVER_PORT}").env({ "working_dir": str(tmp_path), - "pip": "requirements.txt" + "pip": str(requirement_path) }).connect(): @ray.remote @@ -752,7 +819,9 @@ def test(self): return Path("./test").read_text() - a = TestActor.options(runtime_env={"pip": "requirements.txt"}).remote() + a = TestActor.options(runtime_env={ + "pip": str(requirement_path) + }).remote() assert ray.get(a.test.remote()) == "Hello" # Check that per-task pip specification works and that the job's @@ -888,7 +957,7 @@ def f(self): @pytest.mark.skipif( os.environ.get("CI") and sys.platform != "linux", reason="This test is only run on linux CI machines.") -def test_runtime_env_logging_to_dirver(ray_start_regular_shared, log_pubsub): +def test_runtime_env_logging_to_driver(ray_start_regular_shared, log_pubsub): @ray.remote(runtime_env={"pip": [f"requests=={REQUEST_VERSIONS[0]}"]}) def func(): pass diff --git a/python/ray/tests/test_runtime_env_env_vars.py b/python/ray/tests/test_runtime_env_env_vars.py index 22ce5d5ce59b9..479a7f4130bd2 100644 --- a/python/ray/tests/test_runtime_env_env_vars.py +++ b/python/ray/tests/test_runtime_env_env_vars.py @@ -7,54 +7,37 @@ import ray -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_task(ray_start_regular, - use_runtime_env): +def test_environment_variables_task(ray_start_regular): @ray.remote def get_env(key): return os.environ.get(key) - if use_runtime_env: - assert (ray.get( - get_env.options(runtime_env={ - "env_vars": { - "a": "b", - } - }).remote("a")) == "b") - else: - assert (ray.get( - get_env.options(override_environment_variables={ + assert (ray.get( + get_env.options(runtime_env={ + "env_vars": { "a": "b", - }).remote("a")) == "b") + } + }).remote("a")) == "b") -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_actor(ray_start_regular, - use_runtime_env): +def test_environment_variables_actor(ray_start_regular): @ray.remote class EnvGetter: def get(self, key): return os.environ.get(key) - if use_runtime_env: - a = EnvGetter.options(runtime_env={ - "env_vars": { - "a": "b", - "c": "d", - } - }).remote() - else: - a = EnvGetter.options(override_environment_variables={ + a = EnvGetter.options(runtime_env={ + "env_vars": { "a": "b", "c": "d", - }).remote() + } + }).remote() + assert (ray.get(a.get.remote("a")) == "b") assert (ray.get(a.get.remote("c")) == "d") -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_nested_task(ray_start_regular, - use_runtime_env): +def test_environment_variables_nested_task(ray_start_regular): @ray.remote def get_env(key): return os.environ.get(key) @@ -63,36 +46,19 @@ def get_env(key): def get_env_wrapper(key): return ray.get(get_env.remote(key)) - if use_runtime_env: - assert (ray.get( - get_env_wrapper.options(runtime_env={ - "env_vars": { - "a": "b", - } - }).remote("a")) == "b") - else: - assert (ray.get( - get_env_wrapper.options(override_environment_variables={ + assert (ray.get( + get_env_wrapper.options(runtime_env={ + "env_vars": { "a": "b", - }).remote("a")) == "b") - - -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_multitenancy(shutdown_only, - use_runtime_env): - if use_runtime_env: - ray.init( - job_config=ray.job_config.JobConfig( - runtime_env={"env_vars": { - "foo1": "bar1", - "foo2": "bar2", - }})) - else: - ray.init( - job_config=ray.job_config.JobConfig(worker_env={ - "foo1": "bar1", - "foo2": "bar2", - })) + } + }).remote("a")) == "b") + + +def test_environment_variables_multitenancy(shutdown_only): + ray.init(runtime_env={"env_vars": { + "foo1": "bar1", + "foo2": "bar2", + }}) @ray.remote def get_env(key): @@ -100,48 +66,27 @@ def get_env(key): assert ray.get(get_env.remote("foo1")) == "bar1" assert ray.get(get_env.remote("foo2")) == "bar2" - if use_runtime_env: - assert ray.get( - get_env.options(runtime_env={ - "env_vars": { - "foo1": "baz1", - } - }).remote("foo1")) == "baz1" - assert ray.get( - get_env.options(runtime_env={ - "env_vars": { - "foo1": "baz1", - } - }).remote("foo2")) == "bar2" - else: - assert ray.get( - get_env.options(override_environment_variables={ + assert ray.get( + get_env.options(runtime_env={ + "env_vars": { "foo1": "baz1", - }).remote("foo1")) == "baz1" - assert ray.get( - get_env.options(override_environment_variables={ + } + }).remote("foo1")) == "baz1" + assert ray.get( + get_env.options(runtime_env={ + "env_vars": { "foo1": "baz1", - }).remote("foo2")) == "bar2" + } + }).remote("foo2")) == "bar2" -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_complex(shutdown_only, - use_runtime_env): - if use_runtime_env: - ray.init(runtime_env={ - "env_vars": { - "a": "job_a", - "b": "job_b", - "z": "job_z", - } - }) - else: - ray.init( - job_config=ray.job_config.JobConfig(worker_env={ - "a": "job_a", - "b": "job_b", - "z": "job_z", - })) +def test_environment_variables_complex(shutdown_only): + ray.init( + runtime_env={"env_vars": { + "a": "job_a", + "b": "job_b", + "z": "job_z", + }}) @ray.remote def get_env(key): @@ -164,69 +109,45 @@ def get_task(self, key): return ray.get(get_env.remote(key)) def nested_get(self, key): - if use_runtime_env: - aa = NestedEnvGetter.options(runtime_env={ - "env_vars": { - "c": "e", - "d": "dd", - } - }).remote() - else: - aa = NestedEnvGetter.options(override_environment_variables={ + aa = NestedEnvGetter.options(runtime_env={ + "env_vars": { "c": "e", "d": "dd", - }).remote() + } + }).remote() return ray.get(aa.get.remote(key)) - if use_runtime_env: - a = EnvGetter.options(runtime_env={ - "env_vars": { - "a": "b", - "c": "d", - } - }).remote() - else: - a = EnvGetter.options(override_environment_variables={ + a = EnvGetter.options(runtime_env={ + "env_vars": { "a": "b", "c": "d", - }).remote() + } + }).remote() + assert (ray.get(a.get.remote("a")) == "b") assert (ray.get(a.get_task.remote("a")) == "b") assert (ray.get(a.nested_get.remote("a")) == "b") assert (ray.get(a.nested_get.remote("c")) == "e") assert (ray.get(a.nested_get.remote("d")) == "dd") - if use_runtime_env: - assert (ray.get( - get_env.options(runtime_env={ - "env_vars": { - "a": "b", - } - }).remote("a")) == "b") - else: - assert (ray.get( - get_env.options(override_environment_variables={ + assert (ray.get( + get_env.options(runtime_env={ + "env_vars": { "a": "b", - }).remote("a")) == "b") + } + }).remote("a")) == "b") assert (ray.get(a.get.remote("z")) == "job_z") assert (ray.get(a.get_task.remote("z")) == "job_z") assert (ray.get(a.nested_get.remote("z")) == "job_z") - if use_runtime_env: - assert (ray.get( - get_env.options(runtime_env={ - "env_vars": { - "a": "b", - } - }).remote("z")) == "job_z") - else: - assert (ray.get( - get_env.options(override_environment_variables={ + assert (ray.get( + get_env.options(runtime_env={ + "env_vars": { "a": "b", - }).remote("z")) == "job_z") + } + }).remote("z")) == "job_z") -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_reuse(shutdown_only, use_runtime_env): +def test_environment_variables_reuse(shutdown_only): """Test that new tasks don't incorrectly reuse previous environments.""" ray.init() @@ -244,32 +165,20 @@ def g(): return os.environ.get(env_var_name) assert ray.get(f.remote()) is None - if use_runtime_env: - assert ray.get( - f.options(runtime_env={ - "env_vars": { - env_var_name: val1 - } - }).remote()) == val1 - else: - assert ray.get( - f.options(override_environment_variables={ + assert ray.get( + f.options(runtime_env={ + "env_vars": { env_var_name: val1 - }).remote()) == val1 + } + }).remote()) == val1 assert ray.get(f.remote()) is None assert ray.get(g.remote()) is None - if use_runtime_env: - assert ray.get( - f.options(runtime_env={ - "env_vars": { - env_var_name: val2 - } - }).remote()) == val2 - else: - assert ray.get( - f.options(override_environment_variables={ + assert ray.get( + f.options(runtime_env={ + "env_vars": { env_var_name: val2 - }).remote()) == val2 + } + }).remote()) == val2 assert ray.get(g.remote()) is None assert ray.get(f.remote()) is None @@ -278,9 +187,7 @@ def g(): # there aren't enough CPUs (2-4 on Travis CI vs. likely 8 on Buildkite) and # worker processes are being killed to adhere to the soft limit. @pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on Travis CI.") -@pytest.mark.parametrize("use_runtime_env", [True, False]) -def test_override_environment_variables_env_caching(shutdown_only, - use_runtime_env): +def test_environment_variables_env_caching(shutdown_only): """Test that workers with specified envs are cached and reused. When a new task or actor is created with a new runtime env, a @@ -307,10 +214,7 @@ def g(): return task() def get_options(val): - if use_runtime_env: - return {"override_environment_variables": {env_var_name: val}} - else: - return {"runtime_env": {"env_vars": {env_var_name: val}}} + return {"runtime_env": {"env_vars": {env_var_name: val}}} # Empty runtime env does not set our env var. assert ray.get(f.remote())[0] is None diff --git a/python/ray/tests/test_runtime_env_plugin.py b/python/ray/tests/test_runtime_env_plugin.py new file mode 100644 index 0000000000000..629cdca4e6d25 --- /dev/null +++ b/python/ray/tests/test_runtime_env_plugin.py @@ -0,0 +1,75 @@ +import os +import tempfile + +import pytest +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.plugin import RuntimeEnvPlugin + +import ray + +MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin" + + +class MyPlugin(RuntimeEnvPlugin): + env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY" + + @staticmethod + def validate(runtime_env_dict: dict) -> str: + value = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH] + if value == "fail": + raise ValueError("not allowed") + return value + + @staticmethod + def modify_context(uri: str, runtime_env_dict: dict, + ctx: RuntimeEnvContext) -> None: + plugin_config_dict = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH] + ctx.env_vars[MyPlugin.env_key] = str(plugin_config_dict["env_value"]) + ctx.command_prefix.append( + f"echo {plugin_config_dict['tmp_content']} > " + f"{plugin_config_dict['tmp_file']}") + ctx.py_executable = ( + plugin_config_dict["prefix_command"] + " " + ctx.py_executable) + + +def test_simple_env_modification_plugin(ray_start_regular): + _, tmp_file_path = tempfile.mkstemp() + + @ray.remote + def f(): + import psutil + with open(tmp_file_path, "r") as f: + content = f.read().strip() + return { + "env_value": os.environ[MyPlugin.env_key], + "tmp_content": content, + "nice": psutil.Process().nice(), + } + + with pytest.raises(ValueError, match="not allowed"): + f.options(runtime_env={ + "plugins": { + MY_PLUGIN_CLASS_PATH: "fail" + } + }).remote() + + output = ray.get( + f.options( + runtime_env={ + "plugins": { + MY_PLUGIN_CLASS_PATH: { + "env_value": 42, + "tmp_file": tmp_file_path, + "tmp_content": "hello", + # See https://en.wikipedia.org/wiki/Nice_(Unix) + "prefix_command": "nice -n 19", + } + } + }).remote()) + + assert output == {"env_value": "42", "tmp_content": "hello", "nice": 19} + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_runtime_env_validation.py b/python/ray/tests/test_runtime_env_validation.py new file mode 100644 index 0000000000000..dd73db1c3bd0d --- /dev/null +++ b/python/ray/tests/test_runtime_env_validation.py @@ -0,0 +1,360 @@ +import os +import pytest +import sys +import tempfile +from pathlib import Path +import yaml + +from ray._private.runtime_env.validation import ( + parse_and_validate_working_dir, parse_and_validate_conda, + parse_and_validate_pip, parse_and_validate_env_vars, ParsedRuntimeEnv, + override_task_or_actor_runtime_env) + +CONDA_DICT = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} + +PIP_LIST = ["requests==1.0.0", "pip-install-test"] + + +@pytest.fixture +def test_directory(): + with tempfile.TemporaryDirectory() as tmp_dir: + path = Path(tmp_dir) + subdir = path / "subdir" + subdir.mkdir(parents=True) + requirements_file = subdir / "requirements.txt" + with requirements_file.open(mode="w") as f: + print("\n".join(PIP_LIST), file=f) + + good_conda_file = subdir / "good_conda_env.yaml" + with good_conda_file.open(mode="w") as f: + yaml.dump(CONDA_DICT, f) + + bad_conda_file = subdir / "bad_conda_env.yaml" + with bad_conda_file.open(mode="w") as f: + print("% this is not a YAML file %", file=f) + + old_dir = os.getcwd() + os.chdir(tmp_dir) + yield subdir, requirements_file, good_conda_file, bad_conda_file + os.chdir(old_dir) + + +class TestValidateWorkingDir: + @pytest.mark.parametrize("absolute_path", [True, False]) + def test_validate_working_dir_valid_path(self, test_directory, + absolute_path): + subdir, _, _, _ = test_directory + + rel1 = "." + assert parse_and_validate_working_dir( + rel1, is_task_or_actor=False) == rel1 + + if absolute_path: + subdir = subdir.resolve() + + rel2 = str(subdir) + assert parse_and_validate_working_dir( + rel2, is_task_or_actor=False) == rel2 + + def test_validate_working_dir_absolute_path(self, test_directory): + subdir, _, _, _ = test_directory + + abspath = str(subdir.resolve()) + assert parse_and_validate_working_dir( + abspath, is_task_or_actor=False) == abspath + + def test_validate_working_dir_invalid_path(self): + with pytest.raises(ValueError): + parse_and_validate_working_dir("fake_path", is_task_or_actor=False) + + def test_validate_working_dir_invalid_types(self): + with pytest.raises(TypeError): + parse_and_validate_working_dir( + { + "working_dir": 1 + }, is_task_or_actor=False) + + def test_validate_working_dir_reject_task_or_actor(self): + # Can't pass working_dir for tasks/actors. + with pytest.raises(NotImplementedError): + parse_and_validate_working_dir( + { + "working_dir": "." + }, is_task_or_actor=True) + + +@pytest.mark.skipif( + sys.platform == "win32", reason="Conda option not supported on Windows.") +class TestValidateConda: + def test_validate_conda_invalid_types(self): + with pytest.raises(TypeError): + parse_and_validate_conda(1) + + with pytest.raises(TypeError): + parse_and_validate_conda(True) + + def test_validate_conda_str(self, test_directory): + assert parse_and_validate_conda("my_env_name") == "my_env_name" + + def test_validate_conda_invalid_path(self): + with pytest.raises(ValueError): + parse_and_validate_conda("../bad_path.yaml") + + @pytest.mark.parametrize("absolute_path", [True, False]) + def test_validate_conda_valid_file(self, test_directory, absolute_path): + _, _, good_conda_file, _ = test_directory + + if absolute_path: + good_conda_file = good_conda_file.resolve() + + assert parse_and_validate_conda(str(good_conda_file)) == CONDA_DICT + + @pytest.mark.parametrize("absolute_path", [True, False]) + def test_validate_conda_invalid_file(self, test_directory, absolute_path): + _, _, _, bad_conda_file = test_directory + + if absolute_path: + bad_conda_file = bad_conda_file.resolve() + + with pytest.raises(ValueError): + parse_and_validate_conda(str(bad_conda_file)) + + def test_validate_conda_valid_dict(self): + assert parse_and_validate_conda(CONDA_DICT) == CONDA_DICT + + +@pytest.mark.skipif( + sys.platform == "win32", reason="Pip option not supported on Windows.") +class TestValidatePip: + def test_validate_pip_invalid_types(self): + with pytest.raises(TypeError): + parse_and_validate_pip(1) + + with pytest.raises(TypeError): + parse_and_validate_pip(True) + + def test_validate_pip_invalid_path(self): + with pytest.raises(ValueError): + parse_and_validate_pip("../bad_path.txt") + + @pytest.mark.parametrize("absolute_path", [True, False]) + def test_validate_pip_valid_file(self, test_directory, absolute_path): + _, requirements_file, _, _ = test_directory + + if absolute_path: + requirements_file = requirements_file.resolve() + + result = parse_and_validate_pip(str(requirements_file)) + assert result == PIP_LIST + + def test_validate_pip_valid_list(self): + result = parse_and_validate_pip(PIP_LIST) + assert result == PIP_LIST + + +class TestValidateEnvVars: + def test_type_validation(self): + # Only strings allowed. + with pytest.raises(TypeError, match=".*Dict[str, str]*"): + parse_and_validate_env_vars({"INT_ENV": 1}) + + with pytest.raises(TypeError, match=".*Dict[str, str]*"): + parse_and_validate_env_vars({1: "hi"}) + + +class TestParsedRuntimeEnv: + @pytest.mark.parametrize("is_task_or_actor", [True, False]) + def test_empty(self, is_task_or_actor): + assert ParsedRuntimeEnv({}, is_task_or_actor=is_task_or_actor) == {} + + @pytest.mark.skipif( + sys.platform == "win32", reason="Pip option not supported on Windows.") + @pytest.mark.parametrize("is_task_or_actor", [True, False]) + def test_serialization(self, is_task_or_actor): + env1 = ParsedRuntimeEnv( + { + "pip": ["requests"], + "env_vars": { + "hi1": "hi1", + "hi2": "hi2" + } + }, + is_task_or_actor=is_task_or_actor) + + env2 = ParsedRuntimeEnv( + { + "env_vars": { + "hi2": "hi2", + "hi1": "hi1" + }, + "pip": ["requests"] + }, + is_task_or_actor=is_task_or_actor) + + assert env1 == env2 + + serialized_env1 = env1.serialize() + serialized_env2 = env2.serialize() + + # Key ordering shouldn't matter. + assert serialized_env1 == serialized_env2 + + deserialized_env1 = ParsedRuntimeEnv.deserialize(serialized_env1) + deserialized_env2 = ParsedRuntimeEnv.deserialize(serialized_env2) + + assert env1 == deserialized_env1 == env2 == deserialized_env2 + + @pytest.mark.parametrize("is_task_or_actor", [True, False]) + def test_reject_pip_and_conda(self, is_task_or_actor): + with pytest.raises(ValueError): + ParsedRuntimeEnv( + { + "pip": ["requests"], + "conda": "env_name" + }, + is_task_or_actor=is_task_or_actor) + + @pytest.mark.skipif( + sys.platform == "win32", + reason="Conda and pip options not supported on Windows.") + @pytest.mark.parametrize("is_task_or_actor", [True, False]) + def test_ray_commit_injection(self, is_task_or_actor): + # Should not be injected if no pip and conda. + result = ParsedRuntimeEnv( + { + "env_vars": { + "hi": "hi" + } + }, is_task_or_actor=is_task_or_actor) + assert "_ray_commit" not in result + + # Should be injected if pip or conda present. + result = ParsedRuntimeEnv( + { + "pip": ["requests"], + }, is_task_or_actor=is_task_or_actor) + assert "_ray_commit" in result + + result = ParsedRuntimeEnv( + { + "conda": "env_name" + }, is_task_or_actor=is_task_or_actor) + assert "_ray_commit" in result + + # Should not override if passed. + result = ParsedRuntimeEnv( + { + "conda": "env_name", + "_ray_commit": "Blah" + }, + is_task_or_actor=is_task_or_actor) + assert result["_ray_commit"] == "Blah" + + @pytest.mark.parametrize("is_task_or_actor", [True, False]) + def test_inject_current_ray(self, is_task_or_actor): + # Should not be injected if not provided by env var. + result = ParsedRuntimeEnv( + { + "env_vars": { + "hi": "hi" + } + }, is_task_or_actor=is_task_or_actor) + assert "_inject_current_ray" not in result + + os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] = "1" + + # Should be injected if provided by env var. + result = ParsedRuntimeEnv({}, is_task_or_actor=is_task_or_actor) + assert result["_inject_current_ray"] + + # Should be preserved if passed. + result = ParsedRuntimeEnv( + { + "_inject_current_ray": False + }, is_task_or_actor=is_task_or_actor) + assert not result["_inject_current_ray"] + + del os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] + + +class TestOverrideRuntimeEnvs: + def test_override_uris(self): + child = {} + parent = {"uris": ["a", "b"]} + assert override_task_or_actor_runtime_env(child, parent) == parent + + child = {"uris": ["a", "b"]} + parent = {"uris": ["c", "d"]} + assert override_task_or_actor_runtime_env(child, parent) == child + + child = {"uris": ["a", "b"]} + parent = {} + assert override_task_or_actor_runtime_env(child, parent) == child + + def test_override_env_vars(self): + # (child, parent, expected) + TEST_CASES = [ + ({}, {}, {}), + (None, None, None), + ({"a": "b"}, {}, {"a": "b"}), + ({"a": "b"}, None, {"a": "b"}), + ({}, {"a": "b"}, {"a": "b"}), + (None, {"a": "b"}, {"a": "b"}), + ({"a": "b"}, {"a": "d"}, {"a": "b"}), + ({"a": "b"}, {"c": "d"}, {"a": "b", "c": "d"}), + ({"a": "b"}, {"a": "e", "c": "d"}, {"a": "b", "c": "d"}) + ] # yapf: disable + + for idx, (child, parent, expected) in enumerate(TEST_CASES): + child = {"env_vars": child} if child is not None else {} + parent = {"env_vars": parent} if parent is not None else {} + expected = {"env_vars": expected} if expected is not None else {} + assert override_task_or_actor_runtime_env( + child, parent) == expected, f"TEST_INDEX:{idx}" + + def test_uri_inherit(self): + child_env = {} + parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env == {"uris": ["a", "b"]} + + # The dicts passed in should not be mutated. + assert child_env == {} + assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} + + def test_uri_override(self): + child_env = {"uris": ["c", "d"]} + parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env["uris"] == ["c", "d"] + assert result_env.get("working_dir") is None + + # The dicts passed in should not be mutated. + assert child_env == {"uris": ["c", "d"]} + assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} + + def test_no_mutate(self): + child_env = {} + parent_env = {"working_dir": "other_dir", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env == {"uris": ["a", "b"]} + + # The dicts passed in should not be mutated. + assert child_env == {} + assert parent_env == {"working_dir": "other_dir", "uris": ["a", "b"]} + + def test_inherit_conda(self): + child_env = {"uris": ["a"]} + parent_env = {"conda": "my-env-name", "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env == {"uris": ["a"], "conda": "my-env-name"} + + def test_inherit_pip(self): + child_env = {"uris": ["a"]} + parent_env = {"pip": ["pkg-name"], "uris": ["a", "b"]} + result_env = override_task_or_actor_runtime_env(child_env, parent_env) + assert result_env == {"uris": ["a"], "pip": ["pkg-name"]} + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_scheduling.py b/python/ray/tests/test_scheduling.py index 10a4ab846e844..b834d67e0c67c 100644 --- a/python/ray/tests/test_scheduling.py +++ b/python/ray/tests/test_scheduling.py @@ -2,6 +2,7 @@ import collections import logging import platform +import subprocess import sys import time import unittest @@ -549,8 +550,8 @@ def __init__(self): def get_location(self): return ray.worker.global_worker.node.unique_id - @ray.remote - def task_cpu(num_cpus=0.5): + @ray.remote(num_cpus=0.5) + def task_cpu(): time.sleep(10) return ray.worker.global_worker.node.unique_id @@ -578,6 +579,100 @@ def launcher(): cluster.shutdown() +@pytest.mark.parametrize( + "ray_start_cluster", [{ + "num_cpus": 0, + "num_nodes": 1, + }], indirect=True) +def test_head_node_without_cpu(ray_start_cluster): + @ray.remote(num_cpus=1) + def f(): + return 1 + + f.remote() + + check_count = 0 + demand_1cpu = " {'CPU': 1.0}:" + while True: + status = subprocess.check_output(["ray", "status"]).decode() + if demand_1cpu in status: + break + check_count += 1 + assert check_count < 5, f"Incorrect demand. Last status {status}" + time.sleep(1) + + @ray.remote(num_cpus=2) + def g(): + return 2 + + g.remote() + + check_count = 0 + demand_2cpu = " {'CPU': 2.0}:" + while True: + status = subprocess.check_output(["ray", "status"]).decode() + if demand_1cpu in status and demand_2cpu in status: + break + check_count += 1 + assert check_count < 5, f"Incorrect demand. Last status {status}" + time.sleep(1) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Fails on windows") +def test_gpu_scheduling_liveness(ray_start_cluster): + """Check if the GPU scheduling is in progress when + it is used with the placement group + Issue: https://github.com/ray-project/ray/issues/19130 + """ + cluster = ray_start_cluster + # Start a node without a gpu. + cluster.add_node(num_cpus=6) + ray.init(address=cluster.address) + + NUM_CPU_BUNDLES = 10 + + @ray.remote(num_cpus=1) + class Worker(object): + def __init__(self, i): + self.i = i + + def work(self): + time.sleep(0.1) + print("work ", self.i) + + @ray.remote(num_cpus=1, num_gpus=1) + class Trainer(object): + def __init__(self, i): + self.i = i + + def train(self): + time.sleep(0.2) + print("train ", self.i) + + bundles = [{"CPU": 1, "GPU": 1}] + bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)] + + pg = ray.util.placement_group(bundles, strategy="PACK") + o = pg.ready() + # Artificial delay to simulate the real world workload. + time.sleep(3) + print("Scaling up.") + cluster.add_node(num_cpus=6, num_gpus=1) + ray.get(o) + + workers = [ + Worker.options(placement_group=pg).remote(i) + for i in range(NUM_CPU_BUNDLES) + ] + trainer = Trainer.options(placement_group=pg).remote(0) + + # If the gpu scheduling doesn't properly work, the below + # code will hang. + ray.get( + [workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)], timeout=30) + ray.get(trainer.train.remote(), timeout=30) + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_traceback.py b/python/ray/tests/test_traceback.py index 3081bcc6ec3d4..fa48ec62f09cb 100644 --- a/python/ray/tests/test_traceback.py +++ b/python/ray/tests/test_traceback.py @@ -270,6 +270,45 @@ def __repr__(self): assert label_dict["repr"] == actor_repr +def test_unpickleable_stacktrace(): + expected_output = """System error: Failed to unpickle serialized exception +traceback: Traceback (most recent call last): + File "FILE", line ZZ, in from_bytes + return pickle.loads(ray_exception.serialized_exception) +TypeError: __init__() missing 1 required positional argument: 'arg' + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "FILE", line ZZ, in deserialize_objects + obj = self._deserialize_object(data, metadata, object_ref) + File "FILE", line ZZ, in _deserialize_object + return RayError.from_bytes(obj) + File "FILE", line ZZ, in from_bytes + raise RuntimeError(msg) from e +RuntimeError: Failed to unpickle serialized exception""" + + class NoPickleError(OSError): + def __init__(self, arg): + pass + + def g(a): + raise NoPickleError("asdf") + + @ray.remote + def f(): + a = 3 + b = 4 + c = a + b + return g(c) + + try: + ray.get(f.remote()) + except Exception as ex: + print(repr(scrub_traceback(str(ex)))) + assert clean_noqa(expected_output) == scrub_traceback(str(ex)) + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index e7cfc31810e1d..fbaa7207a04a0 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -1,9 +1,11 @@ import json import logging import os +import warnings from numbers import Number from typing import Any, Dict, List, Optional, Tuple +from ray.util.debug import log_once from ray.tune.utils import flatten_dict from ray.tune.utils.serialization import TuneFunctionDecoder from ray.tune.utils.util import is_nan_or_inf @@ -556,6 +558,17 @@ def best_result(self) -> Dict: "the metric and mode explicitly and fetch the last result.") return self.best_trial.last_result + def _delimiter(self): + # Deprecate: 1.9 (default should become `/`) + delimiter = os.environ.get("TUNE_RESULT_DELIM", ".") + if delimiter == "." and log_once("delimiter_deprecation"): + warnings.warn( + "Dataframes will use '/' instead of '.' to delimit " + "nested result keys in future versions of Ray. For forward " + "compatibility, set the environment variable " + "TUNE_RESULT_DELIM='/'") + return delimiter + @property def best_result_df(self) -> DataFrame: """Get the best result of the experiment as a pandas dataframe. @@ -569,7 +582,9 @@ def best_result_df(self) -> DataFrame: if not pd: raise ValueError("`best_result_df` requires pandas. Install with " "`pip install pandas`.") - best_result = flatten_dict(self.best_result, delimiter=".") + + best_result = flatten_dict( + self.best_result, delimiter=self._delimiter()) return pd.DataFrame.from_records([best_result], index="trial_id") @property @@ -579,12 +594,13 @@ def results(self) -> Dict[str, Dict]: @property def results_df(self) -> DataFrame: + """Get all the last results as a pandas dataframe.""" if not pd: - raise ValueError("`best_result_df` requires pandas. Install with " + raise ValueError("`results_df` requires pandas. Install with " "`pip install pandas`.") return pd.DataFrame.from_records( [ - flatten_dict(trial.last_result, delimiter=".") + flatten_dict(trial.last_result, delimiter=self._delimiter()) for trial in self.trials ], index="trial_id") diff --git a/python/ray/tune/commands.py b/python/ray/tune/commands.py index 7fbbe9776bde2..5d47605c63181 100644 --- a/python/ray/tune/commands.py +++ b/python/ray/tune/commands.py @@ -116,10 +116,9 @@ def list_trials(experiment_path, _check_tabulate() try: - checkpoints_df = Analysis(experiment_path).dataframe( - metric="episode_reward_mean", mode="max") - except TuneError: - raise click.ClickException("No trial data found!") + checkpoints_df = Analysis(experiment_path).dataframe() # last result + except TuneError as e: + raise click.ClickException("No trial data found!") from e def key_filter(k): return k in DEFAULT_CLI_KEYS or k.startswith(CONFIG_PREFIX) diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py index db822434f1223..77b80e510af2b 100644 --- a/python/ray/tune/durable_trainable.py +++ b/python/ray/tune/durable_trainable.py @@ -171,14 +171,16 @@ def durable(trainable: Union[str, Type[Trainable], Callable]): A durable trainable class wrapped around your trainable. """ + overwrite_name = None if isinstance(trainable, str): trainable_cls = get_trainable_cls(trainable) + overwrite_name = f"Durable{trainable}" else: trainable_cls = trainable if not inspect.isclass(trainable_cls): # Function API - return wrap_function(trainable_cls, durable=True) + return wrap_function(trainable_cls, durable=True, name=overwrite_name) if not issubclass(trainable_cls, Trainable): raise ValueError( @@ -187,8 +189,14 @@ def durable(trainable: Union[str, Type[Trainable], Callable]): f"it does. Got: {type(trainable_cls)}") # else: Class API + + # Class is already durable + + if issubclass(trainable_cls, DurableTrainable): + return trainable_cls + class _WrappedDurableTrainable(DurableTrainable, trainable_cls): - _name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \ - else "durable_trainable" + _name = overwrite_name or (trainable_cls.__name__ if hasattr( + trainable_cls, "__name__") else "durable_trainable") return _WrappedDurableTrainable diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index ae4235aa89099..e4c2018068d7a 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -10,6 +10,8 @@ from functools import partial from numbers import Number +from typing import Any, Callable, Optional + from six.moves import queue from ray.util.debug import log_once @@ -530,7 +532,10 @@ def _report_thread_runner_error(self, block=False): pass -def wrap_function(train_func, durable=False, warn=True): +def wrap_function(train_func: Callable[[Any], Any], + durable: bool = False, + warn: bool = True, + name: Optional[str] = None): inherit_from = (FunctionRunner, ) if hasattr(train_func, "__mixins__"): @@ -562,8 +567,8 @@ def wrap_function(train_func, durable=False, warn=True): "arguments to be `func(config, checkpoint_dir=None)`.") class ImplicitFunc(*inherit_from): - _name = train_func.__name__ if hasattr(train_func, "__name__") \ - else "func" + _name = name or (train_func.__name__ + if hasattr(train_func, "__name__") else "func") def _trainable_func(self, config, reporter, checkpoint_dir): if not use_checkpoint and not use_reporter: diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 0b69faa51550d..52f19f8029da2 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -1,5 +1,6 @@ from __future__ import print_function +import datetime from typing import Dict, List, Optional, Union import collections @@ -8,15 +9,17 @@ import numpy as np import time +from ray.util.annotations import PublicAPI, DeveloperAPI +from ray.util.queue import Queue + from ray.tune.callback import Callback from ray.tune.logger import pretty_print, logger -from ray.tune.result import (DEFAULT_METRIC, EPISODE_REWARD_MEAN, - MEAN_ACCURACY, MEAN_LOSS, TRAINING_ITERATION, - TIME_TOTAL_S, TIMESTEPS_TOTAL, AUTO_RESULT_KEYS) -from ray.tune.trial import DEBUG_PRINT_INTERVAL, Trial +from ray.tune.result import ( + DEFAULT_METRIC, EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS, NODE_IP, + PID, TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL, AUTO_RESULT_KEYS) +from ray.tune.trial import DEBUG_PRINT_INTERVAL, Trial, Location from ray.tune.utils import unflattened_lookup from ray.tune.utils.log import Verbosity, has_verbosity -from ray.util.annotations import PublicAPI, DeveloperAPI try: from collections.abc import Mapping, MutableMapping @@ -159,6 +162,8 @@ def __init__( self._max_report_freqency = max_report_frequency self._last_report_time = 0 + self._start_time = time.time() + self._metric = metric self._mode = mode @@ -188,6 +193,12 @@ def set_search_properties(self, metric: Optional[str], def set_total_samples(self, total_samples: int): self._total_samples = total_samples + def set_start_time(self, timestamp: Optional[float] = None): + if timestamp is not None: + self._start_time = time.time() + else: + self._start_time = timestamp + def should_report(self, trials: List[Trial], done: bool = False): if time.time() - self._last_report_time > self._max_report_freqency: self._last_report_time = time.time() @@ -267,7 +278,11 @@ def _progress_str(self, if not self._metrics_override: user_metrics = self._infer_user_metrics(trials, self._infer_limit) self._metric_columns.update(user_metrics) - messages = ["== Status ==", memory_debug_str(), *sys_info] + messages = [ + "== Status ==", + time_passed_str(self._start_time, time.time()), + memory_debug_str(), *sys_info + ] if done: max_progress = None max_error = None @@ -416,15 +431,32 @@ def __init__( "to `tune.run()` instead.") self._overwrite = overwrite + self._output_queue = None + + def set_output_queue(self, queue: Queue): + self._output_queue = queue def report(self, trials: List[Trial], done: bool, *sys_info: Dict): - from IPython.display import clear_output - from IPython.core.display import display, HTML - if self._overwrite: - clear_output(wait=True) + overwrite = self._overwrite progress_str = self._progress_str( trials, done, *sys_info, fmt="html", delim="
") - display(HTML(progress_str)) + + def update_output(): + from IPython.display import clear_output + from IPython.core.display import display, HTML + + if overwrite: + clear_output(wait=True) + + display(HTML(progress_str)) + + if self._output_queue is not None: + # If an output queue is set, send callable (e.g. when using + # Ray client) + self._output_queue.put(update_output) + else: + # Else, output directly + update_output() @PublicAPI @@ -510,6 +542,33 @@ def memory_debug_str(): "to resolve)") +def time_passed_str(start_time: float, current_time: float): + current_time_dt = datetime.datetime.fromtimestamp(current_time) + start_time_dt = datetime.datetime.fromtimestamp(start_time) + delta: datetime.timedelta = current_time_dt - start_time_dt + + rest = delta.total_seconds() + days = rest // (60 * 60 * 24) + + rest -= days * (60 * 60 * 24) + hours = rest // (60 * 60) + + rest -= hours * (60 * 60) + minutes = rest // 60 + + seconds = rest - minutes * 60 + + if days > 0: + running_for_str = f"{days:.0f} days, " + else: + running_for_str = "" + + running_for_str += f"{hours:02.0f}:{minutes:02.0f}:{seconds:05.2f}" + + return (f"Current time: {current_time_dt:%Y-%m-%d %H:%M:%S} " + f"(running for {running_for_str})") + + def _get_trials_by_state(trials: List[Trial]): trials_by_state = collections.defaultdict(list) for t in trials: @@ -774,6 +833,18 @@ def _fair_filter_trials(trials_by_state: Dict[str, List[Trial]], return filtered_trials +def _get_trial_location(trial: Trial, result: dict) -> Location: + # we get the location from the result, as the one in trial will be + # reset when trial terminates + node_ip, pid = result.get(NODE_IP, None), result.get(PID, None) + if node_ip and pid: + location = Location(node_ip, pid) + else: + # fallback to trial location if there hasn't been a report yet + location = trial.location + return location + + def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]): """Returns the following information about a trial: @@ -786,7 +857,8 @@ def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]): """ result = trial.last_result config = trial.config - trial_info = [str(trial), trial.status, str(trial.location)] + location = _get_trial_location(trial, result) + trial_info = [str(trial), trial.status, str(location)] trial_info += [ unflattened_lookup(param, config, default=None) for param in parameters ] diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 52ec1102a5f78..959fba6c0dcff 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -93,11 +93,18 @@ class _TrialCleanup: Args: threshold (int): Number of futures to hold at once. If the threshold is passed, cleanup will kick in and remove futures. + force_cleanup (int): Grace periods for forceful actor termination. + If 0, actors will not be forcefully terminated. """ - def __init__(self, threshold: int = TRIAL_CLEANUP_THRESHOLD): + def __init__(self, + threshold: int = TRIAL_CLEANUP_THRESHOLD, + force_cleanup: int = 0): self.threshold = threshold self._cleanup_map = {} + if force_cleanup < 0: + force_cleanup = 0 + self._force_cleanup = force_cleanup def add(self, trial: Trial, actor: ActorHandle): """Adds a trial actor to be stopped. @@ -123,15 +130,27 @@ def cleanup(self, partial: bool = True): If partial=False, all futures are expected to return. If a future does not return within the timeout period, the cleanup terminates. """ + # At this point, self._cleanup_map holds the last references + # to actors. Removing those references either one-by-one + # (graceful termination case) or all at once, by reinstantiating + # self._cleanup_map (forceful termination case) will cause Ray + # to kill the actors during garbage collection. logger.debug("Cleaning up futures") num_to_keep = int(self.threshold) / 2 if partial else 0 while len(self._cleanup_map) > num_to_keep: dones, _ = ray.wait( - list(self._cleanup_map), timeout=DEFAULT_GET_TIMEOUT) + list(self._cleanup_map), + timeout=DEFAULT_GET_TIMEOUT + if not self._force_cleanup else self._force_cleanup) if not dones: logger.warning( "Skipping cleanup - trainable.stop did not return in " "time. Consider making `stop` a faster operation.") + if not partial and self._force_cleanup: + logger.warning( + "Forcing trainable cleanup by terminating actors.") + self._cleanup_map = {} + return else: done = dones[0] del self._cleanup_map[done] @@ -165,7 +184,9 @@ def __init__(self, # We use self._paused to store paused trials here. self._paused = {} - self._trial_cleanup = _TrialCleanup() + force_trial_cleanup = int( + os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "0")) + self._trial_cleanup = _TrialCleanup(force_cleanup=force_trial_cleanup) self._has_cleaned_up_pgs = False self._reuse_actors = reuse_actors # The maxlen will be updated when `set_max_pending_trials()` is called diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index d83c727179387..9f143db42d37d 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -1,5 +1,8 @@ import logging +import uuid + from types import FunctionType +from typing import Optional import ray import ray.cloudpickle as pickle @@ -114,23 +117,25 @@ def check_serializability(key, value): _global_registry.register(TEST, key, value) -def _make_key(category, key): +def _make_key(prefix, category, key): """Generate a binary key for the given category and key. Args: + prefix (str): Prefix category (str): The category of the item key (str): The unique identifier for the item Returns: The key to use for storing a the value. """ - return (b"TuneRegistry:" + category.encode("ascii") + b"/" + - key.encode("ascii")) + return (b"TuneRegistry:" + prefix.encode("ascii") + b":" + + category.encode("ascii") + b"/" + key.encode("ascii")) class _Registry: - def __init__(self): + def __init__(self, prefix: Optional[str] = None): self._to_flush = {} + self._prefix = prefix or uuid.uuid4().hex[:8] def register(self, category, key, value): """Registers the value with the global registry. @@ -148,14 +153,14 @@ def register(self, category, key, value): def contains(self, category, key): if _internal_kv_initialized(): - value = _internal_kv_get(_make_key(category, key)) + value = _internal_kv_get(_make_key(self._prefix, category, key)) return value is not None else: return (category, key) in self._to_flush def get(self, category, key): if _internal_kv_initialized(): - value = _internal_kv_get(_make_key(category, key)) + value = _internal_kv_get(_make_key(self._prefix, category, key)) if value is None: raise ValueError( "Registry value for {}/{} doesn't exist.".format( @@ -166,11 +171,12 @@ def get(self, category, key): def flush_values(self): for (category, key), value in self._to_flush.items(): - _internal_kv_put(_make_key(category, key), value, overwrite=True) + _internal_kv_put( + _make_key(self._prefix, category, key), value, overwrite=True) self._to_flush.clear() -_global_registry = _Registry() +_global_registry = _Registry(prefix="global") ray.worker._post_init_hooks.append(_global_registry.flush_values) diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index e9eb7f40212dc..c166da2c0ce8e 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -67,6 +67,10 @@ DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL, MEAN_ACCURACY, MEAN_LOSS) +# Metrics that don't require at least one iteration to complete +DEBUG_METRICS = (TRIAL_ID, "experiment_id", "date", "timestamp", PID, HOSTNAME, + NODE_IP, "config") + # Make sure this doesn't regress AUTO_RESULT_KEYS = ( TRAINING_ITERATION, diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index bb49de900fb1d..598b2a2dccf59 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -226,6 +226,14 @@ class B(Trainable): self.assertRaises(TypeError, lambda: register_trainable("foo", A)) self.assertRaises(TypeError, lambda: Experiment("foo", A)) + def testRegisterDurableTrainableTwice(self): + def train(config, reporter): + pass + + register_trainable("foo", train) + register_trainable("foo", tune.durable("foo")) + register_trainable("foo", tune.durable("foo")) + def testTrainableCallable(self): def dummy_fn(config, reporter, steps): reporter(timesteps_total=steps, done=True) diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 98dd4e4b2da58..87a6f42f7af25 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -190,7 +190,7 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster): running_trials = _get_running_trials(runner) assert len(running_trials) == 1 assert _check_trial_running(running_trials[0]) - assert not trial.last_result + assert not trial.has_reported_at_least_once assert trial.status == Trial.RUNNING cluster.remove_node(node) cluster.add_node(num_cpus=1) diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index 2e85fe0a6b368..6978d2c128c6f 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -3,13 +3,14 @@ import os import unittest from unittest.mock import MagicMock, Mock, patch + from ray import tune from ray._private.test_utils import run_string_as_driver from ray.tune.trial import Trial from ray.tune.result import AUTO_RESULT_KEYS -from ray.tune.progress_reporter import (CLIReporter, JupyterNotebookReporter, - _fair_filter_trials, best_trial_str, - detect_reporter, trial_progress_str) +from ray.tune.progress_reporter import ( + CLIReporter, JupyterNotebookReporter, _fair_filter_trials, best_trial_str, + detect_reporter, trial_progress_str, time_passed_str) EXPECTED_RESULT_1 = """Result logdir: /foo Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED) @@ -60,76 +61,92 @@ END_TO_END_COMMAND = """ import ray from ray import tune +from ray.tune.trial import Location +from ray.tune.progress_reporter import _get_trial_location +from unittest.mock import patch -reporter = tune.progress_reporter.CLIReporter(metric_columns=["done"]) -def f(config): - return {"done": True} +def mock_get_trial_location(trial, result): + location = _get_trial_location(trial, result) + if location.pid: + return Location("123.123.123.123", "1") + return location -ray.init(num_cpus=1) -tune.run_experiments({ - "one": { - "run": f, - "config": { - "a": tune.grid_search(list(range(10))), - }, - }, - "two": { - "run": f, - "config": { - "b": tune.grid_search(list(range(10))), - }, - }, - "three": { - "run": f, - "config": { - "c": tune.grid_search(list(range(10))), + +with patch("ray.tune.progress_reporter._get_trial_location", + mock_get_trial_location): + reporter = tune.progress_reporter.CLIReporter(metric_columns=["done"]) + + def f(config): + return {"done": True} + + ray.init(num_cpus=1) + tune.run_experiments( + { + "one": { + "run": f, + "config": { + "a": tune.grid_search(list(range(10))), + }, + }, + "two": { + "run": f, + "config": { + "b": tune.grid_search(list(range(10))), + }, + }, + "three": { + "run": f, + "config": { + "c": tune.grid_search(list(range(10))), + }, + }, }, - }, -}, verbose=3, progress_reporter=reporter)""" + verbose=3, + progress_reporter=reporter)""" EXPECTED_END_TO_END_START = """Number of trials: 30/30 (29 PENDING, 1 RUNNING) -+---------------+----------+-------+-----+-----+ -| Trial name | status | loc | a | b | -|---------------+----------+-------+-----+-----| -| f_xxxxx_00000 | RUNNING | | 0 | | -| f_xxxxx_00001 | PENDING | | 1 | |""" ++---------------+----------+-------------------+-----+-----+ +| Trial name | status | loc | a | b | +|---------------+----------+-------------------+-----+-----| +| f_xxxxx_00000 | RUNNING | 123.123.123.123:1 | 0 | | +| f_xxxxx_00001 | PENDING | | 1 | |""" EXPECTED_END_TO_END_END = """Number of trials: 30/30 (30 TERMINATED) -+---------------+------------+-------+-----+-----+-----+--------+ -| Trial name | status | loc | a | b | c | done | -|---------------+------------+-------+-----+-----+-----+--------| -| f_xxxxx_00000 | TERMINATED | | 0 | | | True | -| f_xxxxx_00001 | TERMINATED | | 1 | | | True | -| f_xxxxx_00002 | TERMINATED | | 2 | | | True | -| f_xxxxx_00003 | TERMINATED | | 3 | | | True | -| f_xxxxx_00004 | TERMINATED | | 4 | | | True | -| f_xxxxx_00005 | TERMINATED | | 5 | | | True | -| f_xxxxx_00006 | TERMINATED | | 6 | | | True | -| f_xxxxx_00007 | TERMINATED | | 7 | | | True | -| f_xxxxx_00008 | TERMINATED | | 8 | | | True | -| f_xxxxx_00009 | TERMINATED | | 9 | | | True | -| f_xxxxx_00010 | TERMINATED | | | 0 | | True | -| f_xxxxx_00011 | TERMINATED | | | 1 | | True | -| f_xxxxx_00012 | TERMINATED | | | 2 | | True | -| f_xxxxx_00013 | TERMINATED | | | 3 | | True | -| f_xxxxx_00014 | TERMINATED | | | 4 | | True | -| f_xxxxx_00015 | TERMINATED | | | 5 | | True | -| f_xxxxx_00016 | TERMINATED | | | 6 | | True | -| f_xxxxx_00017 | TERMINATED | | | 7 | | True | -| f_xxxxx_00018 | TERMINATED | | | 8 | | True | -| f_xxxxx_00019 | TERMINATED | | | 9 | | True | -| f_xxxxx_00020 | TERMINATED | | | | 0 | True | -| f_xxxxx_00021 | TERMINATED | | | | 1 | True | -| f_xxxxx_00022 | TERMINATED | | | | 2 | True | -| f_xxxxx_00023 | TERMINATED | | | | 3 | True | -| f_xxxxx_00024 | TERMINATED | | | | 4 | True | -| f_xxxxx_00025 | TERMINATED | | | | 5 | True | -| f_xxxxx_00026 | TERMINATED | | | | 6 | True | -| f_xxxxx_00027 | TERMINATED | | | | 7 | True | -| f_xxxxx_00028 | TERMINATED | | | | 8 | True | -| f_xxxxx_00029 | TERMINATED | | | | 9 | True | -+---------------+------------+-------+-----+-----+-----+--------+""" ++---------------+------------+-------------------+-----+-----+-----+--------+ +| Trial name | status | loc | a | b | c | done | +|---------------+------------+-------------------+-----+-----+-----+--------| +| f_xxxxx_00000 | TERMINATED | 123.123.123.123:1 | 0 | | | True | +| f_xxxxx_00001 | TERMINATED | 123.123.123.123:1 | 1 | | | True | +| f_xxxxx_00002 | TERMINATED | 123.123.123.123:1 | 2 | | | True | +| f_xxxxx_00003 | TERMINATED | 123.123.123.123:1 | 3 | | | True | +| f_xxxxx_00004 | TERMINATED | 123.123.123.123:1 | 4 | | | True | +| f_xxxxx_00005 | TERMINATED | 123.123.123.123:1 | 5 | | | True | +| f_xxxxx_00006 | TERMINATED | 123.123.123.123:1 | 6 | | | True | +| f_xxxxx_00007 | TERMINATED | 123.123.123.123:1 | 7 | | | True | +| f_xxxxx_00008 | TERMINATED | 123.123.123.123:1 | 8 | | | True | +| f_xxxxx_00009 | TERMINATED | 123.123.123.123:1 | 9 | | | True | +| f_xxxxx_00010 | TERMINATED | 123.123.123.123:1 | | 0 | | True | +| f_xxxxx_00011 | TERMINATED | 123.123.123.123:1 | | 1 | | True | +| f_xxxxx_00012 | TERMINATED | 123.123.123.123:1 | | 2 | | True | +| f_xxxxx_00013 | TERMINATED | 123.123.123.123:1 | | 3 | | True | +| f_xxxxx_00014 | TERMINATED | 123.123.123.123:1 | | 4 | | True | +| f_xxxxx_00015 | TERMINATED | 123.123.123.123:1 | | 5 | | True | +| f_xxxxx_00016 | TERMINATED | 123.123.123.123:1 | | 6 | | True | +| f_xxxxx_00017 | TERMINATED | 123.123.123.123:1 | | 7 | | True | +| f_xxxxx_00018 | TERMINATED | 123.123.123.123:1 | | 8 | | True | +| f_xxxxx_00019 | TERMINATED | 123.123.123.123:1 | | 9 | | True | +| f_xxxxx_00020 | TERMINATED | 123.123.123.123:1 | | | 0 | True | +| f_xxxxx_00021 | TERMINATED | 123.123.123.123:1 | | | 1 | True | +| f_xxxxx_00022 | TERMINATED | 123.123.123.123:1 | | | 2 | True | +| f_xxxxx_00023 | TERMINATED | 123.123.123.123:1 | | | 3 | True | +| f_xxxxx_00024 | TERMINATED | 123.123.123.123:1 | | | 4 | True | +| f_xxxxx_00025 | TERMINATED | 123.123.123.123:1 | | | 5 | True | +| f_xxxxx_00026 | TERMINATED | 123.123.123.123:1 | | | 6 | True | +| f_xxxxx_00027 | TERMINATED | 123.123.123.123:1 | | | 7 | True | +| f_xxxxx_00028 | TERMINATED | 123.123.123.123:1 | | | 8 | True | +| f_xxxxx_00029 | TERMINATED | 123.123.123.123:1 | | | 9 | True | ++---------------+------------+-------------------+-----+-----+-----+--------+""" # noqa EXPECTED_END_TO_END_AC = """Number of trials: 30/30 (30 TERMINATED) +---------------+------------+-------+-----+-----+-----+ @@ -217,15 +234,26 @@ def f(config): Trial train_xxxxx_00002 reported acc=8 with parameters={'do': 'twice'}. """ + \ "This trial completed." -VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------+----------+ -| Trial name | status | loc | do | -|-------------------+----------+-------+----------| -| train_xxxxx_00000 | RUNNING | | complete |""" +VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------------------+----------+ +| Trial name | status | loc | do | +|-------------------+----------+-------------------+----------| +| train_xxxxx_00000 | RUNNING | 123.123.123.123:1 | complete |""" VERBOSE_CMD = """from ray import tune import random import numpy as np import time +from ray.tune.trial import Location +from ray.tune.progress_reporter import _get_trial_location +from unittest.mock import patch + + +def mock_get_trial_location(trial, result): + location = _get_trial_location(trial, result) + if location.pid: + return Location("123.123.123.123", "1") + return location + def train(config): if config["do"] == "complete": @@ -242,11 +270,14 @@ def train(config): random.seed(1234) np.random.seed(1234) -tune.run( - train, - config={ - "do": tune.grid_search(["complete", "once", "twice"]) - },""" + +with patch("ray.tune.progress_reporter._get_trial_location", + mock_get_trial_location): + tune.run( + train, + config={ + "do": tune.grid_search(["complete", "once", "twice"]) + },""" # Add "verbose=3)" etc @@ -424,6 +455,27 @@ def testProgressStr(self): best1 = best_trial_str(trials[1], "metric_1") assert best1 == EXPECTED_BEST_1 + def testTimeElapsed(self): + # Sun Feb 7 14:18:40 2016 -0800 + # (time of the first Ray commit) + time_start = 1454825920 + time_now = ( + time_start + 1 * 60 * 60 # 1 hour + + 31 * 60 # 31 minutes + + 22 # 22 seconds + ) # time to second commit + + # Local timezone output can be tricky, so we don't check the + # day and the hour in this test. + output = time_passed_str(time_start, time_now) + self.assertIn("Current time: 2016-02-", output) + self.assertIn(":50:02 (running for 01:31:22.00)", output) + + time_now += 2 * 60 * 60 * 24 # plus two days + output = time_passed_str(time_start, time_now) + self.assertIn("Current time: 2016-02-", output) + self.assertIn(":50:02 (running for 2 days, 01:31:22.00)", output) + def testCurrentBestTrial(self): trials = [] for i in range(5): diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index a21664a2c11ee..f5d87e7dd1926 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -2,6 +2,7 @@ import os import pytest +import time import unittest import ray @@ -11,7 +12,7 @@ from ray.tune.callback import Callback from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.registry import _global_registry, TRAINABLE_CLASS -from ray.tune.result import TRAINING_ITERATION +from ray.tune.result import PID, TRAINING_ITERATION, TRIAL_ID from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, Checkpoint from ray.tune.resources import Resources @@ -252,6 +253,68 @@ def reset_config(self, config): self.assertEqual(trial.experiment_tag, "modified_mock") self.assertEqual(Trial.RUNNING, trial.status) + def testForceTrialCleanup(self): + class B(Trainable): + def step(self): + print("Step start") + time.sleep(10) + print("Step done") + return dict(my_metric=1, timesteps_this_iter=1, done=True) + + def reset_config(self, config): + self.config = config + return True + + def cleanup(self): + print("Cleanup start") + time.sleep(10) + print("Cleanup done") + + # First check if the trials terminate gracefully by default + trials = self.generate_trials({ + "run": B, + "config": { + "foo": 0 + }, + }, "grid_search") + trial = trials[0] + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + time.sleep(5) + print("Stop trial") + self.trial_executor.stop_trial(trial) + print("Start trial cleanup") + start = time.time() + self.trial_executor.cleanup([trial]) + self.assertGreaterEqual(time.time() - start, 12.0) + + # Check forceful termination. It should run for much less than the + # sleep periods in the Trainable + trials = self.generate_trials({ + "run": B, + "config": { + "foo": 0 + }, + }, "grid_search") + trial = trials[0] + os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1" + self.trial_executor = RayTrialExecutor(queue_trials=False) + os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0" + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + time.sleep(5) + print("Stop trial") + self.trial_executor.stop_trial(trial) + print("Start trial cleanup") + start = time.time() + self.trial_executor.cleanup([trial]) + self.assertLess(time.time() - start, 5.0) + + # also check if auto-filled metrics were returned + self.assertIn(PID, trial.last_result) + self.assertIn(TRIAL_ID, trial.last_result) + self.assertNotIn("my_metric", trial.last_result) + @staticmethod def generate_trials(spec, name): suggester = BasicVariantGenerator() @@ -480,6 +543,10 @@ def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects + def testForceTrialCleanup(self): + self.skipTest("Skipping as force trial cleanup is not applicable" + " for local mode.") + if __name__ == "__main__": import sys diff --git a/python/ray/tune/tests/test_trial_runner_3.py b/python/ray/tune/tests/test_trial_runner_3.py index e467eafa5e51e..44341ebf99cf6 100644 --- a/python/ray/tune/tests/test_trial_runner_3.py +++ b/python/ray/tune/tests/test_trial_runner_3.py @@ -555,7 +555,8 @@ def testTrialNoSave(self): self.assertTrue( runner2.get_trial("checkpoint").status == Trial.TERMINATED) self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING) - self.assertTrue(not runner2.get_trial("pending").last_result) + self.assertTrue( + not runner2.get_trial("pending").has_reported_at_least_once) runner2.step() def testCheckpointWithFunction(self): diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index 16f40b7602712..f9cf300948ea6 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -154,7 +154,7 @@ def testCallbackSteps(self): result = {TRAINING_ITERATION: 1, "metric": 800, "done": False} self.executor.results[trials[1]] = result self.executor.next_trial = trials[1] - self.assertEqual(trials[1].last_result, {}) + self.assertTrue(not trials[1].has_reported_at_least_once) self.trial_runner.step() self.assertEqual(self.callback.state["trial_result"]["iteration"], 3) self.assertEqual(self.callback.state["trial_result"]["trial"].trial_id, diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 0e0a2dd65c701..798c08192ab21 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -847,6 +847,7 @@ def __init__(self, i, config): self.resources = Resources(1, 0) self.custom_trial_name = None self.custom_dirname = None + self._default_result_or_future = None def on_checkpoint(self, checkpoint): self.restored_checkpoint = checkpoint.value diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index 81b90dcfeebf2..31a8f02132101 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -192,8 +192,13 @@ def MockTrainingFuncSync(config, checkpoint_dir=None): "checkpoint") with open(checkpoint_path, "wb") as fp: pickle.dump((a, iter), fp) + # Different sleep times so that asynch test runs do not + # randomly succeed. If well performing trials finish later, + # then bad performing trials will already have continued + # to train, which is exactly what we want to test when + # comparing sync vs. async. + time.sleep(a / 20) # Score gets better every iteration. - time.sleep(1) tune.report(mean_accuracy=iter + a, a=a) self.MockTrainingFuncSync = MockTrainingFuncSync @@ -201,7 +206,10 @@ def MockTrainingFuncSync(config, checkpoint_dir=None): def tearDown(self): ray.shutdown() - def synchSetup(self, synch, param=[10, 20, 30]): + def synchSetup(self, synch, param=None): + if param is None: + param = [10, 20, 30] + scheduler = PopulationBasedTraining( time_attr="training_iteration", metric="mean_accuracy", diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 7e63147ca4e00..3299d7aa4e861 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -8,17 +8,18 @@ import sys import tempfile import time -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import uuid import ray import ray.cloudpickle as pickle from ray.tune.resources import Resources from ray.tune.result import ( - DEFAULT_RESULTS_DIR, SHOULD_CHECKPOINT, TIME_THIS_ITER_S, - TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL, EPISODES_THIS_ITER, - EPISODES_TOTAL, TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_INFO, - STDOUT_FILE, STDERR_FILE) + DEBUG_METRICS, DEFAULT_RESULTS_DIR, HOSTNAME, NODE_IP, PID, + SHOULD_CHECKPOINT, TIME_THIS_ITER_S, TIME_TOTAL_S, TIMESTEPS_THIS_ITER, + DONE, TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, + TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_ID, TRIAL_INFO, STDOUT_FILE, + STDERR_FILE) from ray.tune.utils import UtilMonitor from ray.tune.utils.placement_groups import PlacementGroupFactory from ray.tune.utils.trainable import TrainableUtil @@ -154,6 +155,40 @@ def get_current_ip(self): self._local_ip = ray.util.get_node_ip_address() return self._local_ip + def get_auto_filled_metrics(self, + now: Optional[datetime] = None, + time_this_iter: Optional[float] = None, + debug_metrics_only: bool = False) -> dict: + """Return a dict with metrics auto-filled by the trainable. + + If ``debug_metrics_only`` is True, only metrics that don't + require at least one iteration will be returned + (``ray.tune.result.DEBUG_METRICS``). + """ + if now is None: + now = datetime.today() + autofilled = { + TRIAL_ID: self.trial_id, + "experiment_id": self._experiment_id, + "date": now.strftime("%Y-%m-%d_%H-%M-%S"), + "timestamp": int(time.mktime(now.timetuple())), + TIME_THIS_ITER_S: time_this_iter, + TIME_TOTAL_S: self._time_total, + PID: os.getpid(), + HOSTNAME: platform.node(), + NODE_IP: self._local_ip, + "config": self.config, + "time_since_restore": self._time_since_restore, + "timesteps_since_restore": self._timesteps_since_restore, + "iterations_since_restore": self._iterations_since_restore + } + if debug_metrics_only: + autofilled = { + k: v + for k, v in autofilled.items() if k in DEBUG_METRICS + } + return autofilled + def is_actor(self): try: actor_id = ray.worker.global_worker.actor_id @@ -289,19 +324,7 @@ def train(self): result.setdefault("neg_mean_loss", -result["mean_loss"]) now = datetime.today() - result.update( - experiment_id=self._experiment_id, - date=now.strftime("%Y-%m-%d_%H-%M-%S"), - timestamp=int(time.mktime(now.timetuple())), - time_this_iter_s=time_this_iter, - time_total_s=self._time_total, - pid=os.getpid(), - hostname=platform.node(), - node_ip=self._local_ip, - config=self.config, - time_since_restore=self._time_since_restore, - timesteps_since_restore=self._timesteps_since_restore, - iterations_since_restore=self._iterations_since_restore) + result.update(self.get_auto_filled_metrics(now, time_this_iter)) monitor_data = self._monitor.get_data() if monitor_data: diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 6398b53f2292f..ede51f26ba5b1 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -8,19 +8,20 @@ import re import shutil import time -from typing import Callable, Dict, Sequence, Union +from typing import Callable, Dict, Optional, Sequence, Union import uuid import ray import ray.cloudpickle as cloudpickle -from ray.exceptions import GetTimeoutError +from ray.exceptions import GetTimeoutError, RayActorError from ray.tune import TuneError from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not # have been defined yet. See https://github.com/ray-project/ray/issues/1716. from ray.tune.registry import get_trainable_cls, validate_trainable -from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION +from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, NODE_IP, PID, + TRAINING_ITERATION, TRIAL_ID) from ray.tune.resources import Resources, \ json_to_resources, resources_to_json from ray.tune.utils.placement_groups import PlacementGroupFactory, \ @@ -299,7 +300,9 @@ def __init__(self, self.max_failures = max_failures # Local trial state that is updated during the run - self.last_result = {} + self._last_result = {} + self._default_result_or_future: Union[ray.ObjectRef, dict, None] = ( + None) self.last_update_time = -float("inf") # stores in memory max/min/avg/last-n-avg/last result for each @@ -394,6 +397,52 @@ def _setup_resources(self, log_always: bool = False): resource_kwargs["has_placement_group"] = True self.resources = Resources(**resource_kwargs) + def _get_default_result_or_future(self) -> Optional[dict]: + """Calls ray.get on self._default_result_or_future and assigns back. + + Returns None in case of exceptions. + Will also set the trial location if runner is set. + """ + if self._default_result_or_future and isinstance( + self._default_result_or_future, ray.ObjectRef): + try: + self._default_result_or_future = ray.get( + self._default_result_or_future) + except RayActorError: # error during initialization + self._default_result_or_future = None + if self._default_result_or_future and self.runner: + self.set_location( + Location( + self._default_result_or_future.get(NODE_IP), + self._default_result_or_future.get(PID))) + return self._default_result_or_future + + @property + def last_result(self) -> dict: + # The logic in here is as follows: + # 1. If the trial has reported at least once, last_result would have + # been set and therefore would not be empty. We can just return it. + # 2. If the trial has not reported at least once but we have the + # future for the default results dict, (obtained through + # Trainable.get_auto_filled_metrics), we get that future + # and return it. + # 3. In the worst case where we have nothing, we just set the + # trial_id and return that. + result = self._last_result + if not {k for k in result if k != TRIAL_ID}: + self._get_default_result_or_future() + result = self._default_result_or_future or result + result.setdefault(TRIAL_ID, self.trial_id) + return result + + @last_result.setter + def last_result(self, val: dict): + self._last_result = val + + @property + def has_reported_at_least_once(self) -> bool: + return bool(self._last_result) + @property def node_ip(self): return self.location.hostname @@ -499,6 +548,11 @@ def update_resources( def set_runner(self, runner): self.runner = runner + if runner: + # Do not block here, the result will be gotten when last_result + # property is accessed + self._default_result_or_future = ( + runner.get_auto_filled_metrics.remote(debug_metrics_only=True)) self.checkpoint_manager.delete = CheckpointDeleter( self._trainable_name(), runner, self.node_ip) # No need to invalidate state cache: runner is not stored in json @@ -603,7 +657,7 @@ def update_last_result(self, result, terminate=False): if self.experiment_tag: result.update(experiment_tag=self.experiment_tag) - self.set_location(Location(result.get("node_ip"), result.get("pid"))) + self.set_location(Location(result.get(NODE_IP), result.get(PID))) self.last_result = result self.last_update_time = time.time() @@ -729,6 +783,7 @@ def __getstate__(self): state["_state_json"] = None state["_state_valid"] = False + state["_default_result_or_future"] = None return copy.deepcopy(state) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 0d91ee3b8bc65..86759e1212188 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -15,8 +15,9 @@ from ray.tune.callback import CallbackList from ray.tune.stopper import NoopStopper from ray.tune.ray_trial_executor import RayTrialExecutor -from ray.tune.result import (DEFAULT_METRIC, TIME_THIS_ITER_S, - RESULT_DUPLICATE, SHOULD_CHECKPOINT) +from ray.tune.result import (DEBUG_METRICS, DEFAULT_METRIC, DONE, + TIME_THIS_ITER_S, RESULT_DUPLICATE, + SHOULD_CHECKPOINT) from ray.tune.syncer import CloudSyncer, get_cloud_syncer from ray.tune.trial import Checkpoint, Trial from ray.tune.schedulers import FIFOScheduler, TrialScheduler @@ -195,7 +196,9 @@ class TrialRunner: """ CKPT_FILE_TMPL = "experiment_state-{}.json" - VALID_RESUME_TYPES = [True, "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY"] + VALID_RESUME_TYPES = [ + True, "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY", "AUTO" + ] RAISE = "RAISE" def __init__(self, @@ -415,7 +418,7 @@ def _validate_resume(self, resume_type): Args: resume_type: One of True, "REMOTE", "LOCAL", - "PROMPT", "ERRORED_ONLY". + "PROMPT", "ERRORED_ONLY", "AUTO". """ # TODO: Consider supporting ERRORED_ONLY+REMOTE? if not resume_type: @@ -426,11 +429,54 @@ def _validate_resume(self, resume_type): # Not clear if we need this assertion, since we should always have a # local checkpoint dir. assert self._local_checkpoint_dir or self._remote_checkpoint_dir + + if resume_type == "AUTO": + if self._remote_checkpoint_dir: + logger.info( + f"Trying to find and download experiment checkpoint at " + f"{self._remote_checkpoint_dir}") + # Todo: This syncs the entire experiment including trial + # checkpoints. We should exclude these in the future. + try: + self._syncer.sync_down_if_needed() + self._syncer.wait() + except TuneError as e: + logger.warning( + f"Got error when trying to sync down: {e} " + f"\nPlease check this error message for potential " + f"access problems - if a directory was not found, " + f"that is expected at this stage when you're starting " + f"a new experiment.") + logger.info( + "No remote checkpoint was found or an error occurred " + "when trying to download the experiment checkpoint. " + "Please check the previous warning message for more " + "details. " + "Ray Tune will now start a new experiment.") + return False + logger.info( + "A remote experiment checkpoint was found and will be " + "used to restore the previous experiment state.") + return True + elif not self.checkpoint_exists(self._local_checkpoint_dir): + logger.info("No local checkpoint was found. " + "Ray Tune will now start a new experiment.") + return False + logger.info( + "A local experiment checkpoint was found and will be used " + "to restore the previous experiment state.") + return True + if resume_type in [True, "LOCAL", "PROMPT", "ERRORED_ONLY"]: if not self.checkpoint_exists(self._local_checkpoint_dir): raise ValueError( - f"Called resume ({resume_type}) when no checkpoint exists " - f"in local directory ({self._local_checkpoint_dir}).") + f"You called resume ({resume_type}) when no checkpoint " + f"exists in local directory " + f"({self._local_checkpoint_dir}). If you want to start " + f"a new experiment, use `resume=\"AUTO\"` or " + f"`resume=None`. If you expected an experiment to " + f"already exist, check if you supplied the correct " + f"`local_dir` to `tune.run()`.") elif resume_type == "PROMPT": if click.confirm(f"Resume from local directory? " f"({self._local_checkpoint_dir})"): @@ -448,12 +494,22 @@ def _validate_resume(self, resume_type): "`upload_dir` set to `tune.run(sync_config=...)`.") # Try syncing down the upload directory. - logger.info("Downloading from %s", self._remote_checkpoint_dir) - # TODO(ujvl): Note that this syncs down the entire directory, - # which may also contain trial checkpoints. We should selectively - # sync the necessary files instead. - self._syncer.sync_down_if_needed() - self._syncer.wait() + logger.info(f"Downloading experiment checkpoint from " + f"{self._remote_checkpoint_dir}") + # Todo: This syncs the entire experiment including trial + # checkpoints. We should exclude these in the future. + try: + self._syncer.sync_down_if_needed() + self._syncer.wait() + except TuneError as e: + raise RuntimeError( + "Syncing the remote experiment checkpoint to the driver " + "failed. Please check the error message. If you want to " + "start a new experiment, use `resume=\"AUTO\"` or " + "`resume=None`. If you expected an experiment to " + "already exist, check if you supplied the correct " + "`upload_dir` to the `tune.SyncConfig` passed to " + "`tune.run()`.") from e if not self.checkpoint_exists(self._local_checkpoint_dir): raise ValueError("Called resume when no checkpoint exists " @@ -932,15 +988,18 @@ def _process_trial_result(self, trial, result): def _validate_result_metrics(self, result): """ Check if any of the required metrics was not reported - in the last result. If the only item is `done=True`, this - means that no result was ever received and the trial just - returned. This is also okay and will not raise an error. + in the last result. If the only items are ``done`` or any of + DEBUG_METRICS, this means that no result was ever received and + the trial just returned. This is also okay and will not raise + an error. This will ignore checking for the DEFAULT_METRIC. """ - if int(os.environ.get("TUNE_DISABLE_STRICT_METRIC_CHECKING", - 0)) != 1 and (len(result) > 1 - or "done" not in result): + if int(os.environ.get( + "TUNE_DISABLE_STRICT_METRIC_CHECKING", 0)) != 1 and (len({ + k + for k in result if k not in list(DEBUG_METRICS) + [DONE] + }) > 1): base_metric = self._metric \ if self._metric != DEFAULT_METRIC else None scheduler_metric = self._scheduler_alg.metric \ diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 8077f7c6e6cd2..f6bcb2b56a2ba 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -11,13 +11,15 @@ import ray from ray.util.annotations import PublicAPI +from ray.util.queue import Queue, Empty from ray.tune.analysis import ExperimentAnalysis from ray.tune.callback import Callback from ray.tune.error import TuneError from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.logger import Logger -from ray.tune.progress_reporter import detect_reporter, ProgressReporter +from ray.tune.progress_reporter import (detect_reporter, ProgressReporter, + JupyterNotebookReporter) from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.registry import get_trainable_cls from ray.tune.stopper import Stopper @@ -314,7 +316,48 @@ def run( # Make sure tune.run is called on the sever node. remote_run = force_on_current_node(remote_run) - return ray.get(remote_run.remote(_remote=False, **remote_run_kwargs)) + # JupyterNotebooks don't work with remote tune runs out of the box + # (e.g. via Ray client) as they don't have access to the main + # process stdout. So we introduce a queue here that accepts + # callables, which will then be executed on the driver side. + if isinstance(progress_reporter, JupyterNotebookReporter): + execute_queue = Queue(actor_options={ + "num_cpus": 0, + **force_on_current_node(None) + }) + progress_reporter.set_output_queue(execute_queue) + + def get_next_queue_item(): + try: + return execute_queue.get(block=False) + except Empty: + return None + + else: + # If we don't need a queue, use this dummy get fn instead of + # scheduling an unneeded actor + def get_next_queue_item(): + return None + + def _handle_execute_queue(): + execute_item = get_next_queue_item() + while execute_item: + if isinstance(execute_item, Callable): + execute_item() + + execute_item = get_next_queue_item() + + remote_future = remote_run.remote(_remote=False, **remote_run_kwargs) + + # ray.wait(...)[1] returns futures that are not ready, yet + while ray.wait([remote_future], timeout=0.2)[1]: + # Check if we have items to execute + _handle_execute_queue() + + # Handle queue one last time + _handle_execute_queue() + + return ray.get(remote_future) del remote_run_kwargs @@ -531,6 +574,7 @@ def sigint_handler(sig, frame): signal.signal(signal.SIGINT, sigint_handler) tune_start = time.time() + progress_reporter.set_start_time(tune_start) while not runner.is_finished() and not state[signal.SIGINT]: runner.step() if has_verbosity(Verbosity.V1_EXPERIMENT): diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index 0f4612c66c047..c41179c43b845 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -645,14 +645,15 @@ def get_current_node_resource_key() -> str: raise ValueError("Cannot found the node dictionary for current node.") -def force_on_current_node(task_or_actor): +def force_on_current_node(task_or_actor=None): """Given a task or actor, place it on the current node. If using Ray Client, the current node is the client server node. Args: task_or_actor: A Ray remote function or class to place on the - current node. + current node. If None, returns the options dict to pass to + another actor. Returns: The provided task or actor, but with options modified to force @@ -660,6 +661,10 @@ def force_on_current_node(task_or_actor): """ node_resource_key = get_current_node_resource_key() options = {"resources": {node_resource_key: 0.01}} + + if task_or_actor is None: + return options + return task_or_actor.options(**options) diff --git a/python/ray/util/__init__.py b/python/ray/util/__init__.py index 3177925e68fc0..7a326154bf1e3 100644 --- a/python/ray/util/__init__.py +++ b/python/ray/util/__init__.py @@ -19,7 +19,7 @@ @PublicAPI(stability="beta") -@client_mode_hook +@client_mode_hook(auto_init=True) def list_named_actors(all_namespaces: bool = False) -> List[str]: """List all named actors in the system. diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index f11b692d56f42..29b95c850c2a4 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -5,7 +5,6 @@ import os import sys import logging -import json import threading import grpc @@ -66,7 +65,7 @@ def connect(self, job_config = job_config or JobConfig() job_config.set_ray_namespace(namespace) if job_config is not None: - runtime_env = json.loads(job_config.get_serialized_runtime_env()) + runtime_env = job_config.runtime_env if runtime_env.get("pip") or runtime_env.get("conda"): logger.warning("The 'pip' or 'conda' field was specified in " "the runtime env, so it may take some time to " diff --git a/python/ray/util/client/client_pickler.py b/python/ray/util/client/client_pickler.py index 9c1ebef68d565..0faf3c99c68cd 100644 --- a/python/ray/util/client/client_pickler.py +++ b/python/ray/util/client/client_pickler.py @@ -49,12 +49,17 @@ else: import pickle # noqa: F401 + # NOTE(barakmich): These PickleStubs are really close to -# the data for an exectuion, with no arguments. Combine the two? -PickleStub = NamedTuple("PickleStub", - [("type", str), ("client_id", str), ("ref_id", bytes), - ("name", Optional[str]), - ("baseline_options", Optional[Dict])]) +# the data for an execution, with no arguments. Combine the two? +class PickleStub( + NamedTuple("PickleStub", [("type", str), ("client_id", str), + ("ref_id", bytes), ("name", Optional[str]), + ("baseline_options", Optional[Dict])])): + def __reduce__(self): + # PySpark's namedtuple monkey patch breaks compatibility with + # cloudpickle. Thus we revert this patch here if it exists. + return object.__reduce__(self) class ClientPickler(cloudpickle.CloudPickler): diff --git a/python/ray/util/client/options.py b/python/ray/util/client/options.py index 9c9df946d0cf5..ec6c568d5b347 100644 --- a/python/ray/util/client/options.py +++ b/python/ray/util/client/options.py @@ -36,7 +36,6 @@ "placement_group_bundle_index": (), "placement_group_capture_child_tasks": (), "runtime_env": (), - "override_environment_variables": (), } diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 0fb2f07429b1d..165fbbcabe8a9 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -27,7 +27,7 @@ from ray.util.client.server.dataservicer import _get_reconnecting_from_context from ray._private.client_mode_hook import disable_client_hook from ray._private.parameter import RayParams -from ray._private.runtime_env import RuntimeEnvContext +from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.services import ProcessInfo, start_ray_client_server from ray._private.utils import (detect_fate_sharing_support, add_port_to_grpc_server) @@ -264,7 +264,9 @@ def start_specific_server(self, client_id: str, f"ray_client_server_{specific_server.port}", unique=True) serialized_runtime_env = job_config.get_serialized_runtime_env() - if serialized_runtime_env == "{}": + if not serialized_runtime_env or serialized_runtime_env == "{}": + # TODO(edoakes): can we just remove this case and always send it + # to the agent? serialized_runtime_env_context = RuntimeEnvContext().serialize() else: serialized_runtime_env_context = self._create_runtime_env( diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index b5c50215c5488..e07f74c6c50c9 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -128,6 +128,9 @@ def __init__( self._connect_channel() self._has_connected = True + # Has Ray been initialized on the server? + self._serverside_ray_initialized = False + # Initialize the streams to finish protocol negotiation. self.data_client = DataClient(self, self._client_id, self.metadata) self.reference_count: Dict[bytes, int] = defaultdict(int) @@ -647,10 +650,17 @@ def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]: return json.loads(self.data_client.ListNamedActors(req).actors_json) def is_initialized(self) -> bool: - if self.server is not None: - return self.get_cluster_info( + if not self.is_connected() or self.server is None: + return False + if not self._serverside_ray_initialized: + # We only check that Ray is initialized on the server once to + # avoid making an RPC every time this function is called. This is + # safe to do because Ray only 'un-initializes' on the server when + # the Client connection is torn down. + self._serverside_ray_initialized = self.get_cluster_info( ray_client_pb2.ClusterInfoType.IS_INITIALIZED) - return False + + return self._serverside_ray_initialized def ping_server(self, timeout=None) -> bool: """Simple health check. diff --git a/python/ray/util/dask/scheduler_utils.py b/python/ray/util/dask/scheduler_utils.py index a1805048c989b..dba0c660b0c5b 100644 --- a/python/ray/util/dask/scheduler_utils.py +++ b/python/ray/util/dask/scheduler_utils.py @@ -371,8 +371,11 @@ def fire_task(): return nested_get(result, state["cache"]) -def apply_sync(func, args=(), kwds={}, callback=None): +def apply_sync(func, args=(), kwds=None, callback=None): """ A naive synchronous version of apply_async """ + if kwds is None: + kwds = {} + res = func(*args, **kwds) if callback is not None: callback(res) diff --git a/python/ray/util/placement_group.py b/python/ray/util/placement_group.py index 43741556f54e1..933695ea0fbe1 100644 --- a/python/ray/util/placement_group.py +++ b/python/ray/util/placement_group.py @@ -25,7 +25,7 @@ def _export_bundle_reservation_check_method_if_needed(): if bundle_reservation_check: return - @ray.remote(num_cpus=0, max_calls=0) + @ray.remote(num_cpus=0) def bundle_reservation_check_func(placement_group): return placement_group @@ -307,7 +307,7 @@ def get_current_placement_group() -> Optional[PlacementGroup]: None if the current task or actor wasn't created with any placement group. """ - if client_mode_should_convert(): + if client_mode_should_convert(auto_init=True): # Client mode is only a driver. return None worker = ray.worker.global_worker diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 77bf9e1454ea9..90b8c0adb44cc 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -5,7 +5,6 @@ import ray import torch -from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS from ray.util.sgd import utils from ray.util.sgd.torch.utils import choose_amp_backend @@ -63,6 +62,7 @@ def setup_operator(self): world_rank=0, local_rank=0, is_distributed=False, + device=None, use_gpu=self.use_gpu, use_fp16=self.use_fp16, use_tqdm=self.use_tqdm, @@ -121,11 +121,6 @@ def train_epoch(self, info = info or {} self._toggle_profiling(profile=profile) - info.update({ - NUM_STEPS: num_steps, - USE_FP16: self.use_fp16, - "epoch_idx": self.epochs, - }) with self.timers.record("train_epoch"): if iterator is not None: # Dataset will provide us with a list of tuples but we @@ -141,7 +136,11 @@ def format_batch(batch): else: iterator = self.make_iterator( training=True, num_steps=num_steps) - train_stats = self.training_operator.train_epoch(iterator, info) + train_stats = self.training_operator.train_epoch( + iterator, + info=info, + num_steps=num_steps, + epoch_idx=self.epochs) # This is so that `epochs` is first in ordering. stats = dict(epoch=self.epochs, **train_stats) @@ -151,7 +150,6 @@ def format_batch(batch): def validate(self, num_steps=None, profile=False, info=None): """Evaluates the model on the validation data set.""" - info = info or {} self._toggle_profiling(profile=profile) with self.timers.record("validation"): diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index 7143d5c558fd0..3a37436c43e92 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -11,11 +11,8 @@ from ray.util.annotations import PublicAPI from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection, NUM_SAMPLES) -from ray.util.sgd.torch.constants import ( - SCHEDULER_STEP_EPOCH, - NUM_STEPS, - SCHEDULER_STEP_BATCH, -) +from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS, + SCHEDULER_STEP_BATCH, USE_FP16) from ray.util.sgd.torch.utils import choose_amp_backend from torch.nn.parallel import DistributedDataParallel @@ -131,14 +128,15 @@ def __init__(self, config, world_rank, local_rank, - is_distributed=False, - device=None, - use_gpu=False, + is_distributed, + use_gpu, + device, use_fp16=False, use_tqdm=False, wrap_ddp=False, add_dist_sampler=False, scheduler_step_freq=None): + # You are not expected to override this method. self._world_rank = world_rank self._local_rank = local_rank @@ -456,7 +454,7 @@ def should_wrap_dataloader(loader): self._validation_loader = with_sampler( self._validation_loader) - def train_epoch(self, iterator, info): + def train_epoch(self, iterator, info=None, num_steps=None, epoch_idx=0): """Runs one standard training pass over the training dataloader. By default, this method will iterate over the given iterator and @@ -489,8 +487,10 @@ def train_epoch(self, ...): Args: iterator (iter): Iterator over the training data for the entire epoch. This iterator is expected to be entirely consumed. - info (dict): Dictionary for information to be used for custom - training operations. + info (Optional[dict]): Dictionary for information to be used for + custom training operations. + num_steps (Optional[int]): Number of steps in the iterator. + epoch_idx (int): Index of current epoch. Returns: A dict of metrics from training. @@ -499,6 +499,14 @@ def train_epoch(self, ...): raise RuntimeError("Either set self.model in setup function or " "override this method to implement a custom " "training loop.") + + info = info or {} + + info.update({ + NUM_STEPS: num_steps, + USE_FP16: self.use_fp16, + "epoch_idx": epoch_idx + }) model = self.model scheduler = None if hasattr(self, "scheduler"): @@ -636,7 +644,7 @@ def train_batch(self, batch, batch_info): return {"train_loss": loss.item(), NUM_SAMPLES: target.size(0)} - def validate(self, val_iterator, info): + def validate(self, val_iterator, info=None): """Runs one standard validation pass over the val_iterator. This will call ``model.eval()`` and ``torch.no_grad`` when iterating @@ -648,8 +656,8 @@ def validate(self, val_iterator, info): Args: val_iterator (iter): Iterable constructed from the validation dataloader. - info: (dict): Dictionary for information to be used for custom - validation operations. + info: (Optional[dict]): Dictionary for information to be used for + custom validation operations. Returns: A dict of metrics from the evaluation. @@ -662,6 +670,8 @@ def validate(self, val_iterator, info): raise RuntimeError("Either set self.model in setup function or " "override this method to implement a custom " "validation loop.") + + info = info or {} model = self.model metric_meters = AverageMeterCollection() @@ -1151,13 +1161,13 @@ def schedulers(self): def get_test_operator(operator_cls): class _TestingOperator(operator_cls): - def train_epoch(self, iterator, info): + def train_epoch(self, iterator, info, **kwargs): func = self.config.get("custom_func") if callable(func): return func(self, iterator, info) return {"done": 1} - def validate(self, iterator, info): + def validate(self, iterator, info, **kwargs): return self.train_epoch(iterator, info) return _TestingOperator diff --git a/python/ray/util/sgd/v2/BUILD b/python/ray/util/sgd/v2/BUILD index 1f3bb55976689..7081a53b75591 100644 --- a/python/ray/util/sgd/v2/BUILD +++ b/python/ray/util/sgd/v2/BUILD @@ -24,6 +24,16 @@ py_test( "--max_train_steps=2", "--start_local", "--num_workers=2"] ) +py_test( + name = "tune_cifar_pytorch_pbt_example", + size = "medium", + main = "examples/tune_cifar_pytorch_pbt_example.py", + srcs = ["examples/tune_cifar_pytorch_pbt_example.py"], + tags = ["team:ml", "exclusive", "pytorch"], + deps = [":sgd_v2_lib"], + args = ["--smoke-test"] +) + py_test( name = "tune_linear_example", size = "medium", @@ -47,6 +57,14 @@ py_test( deps = [":sgd_v2_lib"] ) +py_test( + name = "test_gpu", + size = "medium", + srcs = ["tests/test_gpu.py"], + tags = ["team:ml", "exclusive", "gpu_only"], + deps = [":sgd_v2_lib"] +) + py_test( name = "test_session", size = "small", @@ -71,6 +89,15 @@ py_test( deps = [":sgd_v2_lib"] ) +py_test( + name = "test_utils", + size = "small", + srcs = ["tests/test_utils.py"], + tags = ["team:ml", "exclusive"], + deps = [":sgd_v2_lib"] +) + + py_test( name = "test_worker_group", size = "medium", diff --git a/python/ray/util/sgd/v2/__init__.py b/python/ray/util/sgd/v2/__init__.py index 49d68ce97309d..8fb122c160345 100644 --- a/python/ray/util/sgd/v2/__init__.py +++ b/python/ray/util/sgd/v2/__init__.py @@ -8,6 +8,6 @@ __all__ = [ "BackendConfig", "CheckpointStrategy", "HorovodConfig", "load_checkpoint", - "local_rank", "report", "save_checkpoint", "SGDCallback", "SGDIterator", - "TensorflowConfig", "TorchConfig", "Trainer", "world_rank" + "local_rank", "report", "save_checkpoint", "SGDIterator", + "TensorflowConfig", "SGDCallback", "TorchConfig", "Trainer", "world_rank" ] diff --git a/python/ray/util/sgd/v2/backends/backend.py b/python/ray/util/sgd/v2/backends/backend.py index 24d8b59f1e413..4feec51a5eb08 100644 --- a/python/ray/util/sgd/v2/backends/backend.py +++ b/python/ray/util/sgd/v2/backends/backend.py @@ -12,7 +12,7 @@ from ray.util.sgd.v2.checkpoint import CheckpointStrategy from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \ TUNE_INSTALLED, TUNE_CHECKPOINT_FILE_NAME, \ - TUNE_CHECKPOINT_ID + TUNE_CHECKPOINT_ID, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV from ray.util.sgd.v2.session import TrainingResultType, TrainingResult from ray.util.sgd.v2.session import init_session, get_session, shutdown_session from ray.util.sgd.v2.utils import construct_path, check_for_failure @@ -275,15 +275,21 @@ def start(self, if initialization_hook: self._initialization_hook = initialization_hook self.worker_group.execute(initialization_hook) - if self._num_gpus_per_worker > 0: - self._setup_gpus() + + share_cuda_visible_devices_enabled = bool( + env_integer(ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, + self._backend.share_cuda_visible_devices)) + + if (self._num_gpus_per_worker > 0 + and share_cuda_visible_devices_enabled): + self._share_cuda_visible_devices() self._backend.on_start(self.worker_group, self._backend_config) except RayActorError as exc: logger.exception(str(exc)) self._increment_failures() self._restart() - def _setup_gpus(self): + def _share_cuda_visible_devices(self): """Sets CUDA_VISIBLE_DEVICES on all workers. For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs @@ -685,6 +691,18 @@ def _increment_failures(self): class Backend(metaclass=abc.ABCMeta): + """Metaclass for distributed communication backend. + + Attributes: + share_cuda_visible_devices (bool): If True, each worker + process will have CUDA_VISIBLE_DEVICES set as the visible device + IDs of all workers on the same node for this training instance. + If False, each worker will have CUDA_VISIBLE_DEVICES set to the + device IDs allocated by Ray for that worker. + """ + + share_cuda_visible_devices: bool = False + def on_start(self, worker_group: WorkerGroup, backend_config: BackendConfig): """Logic for starting this backend.""" diff --git a/python/ray/util/sgd/v2/backends/horovod.py b/python/ray/util/sgd/v2/backends/horovod.py index 4f424d5212dec..4382130ae5749 100644 --- a/python/ray/util/sgd/v2/backends/horovod.py +++ b/python/ray/util/sgd/v2/backends/horovod.py @@ -52,6 +52,8 @@ def init_env_vars(world_rank: int, world_size: int, node_id: str): class HorovodBackend(Backend): + share_cuda_visible_devices: bool = True + def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig): diff --git a/python/ray/util/sgd/v2/backends/torch.py b/python/ray/util/sgd/v2/backends/torch.py index 7d76b179c8d2d..1d1f0d39f366f 100644 --- a/python/ray/util/sgd/v2/backends/torch.py +++ b/python/ray/util/sgd/v2/backends/torch.py @@ -92,6 +92,8 @@ def shutdown_torch(destroy_process_group=False): class TorchBackend(Backend): + share_cuda_visible_devices: bool = True + def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig): if len(worker_group) > 1 and dist.is_available(): # Set the appropriate training backend. diff --git a/python/ray/util/sgd/v2/constants.py b/python/ray/util/sgd/v2/constants.py index 6ebd428f7b1cb..b0dc39e9cbfbc 100644 --- a/python/ray/util/sgd/v2/constants.py +++ b/python/ray/util/sgd/v2/constants.py @@ -44,3 +44,7 @@ # This needs to be added to the checkpoint dictionary so if the Tune trial # is restarted, the checkpoint_id can continue to increment. TUNE_CHECKPOINT_ID = "_current_checkpoint_id" + +# Integer value which if set will override the value of +# Backend.share_cuda_visible_devices. 1 for True, 0 for False. +ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV = "SGD_ENABLE_SHARE_CUDA_VISIBLE_DEVICES" diff --git a/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py b/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py index f87380cf9ce16..c299808c916aa 100644 --- a/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py +++ b/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py @@ -72,7 +72,7 @@ def train_func(config): return results -def train_tensorflow_mnist(num_workers=1, use_gpu=False): +def train_tensorflow_mnist(num_workers=2, use_gpu=False): trainer = Trainer( backend="tensorflow", num_workers=num_workers, use_gpu=use_gpu) trainer.start() @@ -98,7 +98,7 @@ def train_tensorflow_mnist(num_workers=1, use_gpu=False): "--num-workers", "-n", type=int, - default=1, + default=2, help="Sets number of workers for training.") parser.add_argument( "--use-gpu", diff --git a/python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py b/python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py new file mode 100644 index 0000000000000..1ff8054be367c --- /dev/null +++ b/python/ray/util/sgd/v2/examples/tune_cifar_pytorch_pbt_example.py @@ -0,0 +1,200 @@ +import numpy as np +import argparse +from filelock import FileLock + +import ray +from ray import tune +from ray.tune import CLIReporter +from ray.tune.schedulers import PopulationBasedTraining + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, DistributedSampler, Subset +from torchvision.datasets import CIFAR10 +import torchvision.transforms as transforms +from torch.nn.parallel import DistributedDataParallel + +from ray.util.sgd.torch.resnet import ResNet18 + +import ray.util.sgd.v2 as sgd +from ray.util.sgd.v2 import Trainer + + +def train(dataloader, model, loss_fn, optimizer, device): + size = len(dataloader.dataset) + for batch, (X, y) in enumerate(dataloader): + X, y = X.to(device), y.to(device) + + # Compute prediction error + pred = model(X) + loss = loss_fn(pred, y) + + # Backpropagation + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if batch % 100 == 0: + loss, current = loss.item(), batch * len(X) + print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") + + +def validate(dataloader, model, loss_fn, device): + size = len(dataloader.dataset) + num_batches = len(dataloader) + model.eval() + test_loss, correct = 0, 0 + with torch.no_grad(): + for X, y in dataloader: + X, y = X.to(device), y.to(device) + pred = model(X) + test_loss += loss_fn(pred, y).item() + correct += (pred.argmax(1) == y).type(torch.float).sum().item() + test_loss /= num_batches + correct /= size + print(f"Test Error: \n " + f"Accuracy: {(100 * correct):>0.1f}%, " + f"Avg loss: {test_loss:>8f} \n") + return {"loss": test_loss} + + +def train_func(config): + device = torch.device(f"cuda:{sgd.local_rank()}" + if torch.cuda.is_available() else "cpu") + + epochs = config.pop("epochs", 3) + model = ResNet18(config) + model = model.to(device) + model = DistributedDataParallel( + model, + device_ids=[device.index] if torch.cuda.is_available() else None) + + # Create optimizer. + optimizer = torch.optim.SGD( + model.parameters(), + lr=config.get("lr", 0.1), + momentum=config.get("momentum", 0.9)) + + # Load in training and validation data. + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) # meanstd transformation + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) + + with FileLock(".ray.lock"): + train_dataset = CIFAR10( + root="~/data", + train=True, + download=True, + transform=transform_train) + validation_dataset = CIFAR10( + root="~/data", + train=False, + download=False, + transform=transform_test) + + if config.get("test_mode"): + train_dataset = Subset(train_dataset, list(range(64))) + validation_dataset = Subset(validation_dataset, list(range(64))) + + train_loader = DataLoader( + train_dataset, + batch_size=config["batch_size"], + sampler=DistributedSampler(train_dataset)) + validation_loader = DataLoader( + validation_dataset, + batch_size=config["batch_size"], + sampler=DistributedSampler(validation_dataset)) + + # Create loss. + criterion = nn.CrossEntropyLoss() + + results = [] + + for _ in range(epochs): + train(train_loader, model, criterion, optimizer, device) + result = validate(validation_loader, model, criterion, device) + sgd.report(**result) + results.append(result) + + return results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--address", + required=False, + type=str, + help="the address to use for Redis") + parser.add_argument( + "--num-workers", + "-n", + type=int, + default=2, + help="Sets number of workers for training.") + parser.add_argument( + "--num-epochs", type=int, default=5, help="Number of epochs to train.") + parser.add_argument( + "--smoke-test", + action="store_true", + default=False, + help="Finish quickly for testing.") + parser.add_argument( + "--use-gpu", + action="store_true", + default=False, + help="Enables GPU training") + + args, _ = parser.parse_known_args() + if args.smoke_test: + ray.init(num_cpus=4) + else: + ray.init(address=args.address) + + trainer = Trainer( + "torch", num_workers=args.num_workers, use_gpu=args.use_gpu) + Trainable = trainer.to_tune_trainable(train_func) + pbt_scheduler = PopulationBasedTraining( + time_attr="training_iteration", + metric="loss", + mode="min", + perturbation_interval=1, + hyperparam_mutations={ + # distribution for resampling + "lr": lambda: np.random.uniform(0.001, 1), + # allow perturbations within this set of categorical values + "momentum": [0.8, 0.9, 0.99], + }) + + reporter = CLIReporter() + reporter.add_metric_column("loss", "loss") + + analysis = tune.run( + Trainable, + num_samples=4, + config={ + "lr": tune.choice([0.001, 0.01, 0.1]), + "momentum": 0.8, + "batch_size": 128 * args.num_workers, + "epochs": args.num_epochs, + "test_mode": args.smoke_test # whether to to subset the data + }, + stop={"training_iteration": 2 if args.smoke_test else 100}, + max_failures=3, # used for fault tolerance + checkpoint_freq=3, # used for fault tolerance + keep_checkpoints_num=1, # used for fault tolerance + verbose=2, + progress_reporter=reporter, + scheduler=pbt_scheduler) + + print(analysis.get_best_config(metric="loss", mode="min")) diff --git a/python/ray/util/sgd/v2/tests/test_backend.py b/python/ray/util/sgd/v2/tests/test_backend.py index 985029b808118..65ac486dd9df1 100644 --- a/python/ray/util/sgd/v2/tests/test_backend.py +++ b/python/ray/util/sgd/v2/tests/test_backend.py @@ -8,6 +8,7 @@ from ray.util.sgd import v2 as sgd from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor from ray.util.sgd.v2.backends.tensorflow import TensorflowConfig +from ray.util.sgd.v2.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV from ray.util.sgd.v2.worker_group import WorkerGroup from ray.util.sgd.v2.backends.torch import TorchConfig @@ -321,6 +322,7 @@ def get_resources(): num_workers, expected_results = worker_results + os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" e = BackendExecutor( config, num_workers=num_workers, @@ -349,6 +351,7 @@ def get_resources(): num_workers, expected_results = worker_results + os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" e = BackendExecutor( config, num_workers=num_workers, @@ -374,6 +377,7 @@ def get_resources(): num_workers, expected_results = worker_results + os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" e = BackendExecutor( config, num_workers=num_workers, diff --git a/python/ray/util/sgd/v2/tests/test_gpu.py b/python/ray/util/sgd/v2/tests/test_gpu.py new file mode 100644 index 0000000000000..845e768cd6d47 --- /dev/null +++ b/python/ray/util/sgd/v2/tests/test_gpu.py @@ -0,0 +1,92 @@ +import pytest + +import ray +from ray.util.sgd.v2 import Trainer +from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \ + horovod_torch_train_func +from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ + tensorflow_mnist_train_func +from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \ + as fashion_mnist_train_func +from test_tune import torch_fashion_mnist, tune_tensorflow_mnist + + +@pytest.fixture +def ray_start_4_cpus_2_gpus(): + address_info = ray.init(num_cpus=4, num_gpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +def test_tensorflow_mnist_gpu(ray_start_4_cpus_2_gpus): + num_workers = 2 + epochs = 3 + + trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=True) + config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} + trainer.start() + results = trainer.run(tensorflow_mnist_train_func, config) + trainer.shutdown() + + assert len(results) == num_workers + result = results[0] + + loss = result["loss"] + assert len(loss) == epochs + assert loss[-1] < loss[0] + + accuracy = result["accuracy"] + assert len(accuracy) == epochs + assert accuracy[-1] > accuracy[0] + + +def test_torch_fashion_mnist_gpu(ray_start_4_cpus_2_gpus): + num_workers = 2 + epochs = 3 + + trainer = Trainer("torch", num_workers=num_workers, use_gpu=True) + config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} + trainer.start() + results = trainer.run(fashion_mnist_train_func, config) + trainer.shutdown() + + assert len(results) == num_workers + + for result in results: + assert len(result) == epochs + assert result[-1] < result[0] + + +def test_horovod_torch_mnist_gpu(ray_start_4_cpus_2_gpus): + num_workers = 2 + num_epochs = 2 + trainer = Trainer("horovod", num_workers, use_gpu=True) + trainer.start() + results = trainer.run( + horovod_torch_train_func, + config={ + "num_epochs": num_epochs, + "lr": 1e-3 + }) + trainer.shutdown() + + assert len(results) == num_workers + for worker_result in results: + assert len(worker_result) == num_epochs + assert worker_result[num_epochs - 1] < worker_result[0] + + +def test_tune_fashion_mnist_gpu(ray_start_4_cpus_2_gpus): + torch_fashion_mnist(num_workers=2, use_gpu=True, num_samples=1) + + +def test_tune_tensorflow_mnist_gpu(ray_start_4_cpus_2_gpus): + tune_tensorflow_mnist(num_workers=2, use_gpu=True, num_samples=1) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", "-x", "-s", __file__])) diff --git a/python/ray/util/sgd/v2/tests/test_trainer.py b/python/ray/util/sgd/v2/tests/test_trainer.py index f7da6310a2a96..9795017283a0b 100644 --- a/python/ray/util/sgd/v2/tests/test_trainer.py +++ b/python/ray/util/sgd/v2/tests/test_trainer.py @@ -5,26 +5,24 @@ import horovod.torch as hvd_torch import pytest + import ray import ray.util.sgd.v2 as sgd -import tensorflow as tf -import torch from ray._private.test_utils import wait_for_condition from ray.util.sgd.v2 import Trainer, TorchConfig, TensorflowConfig, \ HorovodConfig from ray.util.sgd.v2.backends.backend import BackendConfig, Backend, \ BackendExecutor from ray.util.sgd.v2.callbacks.callback import SGDCallback +from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \ + horovod_torch_train_func, HorovodTrainClass +from ray.util.sgd.v2.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ tensorflow_mnist_train_func from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \ - as \ - fashion_mnist_train_func + as fashion_mnist_train_func from ray.util.sgd.v2.examples.train_linear_example import train_func as \ linear_train_func - -from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \ - horovod_torch_train_func, HorovodTrainClass from ray.util.sgd.v2.worker_group import WorkerGroup @@ -498,31 +496,6 @@ def test_tensorflow_mnist(ray_start_2_cpus): assert accuracy[-1] > accuracy[0] -@pytest.mark.skipif( - len(tf.config.list_physical_devices("GPU")) < 2, - reason="Only run if multiple GPUs are available.") -def test_tensorflow_mnist_gpu(ray_start_2_cpus_2_gpus): - num_workers = 2 - epochs = 3 - - trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=True) - config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} - trainer.start() - results = trainer.run(tensorflow_mnist_train_func, config) - trainer.shutdown() - - assert len(results) == num_workers - result = results[0] - - loss = result["loss"] - assert len(loss) == epochs - assert loss[-1] < loss[0] - - accuracy = result["accuracy"] - assert len(accuracy) == epochs - assert accuracy[-1] > accuracy[0] - - def test_torch_linear(ray_start_2_cpus): num_workers = 2 epochs = 3 @@ -557,26 +530,6 @@ def test_torch_fashion_mnist(ray_start_2_cpus): assert result[-1] < result[0] -@pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="Only run if multiple GPUs are available.") -def test_torch_fashion_mnist_gpu(ray_start_2_cpus_2_gpus): - num_workers = 2 - epochs = 3 - - trainer = Trainer("torch", num_workers=num_workers, use_gpu=True) - config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs} - trainer.start() - results = trainer.run(fashion_mnist_train_func, config) - trainer.shutdown() - - assert len(results) == num_workers - - for result in results: - assert len(result) == epochs - assert result[-1] < result[0] - - def test_horovod_simple(ray_start_2_cpus): def simple_fn(): hvd_torch.init() @@ -610,28 +563,6 @@ def test_horovod_torch_mnist(ray_start_2_cpus): assert worker_result[num_epochs - 1] < worker_result[0] -@pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="Only run if multiple GPUs are available.") -def test_horovod_torch_mnist_gpu(ray_start_2_cpus_2_gpus): - num_workers = 2 - num_epochs = 2 - trainer = Trainer("horovod", num_workers, use_gpu=True) - trainer.start() - results = trainer.run( - horovod_torch_train_func, - config={ - "num_epochs": num_epochs, - "lr": 1e-3 - }) - trainer.shutdown() - - assert len(results) == num_workers - for worker_result in results: - assert len(worker_result) == num_epochs - assert worker_result[num_epochs - 1] < worker_result[0] - - def test_horovod_torch_mnist_stateful(ray_start_2_cpus): num_workers = 2 num_epochs = 2 @@ -986,7 +917,6 @@ def test_resources(ray_start_4_cpus_4_gpus_4_extra, resource, num_requested): def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra): - # GPUs should not be requested if `use_gpu` is False. with pytest.raises(ValueError): Trainer( @@ -1006,6 +936,8 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra): def get_resources(): return os.environ["CUDA_VISIBLE_DEVICES"] + os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1" + # 0 GPUs will be requested and should not raise an error. trainer = Trainer(TestConfig(), num_workers=2, use_gpu=False) trainer.start() diff --git a/python/ray/util/sgd/v2/tests/test_tune.py b/python/ray/util/sgd/v2/tests/test_tune.py index fb9d39b6df8b0..0ec1db59542f8 100644 --- a/python/ray/util/sgd/v2/tests/test_tune.py +++ b/python/ray/util/sgd/v2/tests/test_tune.py @@ -1,18 +1,13 @@ import os import pytest - -import torch -import tensorflow as tf - import ray +import ray.util.sgd.v2 as sgd from ray import tune, cloudpickle from ray.tune import TuneError - -import ray.util.sgd.v2 as sgd from ray.util.sgd.v2 import Trainer -from ray.util.sgd.v2.constants import TUNE_CHECKPOINT_FILE_NAME from ray.util.sgd.v2.backends.backend import Backend, BackendConfig +from ray.util.sgd.v2.constants import TUNE_CHECKPOINT_FILE_NAME from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ tensorflow_mnist_train_func from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \ @@ -28,14 +23,6 @@ def ray_start_2_cpus(): ray.shutdown() -@pytest.fixture -def ray_start_4_cpus_4_gpus(): - address_info = ray.init(num_cpus=2, num_gpus=2) - yield address_info - # The code after the yield will run as teardown code. - ray.shutdown() - - @pytest.fixture def ray_start_8_cpus(): address_info = ray.init(num_cpus=8) @@ -83,13 +70,6 @@ def test_tune_torch_fashion_mnist(ray_start_8_cpus): torch_fashion_mnist(num_workers=2, use_gpu=False, num_samples=2) -@pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="Only run if multiple GPUs are available.") -def test_tune_fashion_mnist_gpu(ray_start_4_cpus_4_gpus): - torch_fashion_mnist(num_workers=2, use_gpu=True, num_samples=1) - - def tune_tensorflow_mnist(num_workers, use_gpu, num_samples): epochs = 2 trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=use_gpu) @@ -113,13 +93,6 @@ def test_tune_tensorflow_mnist(ray_start_8_cpus): tune_tensorflow_mnist(num_workers=2, use_gpu=False, num_samples=2) -@pytest.mark.skipif( - len(tf.config.list_physical_devices("GPU")) < 2, - reason="Only run if multiple GPUs are available.") -def test_tune_tensorflow_mnist_gpu(ray_start_4_cpus_4_gpus): - tune_tensorflow_mnist(num_workers=2, use_gpu=True, num_samples=1) - - def test_tune_error(ray_start_2_cpus): def train_func(config): raise RuntimeError("Error in training function!") diff --git a/python/ray/util/tracing/tracing_helper.py b/python/ray/util/tracing/tracing_helper.py index 73fb61c00767c..68696fe29c46d 100644 --- a/python/ray/util/tracing/tracing_helper.py +++ b/python/ray/util/tracing/tracing_helper.py @@ -290,6 +290,8 @@ def _invocation_remote_span( # If tracing feature flag is not on, perform a no-op. # Tracing doesn't work for cross lang yet. if not is_tracing_enabled() or self._is_cross_language: + if kwargs is not None: + assert "_ray_trace_ctx" not in kwargs return method(self, args, kwargs, *_args, **_kwargs) assert "_ray_trace_ctx" not in kwargs @@ -365,8 +367,7 @@ def _invocation_actor_class_remote_span( # If tracing feature flag is not on, perform a no-op if not is_tracing_enabled(): - if not self.__ray_metadata__.is_cross_language: - kwargs["_ray_trace_ctx"] = None + assert "_ray_trace_ctx" not in kwargs return method(self, args, kwargs, *_args, **_kwargs) class_name = self.__ray_metadata__.class_name @@ -404,6 +405,8 @@ def _start_span( # If tracing feature flag is not on, perform a no-op if (not is_tracing_enabled() or self._actor_ref()._ray_is_cross_language): + if kwargs is not None: + assert "_ray_trace_ctx" not in kwargs return method(self, args, kwargs, *_args, **_kwargs) class_name = (self._actor_ref() diff --git a/python/ray/worker.py b/python/ray/worker.py index 9f5dd31ca6da3..25757fef62dc2 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -191,8 +191,8 @@ def current_session_and_job(self): @property def runtime_env(self): """Get the runtime env in json format""" - return json.loads( - self.core_worker.get_job_config().runtime_env.raw_json) + return json.loads(self.core_worker.get_job_config() + .runtime_env.serialized_runtime_env) def get_serialization_context(self, job_id=None): """Get the SerializationContext of the job that this worker is processing. @@ -223,9 +223,6 @@ def check_connected(self): Exception: An exception is raised if the worker is not connected. """ if not self.connected: - if os.environ.get("RAY_ENABLE_AUTO_CONNECT", "") != "0": - ray.client().connect() - return raise RaySystemError("Ray has not been started yet. You can " "start Ray with 'ray.init()'.") @@ -479,7 +476,7 @@ def print_logs(self): @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def get_gpu_ids(): """Get the IDs of the GPUs that are available to the worker. @@ -576,7 +573,7 @@ def get_dashboard_url(): @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def init( address: Optional[str] = None, *, @@ -974,7 +971,7 @@ def init( @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def shutdown(_exiting_interpreter: bool = False): """Disconnect the worker, and terminate processes started by ray.init(). @@ -1240,7 +1237,7 @@ def listen_error_messages_raylet(worker, threads_stopped): @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=False) def is_initialized() -> bool: """Check if ray.init has been called yet. @@ -1559,7 +1556,7 @@ def show_in_dashboard(message: str, key: str = "", dtype: str = "text"): @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def get(object_refs: Union[ray.ObjectRef, List[ray.ObjectRef]], *, timeout: Optional[float] = None) -> Union[Any, List[Any]]: @@ -1648,7 +1645,7 @@ def get(object_refs: Union[ray.ObjectRef, List[ray.ObjectRef]], @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def put(value: Any, *, _owner: Optional["ray.actor.ActorHandle"] = None) -> ray.ObjectRef: """Store an object in the object store. @@ -1702,7 +1699,7 @@ def put(value: Any, *, @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def wait(object_refs: List[ray.ObjectRef], *, num_returns: int = 1, @@ -1809,7 +1806,7 @@ def wait(object_refs: List[ray.ObjectRef], @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def get_actor(name: str, namespace: Optional[str] = None) -> "ray.actor.ActorHandle": """Get a handle to a named actor. @@ -1841,7 +1838,7 @@ def get_actor(name: str, @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True): """Kill an actor forcefully. @@ -1870,7 +1867,7 @@ def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True): @PublicAPI -@client_mode_hook +@client_mode_hook(auto_init=True) def cancel(object_ref: ray.ObjectRef, *, force: bool = False, @@ -1932,6 +1929,7 @@ def make_decorator(num_returns=None, max_restarts=None, max_task_retries=None, runtime_env=None, + placement_group="default", worker=None, retry_exceptions=None): def decorator(function_or_class): @@ -1963,7 +1961,7 @@ def decorator(function_or_class): Language.PYTHON, function_or_class, None, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type, num_returns, max_calls, max_retries, retry_exceptions, - runtime_env) + runtime_env, placement_group) if inspect.isclass(function_or_class): if num_returns is not None: @@ -2101,15 +2099,6 @@ def method(self): retry_exceptions (bool): Only for *remote functions*. This specifies whether application-level errors should be retried up to max_retries times. - override_environment_variables (Dict[str, str]): (Deprecated in Ray - 1.4.0, will be removed in Ray 1.6--please use the ``env_vars`` - field of :ref:`runtime-environments` instead.) This specifies - environment variables to override for the actor or task. The - overrides are propagated to all child actors and tasks. This - is a dictionary mapping variable names to their values. Existing - variables can be overridden, new ones can be created, and an - existing variable can be unset by setting it to an empty string. - Note: can only be set via `.options()`. """ worker = global_worker @@ -2121,7 +2110,8 @@ def method(self): valid_kwargs = [ "num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory", "resources", "accelerator_type", "max_calls", "max_restarts", - "max_task_retries", "max_retries", "runtime_env", "retry_exceptions" + "max_task_retries", "max_retries", "runtime_env", "retry_exceptions", + "placement_group" ] error_string = ("The @ray.remote decorator must be applied either " "with no arguments and no parentheses, for example " @@ -2154,6 +2144,7 @@ def method(self): object_store_memory = kwargs.get("object_store_memory") max_retries = kwargs.get("max_retries") runtime_env = kwargs.get("runtime_env") + placement_group = kwargs.get("placement_group", "default") retry_exceptions = kwargs.get("retry_exceptions") return make_decorator( @@ -2169,5 +2160,6 @@ def method(self): max_task_retries=max_task_retries, max_retries=max_retries, runtime_env=runtime_env, + placement_group=placement_group, worker=worker, retry_exceptions=retry_exceptions) diff --git a/python/ray/workers/setup_worker.py b/python/ray/workers/setup_worker.py index 23fbc6e8e150d..b40737c1a8ad0 100644 --- a/python/ray/workers/setup_worker.py +++ b/python/ray/workers/setup_worker.py @@ -3,7 +3,8 @@ import logging import os -from ray._private.runtime_env import RuntimeEnvContext +from ray._private.runtime_env.context import RuntimeEnvContext +from ray.core.generated.common_pb2 import Language logger = logging.getLogger(__name__) @@ -26,6 +27,9 @@ type=str, help="the worker allocated resource") +parser.add_argument( + "--language", type=str, help="the language type of the worker") + def get_tmp_dir(remaining_args): for arg in remaining_args: @@ -117,5 +121,5 @@ def start_worker_in_container(container_option, args, remaining_args): # probably not even go through this codepath. runtime_env_context = RuntimeEnvContext.deserialize( args.serialized_runtime_env_context or "{}") - - runtime_env_context.exec_worker(remaining_args) + runtime_env_context.exec_worker(remaining_args, + Language.Value(args.language)) diff --git a/python/ray/workflow/common.py b/python/ray/workflow/common.py index e883cdfacd0b4..169b318ed5d76 100644 --- a/python/ray/workflow/common.py +++ b/python/ray/workflow/common.py @@ -32,7 +32,8 @@ def get_qualname(f): def ensure_ray_initialized(): - ray.worker.global_worker.check_connected() + if not ray.is_initialized(): + ray.init() @dataclass diff --git a/python/ray/workflow/execution.py b/python/ray/workflow/execution.py index 6de22cef05943..b660c65fa8a9e 100644 --- a/python/ray/workflow/execution.py +++ b/python/ray/workflow/execution.py @@ -32,8 +32,9 @@ def run(entry_workflow: Workflow, # Workflow ID format: {Entry workflow UUID}.{Unix time to nanoseconds} workflow_id = f"{str(uuid.uuid4())}.{time.time():.9f}" - logger.info(f"Workflow job created. [id=\"{workflow_id}\", storage_url=" - f"\"{store.storage_url}\"].") + logger.info( + f"Workflow job created. [id=\"{workflow_id}\", storage_url=" + f"\"{store.storage_url}\"]. Type: {entry_workflow.data.step_type} ") with workflow_context.workflow_step_context(workflow_id, store.storage_url): @@ -51,7 +52,7 @@ def run(entry_workflow: Workflow, # - it's a new workflow # TODO (yic): follow up with force rerun if entry_workflow.data.step_type != StepType.FUNCTION or not wf_exists: - commit_step(ws, "", entry_workflow, None) + commit_step(ws, "", entry_workflow, exception=None) workflow_manager = get_or_create_management_actor() ignore_existing = (entry_workflow.data.step_type != StepType.FUNCTION) # NOTE: It is important to 'ray.get' the returned output. This diff --git a/python/ray/workflow/recovery.py b/python/ray/workflow/recovery.py index 58902b4419681..8c64c2cba4100 100644 --- a/python/ray/workflow/recovery.py +++ b/python/ray/workflow/recovery.py @@ -51,8 +51,8 @@ def _recover_workflow_step(args: List[Any], kwargs: Dict[str, Any], def _construct_resume_workflow_from_step( - reader: workflow_storage.WorkflowStorage, - step_id: StepID) -> Union[Workflow, StepID]: + reader: workflow_storage.WorkflowStorage, step_id: StepID, + input_map: Dict[StepID, Any]) -> Union[Workflow, StepID]: """Try to construct a workflow (step) that recovers the workflow step. If the workflow step already has an output checkpointing file, we return the workflow step id instead. @@ -60,6 +60,8 @@ def _construct_resume_workflow_from_step( Args: reader: The storage reader for inspecting the step. step_id: The ID of the step we want to recover. + input_map: This is a context storing the input which has been loaded. + This context is important for dedupe Returns: A workflow that recovers the step, or a ID of a step @@ -70,8 +72,8 @@ def _construct_resume_workflow_from_step( # we already have the output return step_id if isinstance(result.output_step_id, str): - return _construct_resume_workflow_from_step(reader, - result.output_step_id) + return _construct_resume_workflow_from_step( + reader, result.output_step_id, input_map) # output does not exists or not valid. try to reconstruct it. if not result.is_recoverable(): raise WorkflowStepNotRecoverableError(step_id) @@ -79,7 +81,14 @@ def _construct_resume_workflow_from_step( with serialization.objectref_cache(): input_workflows = [] for i, _step_id in enumerate(result.workflows): - r = _construct_resume_workflow_from_step(reader, _step_id) + # Check whether the step has been loaded or not to avoid + # duplication + if _step_id in input_map: + r = input_map[_step_id] + else: + r = _construct_resume_workflow_from_step( + reader, _step_id, input_map) + input_map[_step_id] = r if isinstance(r, Workflow): input_workflows.append(r) else: @@ -119,15 +128,15 @@ def _resume_workflow_step_executor(workflow_id: str, step_id: "StepID", try: store = storage.create_storage(store_url) wf_store = workflow_storage.WorkflowStorage(workflow_id, store) - r = _construct_resume_workflow_from_step(wf_store, step_id) + r = _construct_resume_workflow_from_step(wf_store, step_id, {}) except Exception as e: raise WorkflowNotResumableError(workflow_id) from e if isinstance(r, Workflow): - with workflow_context.workflow_step_context(workflow_id, - store.storage_url): - from ray.workflow.step_executor import (execute_workflow) - result = execute_workflow(r, last_step_of_workflow=True) + with workflow_context.workflow_step_context( + workflow_id, store.storage_url, last_step_of_workflow=True): + from ray.workflow.step_executor import execute_workflow + result = execute_workflow(r) return result.persisted_output, result.volatile_output assert isinstance(r, StepID) return wf_store.load_step_output(r), None diff --git a/python/ray/workflow/step_executor.py b/python/ray/workflow/step_executor.py index 878c7b40bf451..b5416b5a40218 100644 --- a/python/ray/workflow/step_executor.py +++ b/python/ray/workflow/step_executor.py @@ -134,33 +134,9 @@ def _resolve_step_inputs( return signature.recover_args(flattened_args) -def execute_workflow( - workflow: "Workflow", - outer_most_step_id: Optional[str] = None, - last_step_of_workflow: bool = False) -> "WorkflowExecutionResult": +def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult": """Execute workflow. - To fully explain what we are doing, we need to introduce some syntax first. - The syntax for dependencies between workflow steps - "B.step(A.step())" is "A - B"; the syntax for nested workflow steps - "def A(): return B.step()" is "A / B". - - In a chain/DAG of step dependencies, the "output step" is the step of last - (topological) order. For example, in "A - B - C", C is the output step. - - In a chain of nested workflow steps, the initial "output step" is - called the "outer most step" for other "output steps". For example, in - "A / B / C / D", "A" is the outer most step for "B", "C", "D"; - in the hybrid workflow "((A - B) / C / D) - (E / (F - G) / H)", - "B" is the outer most step for "C", "D"; "E" is the outer most step - for "G", "H". - - Args: - workflow: The workflow to be executed. - outer_most_step_id: The ID of the outer most workflow. None if it - does not exists. - last_step_of_workflow: The step that generates the output of the - workflow (including nested steps). Returns: An object ref that represent the result. """ @@ -173,8 +149,8 @@ def execute_workflow( **workflow_data.ray_options).remote( workflow_data.step_type, workflow_data.func_body, workflow_context.get_workflow_step_context(), workflow.step_id, - baked_inputs, outer_most_step_id, workflow_data.catch_exceptions, - workflow_data.max_retries, last_step_of_workflow) + baked_inputs, workflow_data.catch_exceptions, + workflow_data.max_retries) if not isinstance(persisted_output, WorkflowOutputType): raise TypeError("Unexpected return type of the workflow.") @@ -197,7 +173,6 @@ async def _write_step_inputs(wf_storage: workflow_storage.WorkflowStorage, # TODO(suquark): in the future we should write to storage directly # with plasma store object in memory. args_obj = ray.get(inputs.inputs.args) - workflow_id = wf_storage._workflow_id storage = wf_storage._storage save_tasks = [ @@ -213,19 +188,13 @@ async def _write_step_inputs(wf_storage: workflow_storage.WorkflowStorage, await asyncio.gather(*save_tasks) -def commit_step(store: workflow_storage.WorkflowStorage, - step_id: "StepID", - ret: Union["Workflow", Any], - exception: Optional[Exception], - outer_most_step_id: Optional[str] = None): +def commit_step(store: workflow_storage.WorkflowStorage, step_id: "StepID", + ret: Union["Workflow", Any], exception: Optional[Exception]): """Checkpoint the step output. Args: store: The storage the current workflow is using. step_id: The ID of the step. ret: The returned object of the workflow step. - outer_most_step_id: The ID of the outer most workflow. None if it - does not exists. See "step_executor.execute_workflow" for detailed - explanation. """ from ray.workflow.common import Workflow if isinstance(ret, Workflow): @@ -236,7 +205,12 @@ def commit_step(store: workflow_storage.WorkflowStorage, ] asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) - store.save_step_output(step_id, ret, exception, outer_most_step_id) + context = workflow_context.get_workflow_step_context() + store.save_step_output( + step_id, + ret, + exception=exception, + outer_most_step_id=context.outer_most_step_id) def _wrap_run(func: Callable, step_type: StepType, step_id: "StepID", @@ -328,12 +302,11 @@ def _wrap_run(func: Callable, step_type: StepType, step_id: "StepID", @ray.remote(num_returns=2) -def _workflow_step_executor( - step_type: StepType, func: Callable, - context: workflow_context.WorkflowStepContext, step_id: "StepID", - baked_inputs: "_BakedWorkflowInputs", outer_most_step_id: "StepID", - catch_exceptions: bool, max_retries: int, - last_step_of_workflow: bool) -> Any: +def _workflow_step_executor(step_type: StepType, func: Callable, + context: workflow_context.WorkflowStepContext, + step_id: "StepID", + baked_inputs: "_BakedWorkflowInputs", + catch_exceptions: bool, max_retries: int) -> Any: """Executor function for workflow step. Args: @@ -342,13 +315,9 @@ def _workflow_step_executor( context: Workflow step context. Used to access correct storage etc. step_id: The ID of the step. baked_inputs: The processed inputs for the step. - outer_most_step_id: See "step_executor.execute_workflow" for - explanation. catch_exceptions: If set to be true, return (Optional[Result], Optional[Error]) instead of Result. max_retries: Max number of retries encounter of a failure. - last_step_of_workflow: The step that generates the output of the - workflow (including nested steps). Returns: Workflow step output. @@ -361,7 +330,7 @@ def _workflow_step_executor( func, step_type, step_id, catch_exceptions, max_retries, *args, **kwargs) except Exception as e: - commit_step(store, step_id, None, e, outer_most_step_id) + commit_step(store, step_id, None, e) raise e if step_type == StepType.READONLY_ACTOR_METHOD: if isinstance(volatile_output, Workflow): @@ -371,26 +340,28 @@ def _workflow_step_executor( assert not isinstance(persisted_output, Workflow) else: store = workflow_storage.get_workflow_storage() - commit_step(store, step_id, persisted_output, None, outer_most_step_id) + commit_step(store, step_id, persisted_output, None) + outer_most_step_id = context.outer_most_step_id if isinstance(persisted_output, Workflow): if step_type == StepType.FUNCTION: # Passing down outer most step so inner nested steps would # access the same outer most step. - if not outer_most_step_id: + if not context.outer_most_step_id: # The current workflow step returns a nested workflow, and # there is no outer step for the current step. So the # current step is the outer most step for the inner nested # workflow steps. outer_most_step_id = workflow_context.get_current_step_id() assert volatile_output is None - # execute sub-workflow - result = execute_workflow(persisted_output, outer_most_step_id, - last_step_of_workflow) + # Execute sub-workflow. Pass down "outer_most_step_id". + with workflow_context.fork_workflow_step_context( + outer_most_step_id=outer_most_step_id): + result = execute_workflow(persisted_output) # When virtual actor returns a workflow in the method, # the volatile_output and persisted_output will be put together persisted_output = result.persisted_output volatile_output = result.volatile_output - elif last_step_of_workflow: + elif context.last_step_of_workflow: # advance the progress of the workflow store.advance_progress(step_id) _record_step_status(step_id, WorkflowStatus.SUCCESSFUL) @@ -415,9 +386,11 @@ class _BakedWorkflowInputs: @classmethod def from_workflow_inputs(cls, inputs: "WorkflowInputs"): - workflow_outputs = [ - execute_workflow(w).persisted_output for w in inputs.workflows - ] + with workflow_context.fork_workflow_step_context( + outer_most_step_id=None, last_step_of_workflow=False): + workflow_outputs = [ + execute_workflow(w).persisted_output for w in inputs.workflows + ] return cls(inputs.args, workflow_outputs, inputs.workflow_refs) def __reduce__(self): @@ -427,7 +400,10 @@ def __reduce__(self): def _record_step_status(step_id: "StepID", status: "WorkflowStatus", - outputs: List["ObjectRef"] = []) -> None: + outputs: Optional[List["ObjectRef"]] = None) -> None: + if outputs is None: + outputs = [] + workflow_id = workflow_context.get_current_workflow_id() workflow_manager = get_management_actor() ray.get( diff --git a/python/ray/workflow/tests/test_basic_workflows_2.py b/python/ray/workflow/tests/test_basic_workflows_2.py index dad390635cab7..acecfb14dc014 100644 --- a/python/ray/workflow/tests/test_basic_workflows_2.py +++ b/python/ray/workflow/tests/test_basic_workflows_2.py @@ -1,10 +1,13 @@ +import os import pytest import ray import re from filelock import FileLock +from pathlib import Path from ray._private.test_utils import run_string_as_driver, SignalActor from ray import workflow from ray.tests.conftest import * # noqa +from unittest.mock import patch def test_init_twice(call_ray_start, reset_workflow, tmp_path): @@ -22,9 +25,11 @@ def test_init_twice(call_ray_start, reset_workflow, tmp_path): def test_init_twice_2(call_ray_start, reset_workflow, tmp_path): - run_string_as_driver(driver_script) - with pytest.raises(RuntimeError): - workflow.init(str(tmp_path)) + with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}): + run_string_as_driver(driver_script) + with pytest.raises( + RuntimeError, match=".*different from the workflow manager.*"): + workflow.init(str(tmp_path)) @pytest.mark.parametrize( @@ -285,6 +290,38 @@ def f2(*w): f.run() +def test_dedupe_indirect(workflow_start_regular, tmp_path): + counter = Path(tmp_path) / "counter.txt" + lock = Path(tmp_path) / "lock.txt" + counter.write_text("0") + + @workflow.step + def incr(): + with FileLock(str(lock)): + c = int(counter.read_text()) + c += 1 + counter.write_text(f"{c}") + + @workflow.step + def identity(a): + return a + + @workflow.step + def join(*a): + return counter.read_text() + + # Here a is passed to two steps and we need to ensure + # it's only executed once + a = incr.step() + i1 = identity.step(a) + i2 = identity.step(a) + assert "1" == join.step(i1, i2).run() + assert "2" == join.step(i1, i2).run() + # pass a multiple times + assert "3" == join.step(a, a, a, a).run() + assert "4" == join.step(a, a, a, a).run() + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/workflow/tests/test_lifetime.py b/python/ray/workflow/tests/test_lifetime.py index 8d12399369ac9..64a519fa19a48 100644 --- a/python/ray/workflow/tests/test_lifetime.py +++ b/python/ray/workflow/tests/test_lifetime.py @@ -1,3 +1,4 @@ +import os import ray import time import pytest @@ -5,6 +6,7 @@ run_string_as_driver) from ray.tests.conftest import * # noqa from ray import workflow +from unittest.mock import patch driver_script = """ import time @@ -29,21 +31,23 @@ def foo(x): def test_workflow_lifetime_1(call_ray_start, reset_workflow): # Case 1: driver exits normally - run_string_as_driver(driver_script.format(5)) - workflow.init() - output = workflow.get_output("driver_terminated") - assert ray.get(output) == 20 + with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}): + run_string_as_driver(driver_script.format(5)) + workflow.init() + output = workflow.get_output("driver_terminated") + assert ray.get(output) == 20 def test_workflow_lifetime_2(call_ray_start, reset_workflow): # Case 2: driver terminated - proc = run_string_as_driver_nonblocking(driver_script.format(100)) - time.sleep(10) - proc.kill() - time.sleep(1) - workflow.init() - output = workflow.get_output("driver_terminated") - assert ray.get(output) == 20 + with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}): + proc = run_string_as_driver_nonblocking(driver_script.format(100)) + time.sleep(10) + proc.kill() + time.sleep(1) + workflow.init() + output = workflow.get_output("driver_terminated") + assert ray.get(output) == 20 if __name__ == "__main__": diff --git a/python/ray/workflow/workflow_access.py b/python/ray/workflow/workflow_access.py index 0524637cf08da..c1b5d78d253a0 100644 --- a/python/ray/workflow/workflow_access.py +++ b/python/ray/workflow/workflow_access.py @@ -327,8 +327,8 @@ def load(wf_store, workflow_id, step_id): actor = get_management_actor() return actor.get_output.remote(workflow_id, result.output_step_id) - raise ValueError( - f"No such step id {step_id} in workflow {workflow_id}") + raise ValueError(f"Cannot load output from step id {step_id} " + f"in workflow {workflow_id}") return ray.put( _SelfDereferenceObject(None, diff --git a/python/ray/workflow/workflow_context.py b/python/ray/workflow/workflow_context.py index ffbeaafb6ce7f..7dec0937695f5 100644 --- a/python/ray/workflow/workflow_context.py +++ b/python/ray/workflow/workflow_context.py @@ -1,40 +1,58 @@ +from dataclasses import dataclass, field import logging -from typing import Optional, List +from typing import Optional, List, TYPE_CHECKING from contextlib import contextmanager from ray.workflow.common import WorkflowStatus logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from python.ray.workflow.common import StepID + +@dataclass class WorkflowStepContext: - def __init__(self, - workflow_id: str = None, - storage_url: str = None, - workflow_scope: List[str] = None): - """ - The structure for saving workflow step context. The context provides - critical info (e.g. where to checkpoint, which is its parent step) - for the step to execute correctly. - - Args: - workflow_id: The workflow job ID. - storage_url: The storage of the workflow, used for checkpointing. - workflow_scope: The "calling stack" of the current workflow step. - It describe the parent workflow steps. - """ - self.workflow_id = workflow_id - self.storage_url = storage_url - self.workflow_scope = workflow_scope or [] - - def __reduce__(self): - return WorkflowStepContext, (self.workflow_id, self.storage_url, - self.workflow_scope) + """ + The structure for saving workflow step context. The context provides + critical info (e.g. where to checkpoint, which is its parent step) + for the step to execute correctly. + + To fully explain what we are doing, we need to introduce some syntax + first. The syntax for dependencies between workflow steps + "B.step(A.step())" is "A - B"; the syntax for nested workflow steps + "def A(): return B.step()" is "A / B". + + In a chain/DAG of step dependencies, the "output step" is the step of + last (topological) order. For example, in "A - B - C", C is the + output step. + + In a chain of nested workflow steps, the initial "output step" is + called the "outer most step" for other "output steps". For example, in + "A / B / C / D", "A" is the outer most step for "B", "C", "D"; + in the hybrid workflow "((A - B) / C / D) - (E / (F - G) / H)", + "B" is the outer most step for "C", "D"; "E" is the outer most step + for "G", "H". + """ + # ID of the workflow. + workflow_id: Optional[str] = None + # The storage of the workflow, used for checkpointing. + storage_url: Optional[str] = None + # The "calling stack" of the current workflow step. It describe + # the parent workflow steps. + workflow_scope: List[str] = field(default_factory=list) + # The ID of the outer most workflow. "None" if it does not exists. + outer_most_step_id: "Optional[StepID]" = None + # The step that generates the output of the workflow (including all + # nested steps). + last_step_of_workflow: bool = False _context: Optional[WorkflowStepContext] = None @contextmanager -def workflow_step_context(workflow_id, storage_url) -> None: +def workflow_step_context(workflow_id, + storage_url, + last_step_of_workflow=False) -> None: """Initialize the workflow step context. Args: @@ -45,7 +63,48 @@ def workflow_step_context(workflow_id, storage_url) -> None: original_context = _context assert workflow_id is not None try: - _context = WorkflowStepContext(workflow_id, storage_url) + _context = WorkflowStepContext( + workflow_id, + storage_url, + last_step_of_workflow=last_step_of_workflow) + yield + finally: + _context = original_context + + +_sentinel = object() + + +@contextmanager +def fork_workflow_step_context( + workflow_id: Optional[str] = _sentinel, + storage_url: Optional[str] = _sentinel, + workflow_scope: Optional[List[str]] = _sentinel, + outer_most_step_id: Optional[str] = _sentinel, + last_step_of_workflow: Optional[bool] = _sentinel): + """Fork the workflow step context. + Inherits the original value if no value is provided. + + Args: + workflow_id: The ID of the workflow. + storage_url: The storage the workflow is using. + """ + global _context + original_context = _context + assert workflow_id is not None + try: + _context = WorkflowStepContext( + workflow_id=original_context.workflow_id + if workflow_id is _sentinel else workflow_id, + storage_url=original_context.storage_url + if storage_url is _sentinel else storage_url, + workflow_scope=original_context.workflow_scope + if workflow_scope is _sentinel else workflow_scope, + outer_most_step_id=original_context.outer_most_step_id + if outer_most_step_id is _sentinel else outer_most_step_id, + last_step_of_workflow=original_context.last_step_of_workflow + if last_step_of_workflow is _sentinel else last_step_of_workflow, + ) yield finally: _context = original_context diff --git a/python/ray/workflow/workflow_storage.py b/python/ray/workflow/workflow_storage.py index bf18f471483de..5a188cca1a3f2 100644 --- a/python/ray/workflow/workflow_storage.py +++ b/python/ray/workflow/workflow_storage.py @@ -117,9 +117,9 @@ def load_step_output(self, step_id: StepID) -> Any: # In this case, there is no such step raise output_err - def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], + def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], *, exception: Optional[Exception], - outer_most_step_id: Optional[StepID]) -> None: + outer_most_step_id: StepID) -> None: """When a workflow step returns, 1. If the returned object is a workflow, this means we are a nested workflow. We save the output metadata that points to the workflow. @@ -130,8 +130,7 @@ def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], it means we are in the workflow job driver process. ret: The returned object from a workflow step. exception: This step should throw exception. - outer_most_step_id: See - "step_executor.execute_workflow" for explanation. + outer_most_step_id: See WorkflowStepContext. """ tasks = [] if isinstance(ret, Workflow): @@ -154,14 +153,9 @@ def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], # tasks.append(self._put(self._key_step_output(step_id), ret)) dynamic_output_id = step_id # TODO (yic): Delete exception file - - # outer_most_step_id == "" indicates the root step of a - # workflow. This would directly update "outputs.json" in - # the workflow dir, and we want to avoid it. - if outer_most_step_id is not None and outer_most_step_id != "": - tasks.append( - self._update_dynamic_output(outer_most_step_id, - dynamic_output_id)) + tasks.append( + self._update_dynamic_output(outer_most_step_id, + dynamic_output_id)) else: assert ret is None promise = serialization.dump_to_storage( @@ -271,10 +265,15 @@ async def _update_dynamic_output(self, outer_most_step_id: StepID, critical for scalability of virtual actors. Args: - outer_most_step_id: ID of outer_most_step. See - "step_executor.execute_workflow" for explanation. + outer_most_step_id: See WorkflowStepContext for explanation. dynamic_output_step_id: ID of dynamic_step. """ + # outer_most_step_id == "" indicates the root step of a + # workflow. This would directly update "outputs.json" in + # the workflow dir, and we want to avoid it. + if outer_most_step_id is None or outer_most_step_id == "": + return + metadata = await self._get( self._key_step_output_metadata(outer_most_step_id), True) if (dynamic_output_step_id != metadata["output_step_id"] diff --git a/python/requirements.txt b/python/requirements.txt index 4d0baeaf9ef80..8c9ac2dff5a21 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -27,7 +27,7 @@ requests ## setup.py extras dm_tree flask -gym +gym==0.19 lz4 scikit-image opencv-python-headless==4.3.0.36 @@ -68,6 +68,7 @@ opentelemetry-exporter-otlp==1.1.0 pexpect Pillow; platform_system != "Windows" pygments +pyspark pytest==5.4.3 pytest-asyncio pytest-rerunfailures diff --git a/python/requirements/ml/requirements_rllib.txt b/python/requirements/ml/requirements_rllib.txt index a81e52c9c1f08..6bba94e49fc99 100644 --- a/python/requirements/ml/requirements_rllib.txt +++ b/python/requirements/ml/requirements_rllib.txt @@ -10,9 +10,9 @@ kaggle_environments==1.7.11 # Unity3D testing mlagents_envs==0.27.0 # For tests on PettingZoo's multi-agent envs. -pettingzoo==1.11.0 +pettingzoo==1.11.1 pymunk==6.0.0 -supersuit +supersuit==2.6.6 # For testing in MuJoCo-like envs (in PyBullet). pybullet==3.1.7 # For tests on RecSim and Kaggle envs. diff --git a/python/requirements_linters.txt b/python/requirements_linters.txt index 6f5661b1f2b2f..69f457fea1688 100644 --- a/python/requirements_linters.txt +++ b/python/requirements_linters.txt @@ -1,5 +1,6 @@ flake8==3.9.1 flake8-comprehensions flake8-quotes==2.0.0 +flake8-bugbear==21.9.2 mypy==0.782 yapf==0.23.0 diff --git a/python/setup.py b/python/setup.py index 62d1e4e36fa46..3f168d6510e4e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) SUPPORTED_PYTHONS = [(3, 6), (3, 7), (3, 8), (3, 9)] -SUPPORTED_BAZEL = (3, 4, 1) +SUPPORTED_BAZEL = (4, 2, 1) ROOT_DIR = os.path.dirname(__file__) BUILD_JAVA = os.getenv("RAY_INSTALL_JAVA") == "1" @@ -184,6 +184,11 @@ def get_packages(self): # in this directory if setup_spec.type == SetupType.RAY: setup_spec.extras = { + "data": [ + "pandas", + "pyarrow>=4.0.1", + "fsspec", + ], "default": [ "aiohttp", "aiohttp_cors", @@ -534,6 +539,19 @@ def copy_file(target_dir, filename, rootdir): return 0 +def add_system_dlls(dlls, target_dir): + """ + Copy any required dlls required by the c-extension module and not already + provided by python. They will end up in the wheel next to the c-extension + module which will guarentee they are available at runtime. + """ + for dll in dlls: + # Installing Visual Studio will copy the runtime dlls to system32 + src = os.path.join(r"c:\Windows\system32", dll) + assert os.path.exists(src) + shutil.copy(src, target_dir) + + def pip_run(build_ext): build(True, BUILD_JAVA, True) @@ -558,6 +576,13 @@ def pip_run(build_ext): copied_files = 0 for filename in setup_spec.files_to_include: copied_files += copy_file(build_ext.build_lib, filename, ROOT_DIR) + if sys.platform == "win32": + # _raylet.pyd links to some MSVC runtime DLLS, this one may not be + # present on a user's machine. While vcruntime140.dll and + # vcruntime140_1.dll are also required, they are provided by CPython. + runtime_dlls = ["msvcp140.dll"] + add_system_dlls(runtime_dlls, os.path.join(build_ext.build_lib, "ray")) + copied_files += len(runtime_dlls) print("# of files copied to {}: {}".format(build_ext.build_lib, copied_files)) diff --git a/release/.buildkite/build_pipeline.py b/release/.buildkite/build_pipeline.py index 96d58a2b54f2a..9cbfdbdc08a0c 100644 --- a/release/.buildkite/build_pipeline.py +++ b/release/.buildkite/build_pipeline.py @@ -99,6 +99,7 @@ def __init__(self, name: str, retry: int = 0): "~/ray/release/nightly_tests/nightly_tests.yaml": [ "dask_on_ray_large_scale_test_no_spilling", "dask_on_ray_large_scale_test_spilling", + "pg_autoscaling_regression_test", ], "~/ray/release/long_running_tests/long_running_tests.yaml": [ SmokeTest("actor_deaths"), diff --git a/release/RELEASE_CHECKLIST.md b/release/RELEASE_CHECKLIST.md index f8a55bfff9aec..e4770a1cb6fdd 100644 --- a/release/RELEASE_CHECKLIST.md +++ b/release/RELEASE_CHECKLIST.md @@ -31,6 +31,7 @@ This checklist is meant to be used in conjunction with the RELEASE_PROCESS.rst d - [ ] Test passing - [ ] Results added to `release/release_logs` - [ ] microbenchmark +- [ ] `kubernetes` manual release tests pass - [ ] ``weekly`` release test suite - [ ] Test passing diff --git a/release/RELEASE_PROCESS.rst b/release/RELEASE_PROCESS.rst index 59da95846cf3c..b2f7d05db5492 100644 --- a/release/RELEASE_PROCESS.rst +++ b/release/RELEASE_PROCESS.rst @@ -172,6 +172,9 @@ Release tests are added and maintained by the respective teams. As another example, if you just want to kick off all nightly RLLib tests, select the respective test suite and specify ``rllib`` in the test file filter. +6. **Kubernetes tests must be run manually.** Refer to ``kubernetes_manual_tests/README.md``. + Feel free to ping code owner(s) of OSS Kubernetes support to run these. + Identify and Resolve Release Blockers ------------------------------------- If a release blocking issue arises in the course of testing, you should diff --git a/release/alerts/xgboost_tests.py b/release/alerts/xgboost_tests.py index 59ab2880adf76..8b77cc17f49c7 100644 --- a/release/alerts/xgboost_tests.py +++ b/release/alerts/xgboost_tests.py @@ -43,7 +43,9 @@ def handle_result(created_on: datetime.datetime, category: str, else: # train scripts if test_name == "train_small": - target_time = 30 + # Leave a couple of seconds for ray connect setup + # (without connect it should finish in < 30) + target_time = 45 elif test_name == "train_moderate": target_time = 60 elif test_name == "train_gpu": diff --git a/release/e2e.py b/release/e2e.py index 1b5fe71d15923..f47a0bbeecf08 100644 --- a/release/e2e.py +++ b/release/e2e.py @@ -264,11 +264,30 @@ def getenv_default(key: str, default: Optional[str] = None): } REPORT_S = 30 +RETRY_MULTIPLIER = 2 + + +def exponential_backoff_retry(f, retry_exceptions, initial_retry_delay_s, + max_retries): + retry_cnt = 0 + retry_delay_s = initial_retry_delay_s + while True: + try: + return f() + except retry_exceptions as e: + retry_cnt += 1 + if retry_cnt > max_retries: + raise + logger.info(f"Retry function call failed due to {e} " + f"in {retry_delay_s} seconds...") + time.sleep(retry_delay_s) + retry_delay_s *= RETRY_MULTIPLIER def maybe_fetch_api_token(): if GLOBAL_CONFIG["ANYSCALE_CLI_TOKEN"] is None: - print("Missing ANYSCALE_CLI_TOKEN, retrieving from AWS secrets store") + logger.info( + "Missing ANYSCALE_CLI_TOKEN, retrieving from AWS secrets store") # NOTE(simon) This should automatically retrieve # release-automation@anyscale.com's anyscale token GLOBAL_CONFIG["ANYSCALE_CLI_TOKEN"] = boto3.client( @@ -405,7 +424,8 @@ def populate_wheels_sanity_check(commit: Optional[str] = None): raise RuntimeError(f"Could not populate wheels sanity check command: " f"Commit hash missing. Got: {commit}") - cmd = f"python -c 'import ray; assert ray.__commit__ == \"{commit}\"'" + cmd = (f"python -c 'import ray; " + f"assert ray.__commit__ == \"{commit}\", ray.__commit__'") os.environ["RAY_WHEELS_SANITY_CHECK"] = cmd @@ -463,7 +483,7 @@ def has_errored(result: Dict[Any, Any]) -> bool: return result.get("status", "invalid") != "finished" -def report_result(test_suite: str, test_name: str, status: str, logs: str, +def report_result(test_suite: str, test_name: str, status: str, last_logs: str, results: Dict[Any, Any], artifacts: Dict[Any, Any], category: str): now = datetime.datetime.utcnow() @@ -477,67 +497,66 @@ def report_result(test_suite: str, test_name: str, status: str, logs: str, f"results, artifacts, category) " f"VALUES (:created_on, :test_suite, :test_name, :status, :last_logs, " f":results, :artifacts, :category)") - - rds_data_client.execute_statement( - database=GLOBAL_CONFIG["RELEASE_AWS_DB_NAME"], - parameters=[ - { - "name": "created_on", - "typeHint": "TIMESTAMP", - "value": { - "stringValue": now.strftime("%Y-%m-%d %H:%M:%S") - }, - }, - { - "name": "test_suite", - "value": { - "stringValue": test_suite - } - }, - { - "name": "test_name", - "value": { - "stringValue": test_name - } - }, - { - "name": "status", - "value": { - "stringValue": status - } - }, - { - "name": "last_logs", - "value": { - "stringValue": logs - } - }, - { - "name": "results", - "typeHint": "JSON", - "value": { - "stringValue": json.dumps(results) - }, - }, - { - "name": "artifacts", - "typeHint": "JSON", - "value": { - "stringValue": json.dumps(artifacts) - }, - }, - { - "name": "category", - "value": { - "stringValue": category - } - }, - ], - secretArn=GLOBAL_CONFIG["RELEASE_AWS_DB_SECRET_ARN"], - resourceArn=GLOBAL_CONFIG["RELEASE_AWS_DB_RESOURCE_ARN"], - schema=schema, - sql=sql, - ) + parameters = [{ + "name": "created_on", + "typeHint": "TIMESTAMP", + "value": { + "stringValue": now.strftime("%Y-%m-%d %H:%M:%S") + }, + }, { + "name": "test_suite", + "value": { + "stringValue": test_suite + } + }, { + "name": "test_name", + "value": { + "stringValue": test_name + } + }, { + "name": "status", + "value": { + "stringValue": status + } + }, { + "name": "last_logs", + "value": { + "stringValue": last_logs + } + }, { + "name": "results", + "typeHint": "JSON", + "value": { + "stringValue": json.dumps(results) + }, + }, { + "name": "artifacts", + "typeHint": "JSON", + "value": { + "stringValue": json.dumps(artifacts) + }, + }, { + "name": "category", + "value": { + "stringValue": category + } + }] + + # Default boto3 call timeout is 45 seconds. + retry_delay_s = 64 + MAX_RDS_RETRY = 3 + exponential_backoff_retry( + lambda: rds_data_client.execute_statement( + database=GLOBAL_CONFIG["RELEASE_AWS_DB_NAME"], + parameters=parameters, + secretArn=GLOBAL_CONFIG["RELEASE_AWS_DB_SECRET_ARN"], + resourceArn=GLOBAL_CONFIG["RELEASE_AWS_DB_RESOURCE_ARN"], + schema=schema, + sql=sql), + retry_exceptions=rds_data_client.exceptions.StatementTimeoutException, + initial_retry_delay_s=retry_delay_s, + max_retries=MAX_RDS_RETRY) + logger.info("Result has been persisted to the databse") def log_results_and_artifacts(result: Dict): @@ -903,7 +922,11 @@ def wait_for_session_command_to_complete(create_session_command_result, # Sleep 1 sec before next check. time.sleep(1) - result = sdk.get_session_command(session_command_id=scd_id) + result = exponential_backoff_retry( + lambda: sdk.get_session_command(session_command_id=scd_id), + retry_exceptions=Exception, + initial_retry_delay_s=10, + max_retries=3) completed = result.result.finished_at if state_str == "CMD_RUN": @@ -934,10 +957,14 @@ def wait_for_session_command_to_complete(create_session_command_result, def get_command_logs(session_controller: SessionController, scd_id: str, lines: int = 50): - result = session_controller.api_client.get_execution_logs_api_v2_session_commands_session_command_id_execution_logs_get( # noqa: E501 - session_command_id=scd_id, - start_line=-1 * lines, - end_line=0) + result = exponential_backoff_retry( + lambda: session_controller.api_client.get_execution_logs_api_v2_session_commands_session_command_id_execution_logs_get( # noqa: E501 + session_command_id=scd_id, + start_line=-1 * lines, + end_line=0), + retry_exceptions=Exception, + initial_retry_delay_s=10, + max_retries=3) return result.result.lines @@ -1777,7 +1804,7 @@ def run_test(test_config_file: str, report: bool = True, keep_results_dir: bool = False, session_name: Optional[str] = None, - app_config_id_override=None): + app_config_id_override=None) -> Dict[str, Any]: with open(test_config_file, "rt") as f: test_configs = yaml.load(f, Loader=yaml.FullLoader) @@ -1836,18 +1863,18 @@ def run_test(test_config_file: str, logger.info("Kicked off test. It's now up to the `--check` " "part of the script to track its process.") - return + return {} else: # `--check` or no kick off only if status == "nosession": logger.info(f"No running session found for test {test_name}, so " f"assuming everything is fine.") - return + return {} if status == "kickoff": logger.info(f"Test {test_name} is still running.") - return + return {} last_logs = result.get("last_logs", "No logs.") @@ -1857,7 +1884,7 @@ def run_test(test_config_file: str, test_suite=test_suite, test_name=test_name, status=status, - logs=last_logs, + last_logs=last_logs, results=result.get("results", {}), artifacts=result.get("artifacts", {}), category=category, @@ -1872,7 +1899,7 @@ def run_test(test_config_file: str, if has_errored(result): raise RuntimeError(last_logs) - return + return report_kwargs if __name__ == "__main__": @@ -1935,7 +1962,6 @@ def run_test(test_config_file: str, "You have to set the ANYSCALE_PROJECT environment variable!") maybe_fetch_api_token() - if args.ray_wheels: os.environ["RAY_WHEELS"] = str(args.ray_wheels) url = str(args.ray_wheels) @@ -1955,7 +1981,7 @@ def run_test(test_config_file: str, test_config_file = os.path.abspath(os.path.expanduser(args.test_config)) - run_test( + result_dict = run_test( test_config_file=test_config_file, test_name=args.test_name, project_id=GLOBAL_CONFIG["ANYSCALE_PROJECT"], @@ -1970,3 +1996,30 @@ def run_test(test_config_file: str, keep_results_dir=args.keep_results_dir, app_config_id_override=args.app_config_id_override, ) + + if result_dict: + # If we get a result dict, check if any alerts should be raised + from alert import SUITE_TO_FN, default_handle_result + + logger.info("Checking if results are valid...") + + handle_result_kwargs = result_dict.copy() + handle_result_kwargs["created_on"] = None + + test_suite = handle_result_kwargs.get("test_suite", None) + test_name = handle_result_kwargs.get("test_name", None) + category = handle_result_kwargs.get("category", None) + + handle_fn = SUITE_TO_FN.get(test_suite, None) + if not handle_fn: + logger.warning(f"No handle for suite {test_suite}") + alert = default_handle_result(**handle_result_kwargs) + else: + alert = handle_fn(**handle_result_kwargs) + + if alert: + # If we get an alert, the test failed. + raise RuntimeError(alert) + else: + logger.info(f"No alert raised for test {test_suite}/{test_name} " + f"({category}) - the test successfully passed!") diff --git a/release/golden_notebook_tests/dask_xgboost_app_config.yaml b/release/golden_notebook_tests/dask_xgboost_app_config.yaml index 072b183099476..a05da857edef8 100755 --- a/release/golden_notebook_tests/dask_xgboost_app_config.yaml +++ b/release/golden_notebook_tests/dask_xgboost_app_config.yaml @@ -5,9 +5,8 @@ debian_packages: python: pip_packages: - - pytest - pandas>=1.3.0 # otherwise, a version mismatch between local and remote will cause an exception - - xgboost_ray[default] + - git+https://github.com/ray-project/xgboost_ray.git#egg=xgboost_ray[default] - dask - fastapi - uvicorn @@ -16,5 +15,5 @@ python: post_build_cmds: - pip uninstall -y ray || true - - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - pip install -U {{ env["RAY_WHEELS"] | default("ray") }} - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/golden_notebook_tests/golden_notebook_tests.yaml b/release/golden_notebook_tests/golden_notebook_tests.yaml index 1fae1e1d65824..e6d5838d10333 100644 --- a/release/golden_notebook_tests/golden_notebook_tests.yaml +++ b/release/golden_notebook_tests/golden_notebook_tests.yaml @@ -1,4 +1,7 @@ - name: dask_xgboost_test + owner: + mail: "antoni@anyscale.com" + slack: "@team_ml" cluster: app_config: dask_xgboost_app_config.yaml compute_template: compute_tpl.yaml @@ -8,8 +11,18 @@ autosuspend_mins: 10 timeout: 1200 script: python workloads/dask_xgboost_test.py + args: + [ + "--num-actors 4", + "--cpus-per-actor 4", + "--num-actors-inference 16", + "--cpus-per-actor-inference 1", + ] - name: modin_xgboost_test + owner: + mail: "antoni@anyscale.com" + slack: "@team_ml" cluster: app_config: modin_xgboost_app_config.yaml compute_template: compute_tpl.yaml @@ -19,6 +32,13 @@ autosuspend_mins: 10 timeout: 1200 script: python workloads/modin_xgboost_test.py + args: + [ + "--num-actors 4", + "--cpus-per-actor 4", + "--num-actors-inference 16", + "--cpus-per-actor-inference 1", + ] - name: torch_tune_serve_test owner: @@ -34,4 +54,3 @@ autosuspend_mins: 10 timeout: 1200 script: python workloads/torch_tune_serve_test.py - diff --git a/release/golden_notebook_tests/modin_xgboost_app_config.yaml b/release/golden_notebook_tests/modin_xgboost_app_config.yaml index c17fa85ca0144..5fb35e7b03fdd 100755 --- a/release/golden_notebook_tests/modin_xgboost_app_config.yaml +++ b/release/golden_notebook_tests/modin_xgboost_app_config.yaml @@ -5,7 +5,8 @@ debian_packages: python: pip_packages: - - pytest + - pandas>=1.3.0 # otherwise, a version mismatch between local and remote will cause an exception + - git+https://github.com/ray-project/xgboost_ray.git#egg=xgboost_ray[default] - modin - s3fs - fastapi @@ -16,4 +17,4 @@ python: post_build_cmds: - pip uninstall -y ray || true - pip install -U {{ env["RAY_WHEELS"] | default("ray") }} - - pip install git+https://github.com/ray-project/xgboost_ray.git#egg=xgboost_ray[default] + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/golden_notebook_tests/workloads/dask_xgboost_test.py b/release/golden_notebook_tests/workloads/dask_xgboost_test.py index 99755eb4399bb..c10bf91d96754 100644 --- a/release/golden_notebook_tests/workloads/dask_xgboost_test.py +++ b/release/golden_notebook_tests/workloads/dask_xgboost_test.py @@ -1,135 +1,28 @@ -import argparse -import json +import ray import os import time +import json +from util import import_and_execute_test_script, wait_for_cluster_client -import dask -import dask.dataframe as dd -import ray -from ray import tune - -from ray.util.dask import ray_dask_get - -from xgboost_ray import RayDMatrix, RayParams, train, predict - -from utils.utils import is_anyscale_connect - -FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/" \ - "simpleHIGGS.csv" - - -def train_xgboost(config, train_df, test_df, target_column, ray_params): - # distributed loading of a parquet dataset - train_set = RayDMatrix(train_df, target_column) - test_set = RayDMatrix(test_df, target_column) - - evals_result = {} - - start_time = time.time() - # Train the classifier - bst = train( - params=config, - dtrain=train_set, - evals=[(test_set, "eval")], - evals_result=evals_result, - verbose_eval=False, - num_boost_round=100, - ray_params=ray_params) - print(f"Total time taken: {time.time()-start_time}") - - model_path = "model.xgb" - bst.save_model(model_path) - print("Final validation error: {:.4f}".format( - evals_result["eval"]["error"][-1])) - - return bst - - -def tune_xgboost(train_df, test_df, target_column): - # Set XGBoost config. - config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], - "eta": tune.loguniform(1e-4, 1e-1), - "subsample": tune.uniform(0.5, 1.0), - "max_depth": tune.randint(1, 9) - } - - ray_params = RayParams( - max_actor_restarts=1, gpus_per_actor=0, cpus_per_actor=4, num_actors=4) - - analysis = tune.run( - tune.with_parameters( - train_xgboost, - train_df=train_df, - test_df=test_df, - target_column=target_column, - ray_params=ray_params), - # Use the `get_tune_resources` helper function to set the resources. - resources_per_trial=ray_params.get_tune_resources(), - config=config, - num_samples=1, - metric="eval-error", - mode="min", - verbose=1) - - accuracy = 1. - analysis.best_result["eval-error"] - print(f"Best model parameters: {analysis.best_config}") - print(f"Best model total accuracy: {accuracy:.4f}") - - return analysis.best_config +NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO = ( + "doc/examples/dask_xgboost/dask_xgboost.py") def main(): - print("Loading HIGGS data.") - - dask.config.set(scheduler=ray_dask_get) - colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)] - data = dd.read_csv(FILE_URL, names=colnames) - - print("Loaded HIGGS data.") - - # partition on a column - df_train = data[(data["feature-01"] < 0.4)] - df_validation = data[(data["feature-01"] >= 0.4) - & (data["feature-01"] < 0.8)] - - config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], - } - - bst = train_xgboost( - config, df_train, df_validation, "label", - RayParams(max_actor_restarts=1, cpus_per_actor=4, num_actors=4)) - tune_xgboost(df_train, df_validation, "label") - inference_df = RayDMatrix( - df_train[sorted(df_train.columns)], ignore=["label", "partition"]) - predict( - bst, - inference_df, - ray_params=RayParams(cpus_per_actor=2, num_actors=16)) + import_and_execute_test_script(NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--smoke-test", - action="store_true", - help="Finish quickly for testing.") - args = parser.parse_args() - start = time.time() addr = os.environ.get("RAY_ADDRESS") job_name = os.environ.get("RAY_JOB_NAME", "dask_xgboost_test") - if is_anyscale_connect(addr): + if addr is not None and addr.startswith("anyscale://"): ray.init(address=addr, job_name=job_name) else: ray.init(address="auto") + wait_for_cluster_client(4, 600) main() taken = time.time() - start diff --git a/release/golden_notebook_tests/workloads/modin_xgboost_test.py b/release/golden_notebook_tests/workloads/modin_xgboost_test.py index 4180351e7cb40..d5fb36f07b23e 100644 --- a/release/golden_notebook_tests/workloads/modin_xgboost_test.py +++ b/release/golden_notebook_tests/workloads/modin_xgboost_test.py @@ -1,131 +1,28 @@ -import argparse -import json +import ray import os import time +import json +from util import import_and_execute_test_script, wait_for_cluster_client -import modin.pandas as pd -import ray -from ray import tune -from xgboost_ray import RayDMatrix, RayParams, train, predict - -from utils.utils import is_anyscale_connect - -FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/" \ - "simpleHIGGS.csv" - - -def train_xgboost(config, train_df, test_df, target_column, ray_params): - # distributed loading of a parquet dataset - train_set = RayDMatrix(train_df, target_column) - test_set = RayDMatrix(test_df, target_column) - - evals_result = {} - - start_time = time.time() - # Train the classifier - bst = train( - params=config, - dtrain=train_set, - evals=[(test_set, "eval")], - evals_result=evals_result, - verbose_eval=False, - num_boost_round=100, - ray_params=ray_params) - print(f"Total time taken: {time.time()-start_time}") - - model_path = "model.xgb" - bst.save_model(model_path) - print("Final validation error: {:.4f}".format( - evals_result["eval"]["error"][-1])) - - return bst - - -def tune_xgboost(train_df, test_df, target_column): - # Set XGBoost config. - config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], - "eta": tune.loguniform(1e-4, 1e-1), - "subsample": tune.uniform(0.5, 1.0), - "max_depth": tune.randint(1, 9) - } - - ray_params = RayParams( - max_actor_restarts=1, gpus_per_actor=0, cpus_per_actor=1, num_actors=2) - - analysis = tune.run( - tune.with_parameters( - train_xgboost, - train_df=train_df, - test_df=test_df, - target_column=target_column, - ray_params=ray_params), - # Use the `get_tune_resources` helper function to set the resources. - resources_per_trial=ray_params.get_tune_resources(), - config=config, - num_samples=1, - metric="eval-error", - mode="min", - verbose=1) - - accuracy = 1. - analysis.best_result["eval-error"] - print(f"Best model parameters: {analysis.best_config}") - print(f"Best model total accuracy: {accuracy:.4f}") - - return analysis.best_config +NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO = ( + "doc/examples/modin_xgboost/modin_xgboost.py") def main(): - print("Loading HIGGS data.") - - colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)] - - data = pd.read_csv(FILE_URL, names=colnames) - - print("Loaded HIGGS data.") - - # partition on a column - df_train = data[(data["feature-01"] < 0.4)] - df_validation = data[(data["feature-01"] >= 0.4) - & (data["feature-01"] < 0.8)] - - config = { - "tree_method": "approx", - "objective": "binary:logistic", - "eval_metric": ["logloss", "error"], - } - - bst = train_xgboost( - config, df_train, df_validation, "label", - RayParams(max_actor_restarts=1, cpus_per_actor=4, num_actors=4)) - # tune_xgboost(df_train, df_validation, "label") # broken atm - inference_df = RayDMatrix( - df_train[sorted(df_train.columns)], ignore=["label", "partition"]) - predict( - bst, - inference_df, - ray_params=RayParams(cpus_per_actor=1, num_actors=16)) + import_and_execute_test_script(NOTEBOOK_PATH_RELATIVE_TO_RAY_REPO) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--smoke-test", - action="store_true", - help="Finish quickly for testing.") - args = parser.parse_args() - start = time.time() addr = os.environ.get("RAY_ADDRESS") job_name = os.environ.get("RAY_JOB_NAME", "modin_xgboost_test") - if is_anyscale_connect(addr): + if addr is not None and addr.startswith("anyscale://"): ray.init(address=addr, job_name=job_name) else: ray.init(address="auto") + wait_for_cluster_client(4, 600) main() taken = time.time() - start diff --git a/release/golden_notebook_tests/workloads/torch_tune_serve_test.py b/release/golden_notebook_tests/workloads/torch_tune_serve_test.py index 15bd43a575a7a..9b511d5765ae6 100644 --- a/release/golden_notebook_tests/workloads/torch_tune_serve_test.py +++ b/release/golden_notebook_tests/workloads/torch_tune_serve_test.py @@ -17,8 +17,6 @@ from torch.utils.data import DataLoader, Subset from torchvision.datasets import MNIST -from utils.utils import is_anyscale_connect - def load_mnist_data(train: bool, download: bool): transform = transforms.Compose( @@ -200,7 +198,7 @@ def test_predictions(test_mode=False): addr = os.environ.get("RAY_ADDRESS") job_name = os.environ.get("RAY_JOB_NAME", "torch_tune_serve_test") - if is_anyscale_connect(addr): + if addr is not None and addr.startswith("anyscale://"): client = ray.init(address=addr, job_name=job_name) else: client = ray.init(address="auto") diff --git a/release/golden_notebook_tests/workloads/util.py b/release/golden_notebook_tests/workloads/util.py new file mode 100644 index 0000000000000..a0efc28b0e73a --- /dev/null +++ b/release/golden_notebook_tests/workloads/util.py @@ -0,0 +1,49 @@ +from pathlib import Path +import importlib.util +import ray +import time + + +def import_and_execute_test_script(relative_path_to_test_script: str): + """Imports and executes a module from a path relative to Ray repo root.""" + # get the ray folder + ray_path = next( + x for x in Path(__file__).resolve().parents if str(x).endswith("/ray")) + notebook_path = ray_path.joinpath(relative_path_to_test_script) + assert notebook_path.exists() + + spec = importlib.util.spec_from_file_location("notebook_test", + notebook_path) + notebook_test_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(notebook_test_module) + + +def wait_for_cluster_client(num_nodes: int, + max_time_s: int, + feedback_interval_s: int = 10): + assert ray.is_initialized() + curr_nodes = 0 + start = time.time() + next_feedback = start + max_time = start + max_time_s + while not curr_nodes >= num_nodes: + now = time.time() + + if now >= max_time: + raise RuntimeError( + f"Maximum wait time reached, but only " + f"{curr_nodes}/{num_nodes} nodes came up. Aborting.") + + if now >= next_feedback: + passed = now - start + print(f"Waiting for more nodes to come up: " + f"{curr_nodes}/{num_nodes} " + f"({passed:.0f} seconds passed)") + next_feedback = now + feedback_interval_s + + time.sleep(5) + curr_nodes = len(ray.nodes()) + + passed = time.time() - start + print(f"Cluster is up: {curr_nodes}/{num_nodes} nodes online after " + f"{passed:.0f} seconds") diff --git a/release/golden_notebook_tests/workloads/utils/utils.py b/release/golden_notebook_tests/workloads/utils/utils.py deleted file mode 100644 index 071f076c72aee..0000000000000 --- a/release/golden_notebook_tests/workloads/utils/utils.py +++ /dev/null @@ -1,5 +0,0 @@ -def is_anyscale_connect(address: str) -> bool: - """Returns whether or not the Ray Address points to an Anyscale cluster.""" - is_anyscale_connect = address is not None and address.startswith( - "anyscale://") - return is_anyscale_connect diff --git a/release/kubernetes_manual_tests/README.md b/release/kubernetes_manual_tests/README.md new file mode 100644 index 0000000000000..12b61f272b079 --- /dev/null +++ b/release/kubernetes_manual_tests/README.md @@ -0,0 +1,25 @@ +# ray-k8s-tests + +These tests are not automated and thus **must be run manually** for each release. +If you have issues running them, bug the code owner(s) for OSS Kubernetes support. + +## How to run +1. Configure kubectl and Helm 3 to access a K8s cluster. +2. `git checkout releases/` +3. You might have to locally pip install the Ray wheel for the relevant commit (or pip install -e) in a conda env, see Ray client note below. +4. cd to this directory +5. `IMAGE=rayproject/ray: bash k8s_release_tests.sh` +6. Test outcomes will be reported at the end of the output. + +This runs three tests and does the necessary resource creation/teardown. The tests typically take about 15 minutes to finish. + +## Notes +0. Anyscale employees: You should have access to create a K8s cluster using either GKE or EKS, ask OSS Kubernetes code owner if in doubt. +1. Your Ray cluster should be able to accomodate 30 1-CPU pods to run all of the tests. +2. These tests use basic Ray client functionality -- your locally installed Ray version may need to be updated to match the one in the release image. +3. The tests do a poor job of Ray client port-forwarding process clean-up -- if a test fails, it's possible there might be a port-forwarding process stuck running in the background. To identify the rogue process run `ps aux | grep "port-forward"`. Then `kill` it. +4. There are some errors that will appear on the screen during the run -- that's normal, error recovery is being tested. + +## Running individual tests +To run any of the three individual tests, substitute in step 5 of **How to Run** `k8s-test.sh` or `helm-test.sh` or `k8s-test-scale.sh`. +It's the last of these that needs 30 1-cpu pods. 10 is enough for either of the other two. The scale test is currently somewhat flaky. Rerun it if it fails. diff --git a/release/kubernetes_manual_tests/helm-test.sh b/release/kubernetes_manual_tests/helm-test.sh new file mode 100755 index 0000000000000..273ddb5c1cc11 --- /dev/null +++ b/release/kubernetes_manual_tests/helm-test.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -x +kubectl create namespace helm-test +kubectl create namespace helm-test2 +KUBERNETES_OPERATOR_TEST_NAMESPACE=helm-test KUBERNETES_OPERATOR_TEST_IMAGE="$IMAGE" python ../../python/ray/tests/kubernetes_e2e/test_helm.py +kubectl delete namespace helm-test +kubectl delete namespace helm-test2 +kubectl delete -f ../../deploy/charts/ray/crds/cluster_crd.yaml diff --git a/release/kubernetes_manual_tests/k8s-test-scale.sh b/release/kubernetes_manual_tests/k8s-test-scale.sh new file mode 100755 index 0000000000000..59ea06c80f5f1 --- /dev/null +++ b/release/kubernetes_manual_tests/k8s-test-scale.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -x +kubectl create namespace scale-test +kubectl create namespace scale-test2 +KUBERNETES_OPERATOR_TEST_NAMESPACE=scale-test KUBERNETES_OPERATOR_TEST_IMAGE="$IMAGE" python ../../python/ray/tests/kubernetes_e2e/test_k8s_operator_scaling.py +kubectl -n scale-test delete --all rayclusters +kubectl -n scale-test2 delete --all rayclusters +kubectl delete -f ../../deploy/components/operator_cluster_scoped.yaml +kubectl delete namespace scale-test +kubectl delete namespace scale-test2 +kubectl delete -f ../../deploy/charts/ray/crds/cluster_crd.yaml diff --git a/release/kubernetes_manual_tests/k8s-test.sh b/release/kubernetes_manual_tests/k8s-test.sh new file mode 100755 index 0000000000000..aa0ec6325d880 --- /dev/null +++ b/release/kubernetes_manual_tests/k8s-test.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -x +kubectl create namespace basic-test +kubectl apply -f ../../deploy/charts/ray/crds/cluster_crd.yaml +KUBERNETES_OPERATOR_TEST_NAMESPACE=basic-test KUBERNETES_OPERATOR_TEST_IMAGE="$IMAGE" python ../../python/ray/tests/kubernetes_e2e/test_k8s_operator_basic.py +kubectl -n basic-test delete --all rayclusters +kubectl -n basic-test delete deployment ray-operator +kubectl delete namespace basic-test +kubectl delete -f ../../deploy/charts/ray/crds/cluster_crd.yaml diff --git a/release/kubernetes_manual_tests/k8s_release_tests.sh b/release/kubernetes_manual_tests/k8s_release_tests.sh new file mode 100644 index 0000000000000..6576dcdabfa39 --- /dev/null +++ b/release/kubernetes_manual_tests/k8s_release_tests.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -x +IMAGE="$IMAGE" bash k8s-test.sh +BASIC_SUCCEEDED=$? +IMAGE="$IMAGE" bash helm-test.sh +HELM_SUCCEEDED=$? +IMAGE="$IMAGE" bash k8s-test-scale.sh +SCALE_SUCCEEDED=$? + +if (( BASIC_SUCCEEDED == 0 )) +then + echo "k8s-test.sh succeeded" +else + echo "k8s-test.sh test failed" +fi + +if (( HELM_SUCCEEDED == 0 )) +then + echo "helm-test.sh test succeeded"; +else + echo "helm-test.sh test failed" +fi + +if (( SCALE_SUCCEEDED == 0)) +then + echo "k8s-test-scale.sh test succeeded"; +else + echo "k8s-test-scale.sh failed. Try re-running just the k8s-test-scale.sh. It's expected to be flaky." +fi + diff --git a/release/long_running_tests/tpl_cpu_1.yaml b/release/long_running_tests/tpl_cpu_1.yaml index 1045aa8948456..a22bc5dfc95a7 100644 --- a/release/long_running_tests/tpl_cpu_1.yaml +++ b/release/long_running_tests/tpl_cpu_1.yaml @@ -22,3 +22,8 @@ aws: Value: '{{env["ANYSCALE_USER"]}}' - Key: anyscale-expiration Value: '{{env["EXPIRATION_2D"]}}' + + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 202 \ No newline at end of file diff --git a/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml b/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml index 9b7a0a9a11d3f..1aa0b86782476 100644 --- a/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml +++ b/release/nightly_tests/dask_on_ray/large_scale_dask_on_ray_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} debian_packages: [] python: diff --git a/release/nightly_tests/dataset/app_config.yaml b/release/nightly_tests/dataset/app_config.yaml index c0cc753990de9..5f311fbabfe87 100644 --- a/release/nightly_tests/dataset/app_config.yaml +++ b/release/nightly_tests/dataset/app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" -env_vars: {} python: pip_packages: [] diff --git a/release/nightly_tests/dataset/dataset_shuffle_data_loader.py b/release/nightly_tests/dataset/dataset_shuffle_data_loader.py index da3a7d74649f0..e917624a4712b 100644 --- a/release/nightly_tests/dataset/dataset_shuffle_data_loader.py +++ b/release/nightly_tests/dataset/dataset_shuffle_data_loader.py @@ -85,7 +85,7 @@ def create_torch_iterator(split, batch_size, rank=None): def create_dataset(filenames, repeat_times): pipeline = ray.data.read_parquet(list(filenames))\ - .repeat(times=repeat_times).random_shuffle() + .repeat(times=repeat_times).random_shuffle_each_window() return pipeline diff --git a/release/nightly_tests/dataset/pipelined_ingestion_app.yaml b/release/nightly_tests/dataset/pipelined_ingestion_app.yaml index 2fbda804b9b50..23ee18a1008b7 100644 --- a/release/nightly_tests/dataset/pipelined_ingestion_app.yaml +++ b/release/nightly_tests/dataset/pipelined_ingestion_app.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" -env_vars: {} python: pip_packages: [] diff --git a/release/nightly_tests/dataset/pipelined_training.py b/release/nightly_tests/dataset/pipelined_training.py index d9a4b9245bee1..c8c7486724755 100644 --- a/release/nightly_tests/dataset/pipelined_training.py +++ b/release/nightly_tests/dataset/pipelined_training.py @@ -244,12 +244,12 @@ def __next__(self): i * num_rows // num_windows // num_workers for i in range(1, num_workers) ] - pipe = pipe.random_shuffle(_spread_resource_prefix="node:") + pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:") pipe_shards = pipe.split_at_indices(split_indices) else: ds = ray.data.read_parquet(files, _spread_resource_prefix="node:") pipe = ds.repeat(epochs) - pipe = pipe.random_shuffle(_spread_resource_prefix="node:") + pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:") pipe_shards = pipe.split(num_workers, equal=True) return pipe_shards diff --git a/release/nightly_tests/dataset/pipelined_training_app.yaml b/release/nightly_tests/dataset/pipelined_training_app.yaml index 2fbda804b9b50..23ee18a1008b7 100644 --- a/release/nightly_tests/dataset/pipelined_training_app.yaml +++ b/release/nightly_tests/dataset/pipelined_training_app.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" -env_vars: {} python: pip_packages: [] diff --git a/release/nightly_tests/dataset/shuffle_app_config.yaml b/release/nightly_tests/dataset/shuffle_app_config.yaml index ac02d79b90415..d89acec77a973 100644 --- a/release/nightly_tests/dataset/shuffle_app_config.yaml +++ b/release/nightly_tests/dataset/shuffle_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" -env_vars: {} python: pip_packages: ["boto3", "numpy", "torch", "tqdm", "pyarrow"] diff --git a/release/nightly_tests/decision_tree/decision_tree_app_config.yaml b/release/nightly_tests/decision_tree/decision_tree_app_config.yaml index 92f5d3707fe1c..70ae8eb896d16 100644 --- a/release/nightly_tests/decision_tree/decision_tree_app_config.yaml +++ b/release/nightly_tests/decision_tree/decision_tree_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} debian_packages: [] python: diff --git a/release/nightly_tests/many_nodes_tests/app_config.yaml b/release/nightly_tests/many_nodes_tests/app_config.yaml index 67eb10caac1e7..9586d050b0418 100644 --- a/release/nightly_tests/many_nodes_tests/app_config.yaml +++ b/release/nightly_tests/many_nodes_tests/app_config.yaml @@ -1,5 +1,5 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} +env_vars: {"RAY_gcs_server_rpc_server_thread_num": "8", "RAY_GCS_ACTOR_SCHEDULING_ENABLED": "true"} debian_packages: [] python: diff --git a/release/nightly_tests/nightly_tests.yaml b/release/nightly_tests/nightly_tests.yaml index d932924ffa6a3..9482eade1e713 100644 --- a/release/nightly_tests/nightly_tests.yaml +++ b/release/nightly_tests/nightly_tests.yaml @@ -317,13 +317,24 @@ prepare: python wait_cluster.py 32 1000 script: python dask_on_ray/dask_on_ray_sort.py --nbytes 1_000_000_000_000 --npartitions 1000 --num-nodes 31 --ray --data-dir /tmp/ray --s3-bucket core-nightly-test -- name: many_nodes_actor_test +# TODO (yic): Add this back when we make it stable +# - name: many_nodes_actor_test +# cluster: +# app_config: many_nodes_tests/app_config.yaml +# compute_template: many_nodes_tests/compute_config.yaml + +# run: +# timeout: 7200 +# prepare: python wait_cluster.py 500 5400 +# script: python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 +# # TODO: enable failure test later +# #&& python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --fail --no-report && python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --no-report + +- name: pg_autoscaling_regression_test cluster: - app_config: many_nodes_tests/app_config.yaml - compute_template: many_nodes_tests/compute_config.yaml + app_config: placement_group_tests/app_config.yaml + compute_template: placement_group_tests/compute.yaml run: - timeout: 7200 - prepare: python wait_cluster.py 500 5400 - script: python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 - # TODO(yic): Add extra test for python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --fail --no-report && python many_nodes_tests/actor_test.py --cpus-per-actor=0.1 --total-actors=10000 --no-report + timeout: 1200 + script: python placement_group_tests/pg_run.py diff --git a/release/nightly_tests/placement_group_tests/app_config.yaml b/release/nightly_tests/placement_group_tests/app_config.yaml new file mode 100644 index 0000000000000..d30247838e1e9 --- /dev/null +++ b/release/nightly_tests/placement_group_tests/app_config.yaml @@ -0,0 +1,12 @@ +base_image: "anyscale/ray-ml:pinned-nightly-py37" +debian_packages: [] + +python: + pip_packages: [] + conda_packages: [] + +post_build_cmds: + - pip uninstall -y ray + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - pip3 install -U ray[default] + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/nightly_tests/placement_group_tests/cluster.py b/release/nightly_tests/placement_group_tests/cluster.py new file mode 100644 index 0000000000000..a12ed798a4e99 --- /dev/null +++ b/release/nightly_tests/placement_group_tests/cluster.py @@ -0,0 +1,13 @@ +import time +from ray.cluster_utils import Cluster + +cluster = Cluster() + +cluster.add_node(num_cpus=16) + +time.sleep(20) +print("Scaling up.") +cluster.add_node(num_cpus=16, num_gpus=1) + +print("Scaled up. Waiting for 1000 seconds until done.") +time.sleep(1000) diff --git a/release/nightly_tests/placement_group_tests/compute.yaml b/release/nightly_tests/placement_group_tests/compute.yaml new file mode 100644 index 0000000000000..5b619db7651a4 --- /dev/null +++ b/release/nightly_tests/placement_group_tests/compute.yaml @@ -0,0 +1,27 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +aws: + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 500 + +head_node_type: + name: head_node + instance_type: m5.4xlarge + +worker_node_types: + - name: cpu_node + instance_type: m5.4xlarge + min_workers: 0 + max_workers: 2 + use_spot: false + - name: fake_gpu_node + instance_type: m5.4xlarge + min_workers: 0 + max_workers: 2 + use_spot: false + resources: + cpu: 16 + gpu: 1 diff --git a/release/nightly_tests/placement_group_tests/pg_run.py b/release/nightly_tests/placement_group_tests/pg_run.py new file mode 100644 index 0000000000000..7bb616c2dcaa3 --- /dev/null +++ b/release/nightly_tests/placement_group_tests/pg_run.py @@ -0,0 +1,65 @@ +import os +import time +import json + +import ray +from ray.util.placement_group import placement_group + +# Tests are supposed to run for 10 minutes. +RUNTIME = 600 +NUM_CPU_BUNDLES = 30 + + +@ray.remote(num_cpus=1) +class Worker(object): + def __init__(self, i): + self.i = i + + def work(self): + time.sleep(0.1) + print("work ", self.i) + + +@ray.remote(num_cpus=1, num_gpus=1) +class Trainer(object): + def __init__(self, i): + self.i = i + + def train(self): + time.sleep(0.2) + print("train ", self.i) + + +def main(): + ray.init(address="auto") + + bundles = [{"CPU": 1, "GPU": 1}] + bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)] + + pg = placement_group(bundles, strategy="PACK") + + ray.get(pg.ready()) + + workers = [ + Worker.options(placement_group=pg).remote(i) + for i in range(NUM_CPU_BUNDLES) + ] + + trainer = Trainer.options(placement_group=pg).remote(0) + + start = time.time() + while True: + ray.get([workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)]) + ray.get(trainer.train.remote()) + end = time.time() + if end - start > RUNTIME: + break + + if "TEST_OUTPUT_JSON" in os.environ: + out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") + results = {} + json.dump(results, out_file) + + +if __name__ == "__main__": + main() diff --git a/release/nightly_tests/shuffle/shuffle_app_config.yaml b/release/nightly_tests/shuffle/shuffle_app_config.yaml index 67eb10caac1e7..d30247838e1e9 100644 --- a/release/nightly_tests/shuffle/shuffle_app_config.yaml +++ b/release/nightly_tests/shuffle/shuffle_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} debian_packages: [] python: @@ -10,5 +9,4 @@ post_build_cmds: - pip uninstall -y ray - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} - pip3 install -U ray[default] - - echo {{env["DATESTAMP"]}} - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml b/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml index 2fea571c90f77..536c7b6da27f4 100644 --- a/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml +++ b/release/nightly_tests/shuffle_data_loader/shuffle_data_loader_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} debian_packages: [] python: diff --git a/release/nightly_tests/stress_tests/stress_tests_app_config.yaml b/release/nightly_tests/stress_tests/stress_tests_app_config.yaml index 1f264f9fa1e44..66c99bb3bfe5a 100644 --- a/release/nightly_tests/stress_tests/stress_tests_app_config.yaml +++ b/release/nightly_tests/stress_tests/stress_tests_app_config.yaml @@ -1,5 +1,4 @@ base_image: "anyscale/ray-ml:pinned-nightly-py37" -env_vars: {} debian_packages: [] python: diff --git a/release/release_logs/1.7.0/benchmarks/many_actors.txt b/release/release_logs/1.7.0/benchmarks/many_actors.txt new file mode 100644 index 0000000000000..2995df9b7f18d --- /dev/null +++ b/release/release_logs/1.7.0/benchmarks/many_actors.txt @@ -0,0 +1,10 @@ +{ + "actors_per_second": 333.2797984180003, + "num_actors": 10000, + "time": 30.0048189163208, + "success": "1", + "_runtime": 43.551865577697754, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_han7mApDaGYvrbvhuLKBSGBz", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/benchmarks/many_nodes.txt b/release/release_logs/1.7.0/benchmarks/many_nodes.txt new file mode 100644 index 0000000000000..d6d5a3c0b6631 --- /dev/null +++ b/release/release_logs/1.7.0/benchmarks/many_nodes.txt @@ -0,0 +1,10 @@ +{ + "tasks_per_second": 3.224712885579051, + "num_tasks": 1000, + "time": 610.1051273345947, + "success": "1", + "_runtime": 620.4832813739777, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_6f82dxdGaxTV4uZNSamTYGLY", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/benchmarks/many_pgs.txt b/release/release_logs/1.7.0/benchmarks/many_pgs.txt new file mode 100644 index 0000000000000..560c050dcecb4 --- /dev/null +++ b/release/release_logs/1.7.0/benchmarks/many_pgs.txt @@ -0,0 +1,10 @@ +{ + "pgs_per_second": 17.06879130613137, + "num_pgs": 1000, + "time": 58.586456537246704, + "success": "1", + "_runtime": 69.5553240776062, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_gr3X2VEThCAQrtiHrJRd8yxW", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/benchmarks/many_tasks.txt b/release/release_logs/1.7.0/benchmarks/many_tasks.txt new file mode 100644 index 0000000000000..fa9c7d8d41db2 --- /dev/null +++ b/release/release_logs/1.7.0/benchmarks/many_tasks.txt @@ -0,0 +1,10 @@ +{ + "tasks_per_second": 27.508657888123608, + "num_tasks": 10000, + "time": 663.5219151973724, + "success": "1", + "_runtime": 674.2678966522217, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_XCJkRqS4HkuHLXehx7i6Fwvc", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/microbenchmark.txt b/release/release_logs/1.7.0/microbenchmark.txt new file mode 100644 index 0000000000000..b5fa29117583d --- /dev/null +++ b/release/release_logs/1.7.0/microbenchmark.txt @@ -0,0 +1,134 @@ +{ + "single_client_get_calls": [ + 34647.91400708946, + 311.7390971967917 + ], + "single_client_put_calls": [ + 58969.83872190603, + 869.618205663433 + ], + "multi_client_put_calls": [ + 199832.5755298421, + 2482.9205035774476 + ], + "single_client_get_calls_Plasma_Store": [ + 7082.757370159696, + 146.62873820799672 + ], + "single_client_put_calls_Plasma_Store": [ + 6321.65654587901, + 11.077913617295936 + ], + "multi_client_put_calls_Plasma_Store": [ + 9186.218655830648, + 112.23231532820908 + ], + "single_client_put_gigabytes": [ + 20.299125005168346, + 5.063681202623047 + ], + "single_client_tasks_and_get_batch": [ + 13.14018865978927, + 0.3152301478634011 + ], + "multi_client_put_gigabytes": [ + 36.56441662881655, + 1.843382220404724 + ], + "single_client_get_object_containing_10k_refs": [ + 10.351906653488715, + 0.23442465466734483 + ], + "single_client_tasks_sync": [ + 1257.4155346823063, + 16.879731074181798 + ], + "single_client_tasks_async": [ + 13436.707639489237, + 467.0229967004351 + ], + "multi_client_tasks_async": [ + 37893.82918345513, + 2501.210898297811 + ], + "1_1_actor_calls_sync": [ + 2018.517206134362, + 4.133444448098185 + ], + "1_1_actor_calls_async": [ + 5107.498479502846, + 155.05763494606228 + ], + "1_1_actor_calls_concurrent": [ + 4974.868578485068, + 46.89895438701842 + ], + "1_n_actor_calls_async": [ + 13035.656413458306, + 263.67959962428176 + ], + "n_n_actor_calls_async": [ + 42424.91241384691, + 909.2063842725172 + ], + "n_n_actor_calls_with_arg_async": [ + 2910.8727809194884, + 142.55651461439174 + ], + "1_1_async_actor_calls_sync": [ + 1434.0111494545497, + 15.145616176257736 + ], + "1_1_async_actor_calls_async": [ + 3227.631490168903, + 74.52309737428871 + ], + "1_1_async_actor_calls_with_args_async": [ + 2417.18007329992, + 42.010241468147406 + ], + "1_n_async_actor_calls_async": [ + 13212.476889889944, + 280.91562344862103 + ], + "n_n_async_actor_calls_async": [ + 32212.030653578477, + 4172.2556150359205 + ], + "client__get_calls": [ + 1518.5267029642152, + 18.33838666361156 + ], + "client__put_calls": [ + 869.7170835067376, + 8.603084105450836 + ], + "client__put_gigabytes": [ + 0.11768745420143228, + 0.002542373184018965 + ], + "client__tasks_and_put_batch": [ + 58861.12144186892, + 546.7701167395176 + ], + "client__1_1_actor_calls_sync": [ + 472.8343418119895, + 6.16968890867776 + ], + "client__1_1_actor_calls_async": [ + 742.6478263697102, + 2.886810073788351 + ], + "client__1_1_actor_calls_concurrent": [ + 729.3572241473628, + 19.903703549912592 + ], + "client__tasks_and_get_batch": [ + 0.6990944804839968, + 0.00738047968242822 + ], + "_runtime": 558.9188287258148, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_AHVUzrAzUMiLZ4p9EEAbL68s", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/scalability/object_store.txt b/release/release_logs/1.7.0/scalability/object_store.txt new file mode 100644 index 0000000000000..6917229b88dc5 --- /dev/null +++ b/release/release_logs/1.7.0/scalability/object_store.txt @@ -0,0 +1,10 @@ +{ + "broadcast_time": 611.015479593, + "object_size": 1073741824, + "num_nodes": 50, + "success": "1", + "_runtime": 620.4363269805908, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_Chj4PHZqrEjbzc8Ni4RY1Fev", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/scalability/single_node.txt b/release/release_logs/1.7.0/scalability/single_node.txt new file mode 100644 index 0000000000000..c868fa3c8eb4e --- /dev/null +++ b/release/release_logs/1.7.0/scalability/single_node.txt @@ -0,0 +1,16 @@ +{ + "args_time": 17.256289814000013, + "num_args": 10000, + "returns_time": 5.854934190999984, + "num_returns": 3000, + "get_time": 25.88724605799996, + "queued_time": 140.99555420300004, + "num_queued": 1000000, + "large_object_time": 294.249499343, + "large_object_size": 107374182400, + "success": "1", + "_runtime": 528.4356288909912, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_ELgpggWSHiqhksawLcz4urEP", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/stress_tests/dead_actors.txt b/release/release_logs/1.7.0/stress_tests/dead_actors.txt new file mode 100644 index 0000000000000..ab763e4173b75 --- /dev/null +++ b/release/release_logs/1.7.0/stress_tests/dead_actors.txt @@ -0,0 +1,11 @@ +{ + "success": 1, + "total_time": 130.34314274787903, + "avg_iteration_time": 1.303428828716278, + "max_iteration_time": 3.651247501373291, + "min_iteration_time": 0.09438443183898926, + "_runtime": 902.0143933296204, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_pxDnaxYFzDNsyifjJNV1qhqs", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/stress_tests/many_tasks.txt b/release/release_logs/1.7.0/stress_tests/many_tasks.txt new file mode 100644 index 0000000000000..a0244c5b28489 --- /dev/null +++ b/release/release_logs/1.7.0/stress_tests/many_tasks.txt @@ -0,0 +1,19 @@ +{ + "success": 1, + "stage_0_time": 5.256332874298096, + "stage_1_time": 174.50774693489075, + "stage_1_avg_iteration_time": 17.450765538215638, + "stage_1_max_iteration_time": 17.627604961395264, + "stage_1_min_iteration_time": 17.23277997970581, + "stage_2_time": 268.01243686676025, + "stage_2_avg_iteration_time": 53.60213441848755, + "stage_2_max_iteration_time": 59.097413063049316, + "stage_2_min_iteration_time": 48.71518564224243, + "stage_3_creation_time": 0.5777060985565186, + "stage_3_time": 2066.70570230484, + "stage_4_spread": 3.2197082901427945, + "_runtime": 5045.744384527206, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_b8v2V4Tr7vwee6tCDjTjdXLL", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/release_logs/1.7.0/stress_tests/placement_group.txt b/release/release_logs/1.7.0/stress_tests/placement_group.txt new file mode 100644 index 0000000000000..cbe7c99c54a04 --- /dev/null +++ b/release/release_logs/1.7.0/stress_tests/placement_group.txt @@ -0,0 +1,9 @@ +{ + "success": 1, + "avg_pg_create_time_ms": 0.9874122837809874, + "avg_pg_remove_time_ms": 4.4027920900909265, + "_runtime": 458.8596382141113, + "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_7uQL743cWCzdDT3ZYTpRDETi", + "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.7.0/2367a2cb9033913b68b1230316496ae273c25b54/ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable": true +} diff --git a/release/util/pip_download_test.sh b/release/util/pip_download_test.sh index 6ab91732ab255..c1d998b44e2b1 100755 --- a/release/util/pip_download_test.sh +++ b/release/util/pip_download_test.sh @@ -56,7 +56,7 @@ do else failed=true fi - if sh sanity_check_cpp.sh; then + if bash sanity_check_cpp.sh; then echo "PYTHON ${PYTHON_VERSION} succeed sanity check C++." else cpp_failed=true diff --git a/rllib/BUILD b/rllib/BUILD index f4c527bbb8099..b09e149b14a22 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -876,25 +876,13 @@ py_test( srcs = ["agents/ddpg/tests/test_ddpg.py"] ) -# DQNTrainer/SimpleQTrainer +# DQNTrainer py_test( name = "test_dqn", tags = ["team:ml", "trainers_dir"], size = "large", srcs = ["agents/dqn/tests/test_dqn.py"] ) -py_test( - name = "test_r2d2", - tags = ["team:ml", "trainers_dir"], - size = "large", - srcs = ["agents/dqn/tests/test_r2d2.py"] -) -py_test( - name = "test_simple_q", - tags = ["team:ml", "trainers_dir"], - size = "medium", - srcs = ["agents/dqn/tests/test_simple_q.py"] -) # Dreamer py_test( @@ -1002,6 +990,22 @@ py_test( srcs = ["agents/qmix/tests/test_qmix.py"] ) +# R2D2Trainer +py_test( + name = "test_r2d2", + tags = ["team:ml", "trainers_dir"], + size = "large", + srcs = ["agents/dqn/tests/test_r2d2.py"] +) + +# RNNSACTrainer +py_test( + name = "test_rnnsac", + tags = ["team:ml", "trainers_dir"], + size = "medium", + srcs = ["agents/sac/tests/test_rnnsac.py"] +) + # SACTrainer py_test( name = "test_sac", @@ -1010,6 +1014,14 @@ py_test( srcs = ["agents/sac/tests/test_sac.py"] ) +# SimpleQTrainer +py_test( + name = "test_simple_q", + tags = ["team:ml", "trainers_dir"], + size = "medium", + srcs = ["agents/dqn/tests/test_simple_q.py"] +) + # TD3Trainer py_test( name = "test_td3", @@ -1328,18 +1340,38 @@ py_test( # -------------------------------------------------------------------- sh_test( - name = "env/tests/test_local_inference", + name = "env/tests/test_local_inference_cartpole", tags = ["team:ml", "env"], size = "medium", - srcs = ["env/tests/test_local_inference.sh"], + srcs = ["env/tests/test_policy_client_server_setup.sh"], + args = ["local", "cartpole"], data = glob(["examples/serving/*.py"]), ) sh_test( - name = "env/tests/test_remote_inference", + name = "env/tests/test_remote_inference_cartpole", tags = ["team:ml", "env"], size = "medium", - srcs = ["env/tests/test_remote_inference.sh"], + srcs = ["env/tests/test_policy_client_server_setup.sh"], + args = ["remote", "cartpole"], + data = glob(["examples/serving/*.py"]), +) + +sh_test( + name = "env/tests/test_local_inference_unity3d", + tags = ["team:ml", "env"], + size = "medium", + srcs = ["env/tests/test_policy_client_server_setup.sh"], + args = ["local", "unity3d"], + data = glob(["examples/serving/*.py"]), +) + +sh_test( + name = "env/tests/test_remote_inference_unity3d", + tags = ["team:ml", "env"], + size = "medium", + srcs = ["env/tests/test_policy_client_server_setup.sh"], + args = ["remote", "unity3d"], data = glob(["examples/serving/*.py"]), ) @@ -1350,6 +1382,13 @@ py_test( srcs = ["env/tests/test_record_env_wrapper.py"] ) +py_test( + name = "env/tests/test_remote_worker_envs", + tags = ["team:ml", "env"], + size = "medium", + srcs = ["env/tests/test_remote_worker_envs.py"] +) + py_test( name = "env/wrappers/tests/test_unity3d_env", tags = ["team:ml", "env"], @@ -1847,14 +1886,14 @@ py_test( args = ["TestSupportedMultiAgentOffPolicy"] ) -# py_test( -# name = "tests/test_supported_spaces_pg", -# main = "tests/test_supported_spaces.py", -# tags = ["team:ml", "tests_dir", "tests_dir_S"], -# size = "enormous", -# srcs = ["tests/test_supported_spaces.py"], -# args = ["TestSupportedSpacesPG"] -# ) +py_test( + name = "tests/test_supported_spaces_pg", + main = "tests/test_supported_spaces.py", + tags = ["team:ml", "tests_dir", "tests_dir_S"], + size = "large", + srcs = ["tests/test_supported_spaces.py"], + args = ["TestSupportedSpacesPG"] + ) py_test( name = "tests/test_supported_spaces_off_policy", diff --git a/rllib/agents/a3c/a3c_tf_policy.py b/rllib/agents/a3c/a3c_tf_policy.py index cbc5bbbd797d6..6e7b362a4fd95 100644 --- a/rllib/agents/a3c/a3c_tf_policy.py +++ b/rllib/agents/a3c/a3c_tf_policy.py @@ -111,7 +111,7 @@ def grad_stats(policy: Policy, train_batch: SampleBatch, "grad_gnorm": tf.linalg.global_norm(grads), "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], - policy.model.value_function()), + policy.model.value_function()) } diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index 99172adb814e0..ea44f4767cfdc 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -72,19 +72,25 @@ def actor_critic_loss(policy: Policy, model: ModelV2, total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] - entropy * policy.config["entropy_coeff"]) - policy.entropy = entropy - policy.pi_err = pi_err - policy.value_err = value_err + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["entropy"] = entropy + model.tower_stats["pi_err"] = pi_err + model.tower_stats["value_err"] = value_err return total_loss def loss_and_entropy_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: + return { - "policy_entropy": policy.entropy, - "policy_loss": policy.pi_err, - "vf_loss": policy.value_err, + "policy_entropy": torch.mean( + torch.stack(policy.get_tower_stats("entropy"))), + "policy_loss": torch.mean( + torch.stack(policy.get_tower_stats("pi_err"))), + "vf_loss": torch.mean( + torch.stack(policy.get_tower_stats("value_err"))), } diff --git a/rllib/agents/a3c/tests/test_a2c.py b/rllib/agents/a3c/tests/test_a2c.py index 2394b3f5812b7..4c8a259245adc 100644 --- a/rllib/agents/a3c/tests/test_a2c.py +++ b/rllib/agents/a3c/tests/test_a2c.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.a3c as a3c from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestA2C(unittest.TestCase): @@ -29,6 +29,7 @@ def test_a2c_compilation(self): trainer = a3c.A2CTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop() @@ -37,7 +38,9 @@ def test_a2c_exec_impl(ray_start_regular): config = {"min_iter_time_s": 0} for _ in framework_iterator(config): trainer = a3c.A2CTrainer(env="CartPole-v0", config=config) - assert isinstance(trainer.train(), dict) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() @@ -48,7 +51,9 @@ def test_a2c_exec_impl_microbatch(ray_start_regular): } for _ in framework_iterator(config): trainer = a3c.A2CTrainer(env="CartPole-v0", config=config) - assert isinstance(trainer.train(), dict) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/agents/a3c/tests/test_a3c.py b/rllib/agents/a3c/tests/test_a3c.py index 6ffbab01f955f..59147f213a7a5 100644 --- a/rllib/agents/a3c/tests/test_a3c.py +++ b/rllib/agents/a3c/tests/test_a3c.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.a3c as a3c from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestA3C(unittest.TestCase): @@ -31,6 +31,7 @@ def test_a3c_compilation(self): trainer = a3c.A3CTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action( trainer, include_state=config["model"]["use_lstm"]) diff --git a/rllib/agents/ars/tests/test_ars.py b/rllib/agents/ars/tests/test_ars.py index b6bb3c8df7277..a78353de44ac4 100644 --- a/rllib/agents/ars/tests/test_ars.py +++ b/rllib/agents/ars/tests/test_ars.py @@ -7,9 +7,16 @@ class TestARS(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init(num_cpus=3) + + @classmethod + def tearDownClass(cls): + ray.shutdown() + def test_ars_compilation(self): """Test whether an ARSTrainer can be built on all frameworks.""" - ray.init(num_cpus=3) config = ars.DEFAULT_CONFIG.copy() # Keep it simple. config["model"]["fcnet_hiddens"] = [10] @@ -30,7 +37,6 @@ def test_ars_compilation(self): check_compute_single_action(trainer) trainer.stop() - ray.shutdown() if __name__ == "__main__": diff --git a/rllib/agents/cql/cql.py b/rllib/agents/cql/cql.py index 3c9c026c7bc34..19f1573e29ba9 100644 --- a/rllib/agents/cql/cql.py +++ b/rllib/agents/cql/cql.py @@ -14,10 +14,11 @@ from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \ UpdateTargetNetwork from ray.rllib.offline.shuffled_input import ShuffledInput -from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import merge_dicts from ray.rllib.utils.framework import try_import_tf, try_import_tfp +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import TrainerConfigDict tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/cql/cql_torch_policy.py b/rllib/agents/cql/cql_torch_policy.py index fed6470dc585e..f62b23069a4fd 100644 --- a/rllib/agents/cql/cql_torch_policy.py +++ b/rllib/agents/cql/cql_torch_policy.py @@ -14,12 +14,12 @@ build_sac_model_and_action_dist, optimizer_fn, ComputeTDErrorMixin, \ TargetNetworkMixin, setup_late_mixins, action_distribution_fn from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import LocalOptimizer, TensorType, \ TrainerConfigDict from ray.rllib.utils.torch_ops import apply_grad_clipping, \ @@ -250,23 +250,29 @@ def cql_loss(policy: Policy, model: ModelV2, critic_loss[1].backward(retain_graph=False) policy.critic_optims[1].step() - # Save for stats function. - policy.q_t = q_t_selected - policy.policy_t = policy_t - policy.log_pis_t = log_pis_t - model.td_error = td_error - policy.actor_loss = actor_loss - policy.critic_loss = critic_loss - policy.alpha_loss = alpha_loss - policy.log_alpha_value = model.log_alpha - policy.alpha_value = alpha - policy.target_entropy = model.target_entropy - # CQL Stats. - policy.cql_loss = cql_loss + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + # SAC stats. + model.tower_stats["q_t"] = q_t_selected + model.tower_stats["policy_t"] = policy_t + model.tower_stats["log_pis_t"] = log_pis_t + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + model.tower_stats["alpha_loss"] = alpha_loss + model.tower_stats["log_alpha_value"] = model.log_alpha + model.tower_stats["alpha_value"] = alpha + model.tower_stats["target_entropy"] = model.target_entropy + # CQL stats. + model.tower_stats["cql_loss"] = cql_loss + + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error + if use_lagrange: - policy.log_alpha_prime_value = model.log_alpha_prime[0] - policy.alpha_prime_value = alpha_prime - policy.alpha_prime_loss = alpha_prime_loss + model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0] + model.tower_stats["alpha_prime_value"] = alpha_prime + model.tower_stats["alpha_prime_loss"] = alpha_prime_loss if obs.shape[0] == policy.config["train_batch_size"]: policy.alpha_prime_optim.zero_grad() @@ -274,22 +280,27 @@ def cql_loss(policy: Policy, model: ModelV2, policy.alpha_prime_optim.step() # Return all loss terms corresponding to our optimizers. - if use_lagrange: - return tuple([policy.actor_loss] + policy.critic_loss + - [policy.alpha_loss] + [policy.alpha_prime_loss]) - return tuple([policy.actor_loss] + policy.critic_loss + - [policy.alpha_loss]) + return tuple([actor_loss] + critic_loss + [alpha_loss] + + ([alpha_prime_loss] if use_lagrange else [])) def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: - sac_dict = stats(policy, train_batch) - sac_dict["cql_loss"] = torch.mean(torch.stack(policy.cql_loss)) + # Get SAC loss stats. + stats_dict = stats(policy, train_batch) + + # Add CQL loss stats to the dict. + stats_dict["cql_loss"] = torch.mean( + torch.stack(*policy.get_tower_stats("cql_loss"))) + if policy.config["lagrangian"]: - sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value - sac_dict["alpha_prime_value"] = policy.alpha_prime_value - sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss - return sac_dict + stats_dict["log_alpha_prime_value"] = torch.mean( + torch.stack(policy.get_tower_stats("log_alpha_prime_value"))) + stats_dict["alpha_prime_value"] = torch.mean( + torch.stack(policy.get_tower_stats("alpha_prime_value"))) + stats_dict["alpha_prime_loss"] = torch.mean( + torch.stack(policy.get_tower_stats("alpha_prime_loss"))) + return stats_dict def cql_optimizer_fn(policy: Policy, config: TrainerConfigDict) -> \ diff --git a/rllib/agents/cql/tests/test_cql.py b/rllib/agents/cql/tests/test_cql.py index 9f8466a220e00..7e3ef58896f67 100644 --- a/rllib/agents/cql/tests/test_cql.py +++ b/rllib/agents/cql/tests/test_cql.py @@ -7,7 +7,7 @@ import ray.rllib.agents.cql as cql from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -69,10 +69,13 @@ def test_cql_compilation(self): for fw in framework_iterator(config): trainer = cql.CQLTrainer(config=config) for i in range(num_iterations): - results = trainer.train().get("evaluation") - if results: + results = trainer.train() + check_train_results(results) + print(results) + eval_results = results.get("evaluation") + if eval_results: print(f"iter={trainer.iteration} " - f"R={results['episode_reward_mean']}") + f"R={eval_results['episode_reward_mean']}") check_compute_single_action(trainer) diff --git a/rllib/agents/ddpg/ddpg_tf_model.py b/rllib/agents/ddpg/ddpg_tf_model.py index 53d2d666dc60c..f3c4a3ece6e9b 100644 --- a/rllib/agents/ddpg/ddpg_tf_model.py +++ b/rllib/agents/ddpg/ddpg_tf_model.py @@ -1,6 +1,6 @@ import numpy as np import gym -from typing import List +from typing import List, Optional from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.framework import try_import_tf @@ -29,9 +29,9 @@ def __init__( model_config: ModelConfigDict, name: str, # Extra DDPGActionModel args: - actor_hiddens: List[int] = [256, 256], + actor_hiddens: Optional[List[int]] = None, actor_hidden_activation: str = "relu", - critic_hiddens: List[int] = [256, 256], + critic_hiddens: Optional[List[int]] = None, critic_hidden_activation: str = "relu", twin_q: bool = False, add_layer_norm: bool = False): @@ -48,6 +48,12 @@ def __init__( should be defined in subclasses of DDPGActionModel. """ + if actor_hiddens is None: + actor_hiddens = [256, 256] + + if critic_hiddens is None: + critic_hiddens = [256, 256] + super(DDPGTFModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index 8c24a84c04a5e..d3c295feba940 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -28,7 +28,7 @@ from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ - LocalOptimizer, ModelGradients, PolicyID + LocalOptimizer, ModelGradients from ray.util.debug import log_once tf1, tf, tfv = try_import_tf() @@ -429,17 +429,17 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, TargetNetworkMixin.__init__(policy, config) -def validate_spaces(pid: PolicyID, observation_space: gym.spaces.Space, +def validate_spaces(policy: Policy, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: if not isinstance(action_space, Box): raise UnsupportedSpaceException( "Action space ({}) of {} is not supported for " - "DDPG.".format(action_space, pid)) + "DDPG.".format(action_space, policy)) elif len(action_space.shape) > 1: raise UnsupportedSpaceException( "Action space ({}) of {} has multiple dimensions " - "{}. ".format(action_space, pid, action_space.shape) + + "{}. ".format(action_space, policy, action_space.shape) + "Consider reshaping this into a single dimension, " "using a Tuple action space, or the multi-agent API.") diff --git a/rllib/agents/ddpg/ddpg_torch_model.py b/rllib/agents/ddpg/ddpg_torch_model.py index 2297ee0b2a815..615e0ea8b5814 100644 --- a/rllib/agents/ddpg/ddpg_torch_model.py +++ b/rllib/agents/ddpg/ddpg_torch_model.py @@ -1,6 +1,6 @@ import numpy as np import gym -from typing import List, Dict, Union +from typing import List, Dict, Union, Optional from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 @@ -31,9 +31,9 @@ def __init__( model_config: ModelConfigDict, name: str, # Extra DDPGActionModel args: - actor_hiddens: List[int] = [256, 256], + actor_hiddens: Optional[List[int]] = None, actor_hidden_activation: str = "relu", - critic_hiddens: List[int] = [256, 256], + critic_hiddens: Optional[List[int]] = None, critic_hidden_activation: str = "relu", twin_q: bool = False, add_layer_norm: bool = False): @@ -51,6 +51,12 @@ def __init__( only defines the layers for the output heads. Those layers for forward() should be defined in subclasses of DDPGTorchModel. """ + if actor_hiddens is None: + actor_hiddens = [256, 256] + + if critic_hiddens is None: + critic_hiddens = [256, 256] + nn.Module.__init__(self) super(DDPGTorchModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index ef22a5e75fd47..c6eb6bddbda6e 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -172,18 +172,17 @@ def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _, [actor_loss, critic_loss] = model.custom_loss( [actor_loss, critic_loss], input_dict) - # Store values for stats function. - policy.q_t = q_t - policy.actor_loss = actor_loss - policy.critic_loss = critic_loss - - # Store td-error in model, such that for multi-GPU, we do not override - # them during the parallel loss phase. TD-error tensor in final stats - # can then be concatenated and retrieved for each individual batch item. - model.td_error = td_error + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["q_t"] = q_t + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error # Return two loss terms (corresponding to the two optimizers, we create). - return policy.actor_loss, policy.critic_loss + return actor_loss, critic_loss def make_ddpg_optimizers(policy: Policy, @@ -217,12 +216,16 @@ def apply_gradients_fn(policy: Policy, gradients: GradInfoDict) -> None: def build_ddpg_stats(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]: + + q_t = torch.stack(policy.get_tower_stats("q_t")) stats = { - "actor_loss": policy.actor_loss, - "critic_loss": policy.critic_loss, - "mean_q": torch.mean(policy.q_t), - "max_q": torch.max(policy.q_t), - "min_q": torch.min(policy.q_t), + "actor_loss": torch.mean( + torch.stack(policy.get_tower_stats("actor_loss"))), + "critic_loss": torch.mean( + torch.stack(policy.get_tower_stats("critic_loss"))), + "mean_q": torch.mean(q_t), + "max_q": torch.max(q_t), + "min_q": torch.min(q_t), } return stats @@ -251,8 +254,8 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # (one TD-error value per item in batch to update PR weights). loss_fn(self, self.model, None, input_dict) - # Self.td_error is set within actor_critic_loss call. - return self.model.td_error + # `self.model.td_error` is set within actor_critic_loss call. + return self.model.tower_stats["td_error"] self.compute_td_error = compute_td_error diff --git a/rllib/agents/ddpg/tests/test_apex_ddpg.py b/rllib/agents/ddpg/tests/test_apex_ddpg.py index 61556fb9b961b..16ebab9a1f9ae 100644 --- a/rllib/agents/ddpg/tests/test_apex_ddpg.py +++ b/rllib/agents/ddpg/tests/test_apex_ddpg.py @@ -4,7 +4,7 @@ import ray import ray.rllib.agents.ddpg.apex as apex_ddpg from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestApexDDPG(unittest.TestCase): @@ -40,7 +40,9 @@ def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self): check(scale, [0.0] + expected) for _ in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) # Test again per-worker scale distribution diff --git a/rllib/agents/ddpg/tests/test_ddpg.py b/rllib/agents/ddpg/tests/test_ddpg.py index be404e720d48e..7f72e03d0e30c 100644 --- a/rllib/agents/ddpg/tests/test_ddpg.py +++ b/rllib/agents/ddpg/tests/test_ddpg.py @@ -13,7 +13,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator from ray.rllib.utils.torch_ops import convert_to_torch_tensor tf1, tf, tfv = try_import_tf() @@ -45,6 +45,7 @@ def test_ddpg_compilation(self): trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0") for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) # Ensure apply_gradient_fn is being called and updating global_step @@ -288,8 +289,9 @@ def test_ddpg_loss_function(self): elif fw == "torch": loss_torch(policy, policy.model, None, input_) - c, a, t = policy.critic_loss, policy.actor_loss, \ - policy.model.td_error + c, a, t = policy.get_tower_stats("critic_loss")[0], \ + policy.get_tower_stats("actor_loss")[0], \ + policy.get_tower_stats("td_error")[0] # Check pure loss values. check(c, expect_c) check(a, expect_a) diff --git a/rllib/agents/ddpg/tests/test_td3.py b/rllib/agents/ddpg/tests/test_td3.py index 75b84e4ddc57e..a542cf5a1574d 100644 --- a/rllib/agents/ddpg/tests/test_td3.py +++ b/rllib/agents/ddpg/tests/test_td3.py @@ -5,7 +5,7 @@ import ray.rllib.agents.ddpg.td3 as td3 from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() @@ -30,6 +30,7 @@ def test_td3_compilation(self): num_iterations = 1 for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 74afc564f1708..49c24b07ed3e7 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -33,6 +33,7 @@ from ray.rllib.utils import merge_dicts from ray.rllib.utils.actors import create_colocated from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.typing import SampleBatchType from ray.tune.trainable import Trainable from ray.tune.utils.placement_groups import PlacementGroupFactory @@ -227,7 +228,7 @@ def add_apex_metrics(result: dict) -> dict: result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), - "learner": copy.deepcopy(learner_thread.stats), + LEARNER_INFO: copy.deepcopy(learner_thread.learner_info), "replay_shard_0": replay_stats, }) return result diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index ac4b8f0dbb8e5..5f1eadf020a39 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -25,7 +25,8 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork, \ MultiGPUTrainOneStep -from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -200,8 +201,17 @@ def update_prio(item): td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error")) samples.policy_batches[policy_id].set_get_interceptor(None) - prio_dict[policy_id] = (samples.policy_batches[policy_id] - .get("batch_indexes"), td_error) + batch_indices = samples.policy_batches[policy_id].get( + "batch_indexes") + # In case the buffer stores sequences, TD-error could already + # be calculated per sequence chunk. + if len(batch_indices) != len(td_error): + T = local_replay_buffer.replay_sequence_length + assert len(batch_indices) > len( + td_error) and len(batch_indices) % T == 0 + batch_indices = batch_indices.reshape([-1, T])[:, 0] + assert len(batch_indices) == len(td_error) + prio_dict[policy_id] = (batch_indices, td_error) local_replay_buffer.update_priorities(prio_dict) return info_dict diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index d060a1ce4012a..a7826d0da489c 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -121,7 +121,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # Do forward pass on loss to update td error attribute build_q_losses(self, self.model, None, input_dict) - return self.q_loss.td_error + return self.model.tower_stats["q_loss"].td_error self.compute_td_error = compute_td_error @@ -216,8 +216,9 @@ def get_distribution_inputs_and_class( is_training=is_training) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals - policy.q_values = q_vals - return policy.q_values, TorchCategorical, [] # state-out + model.tower_stats["q_values"] = q_vals + + return q_vals, TorchCategorical, [] # state-out def build_q_losses(policy: Policy, model, _, @@ -286,19 +287,21 @@ def build_q_losses(policy: Policy, model, _, q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1) - policy.q_loss = QLoss( - q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best, - train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS], - train_batch[SampleBatch.DONES].float(), config["gamma"], - config["n_step"], config["num_atoms"], config["v_min"], - config["v_max"]) + q_loss = QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, + q_probs_tp1_best, train_batch[PRIO_WEIGHTS], + train_batch[SampleBatch.REWARDS], + train_batch[SampleBatch.DONES].float(), config["gamma"], + config["n_step"], config["num_atoms"], config["v_min"], + config["v_max"]) - # Store td-error in model, such that for multi-GPU, we do not override - # them during the parallel loss phase. TD-error tensor in final stats - # can then be concatenated and retrieved for each individual batch item. - model.td_error = policy.q_loss.td_error + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["td_error"] = q_loss.td_error + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["q_loss"] = q_loss - return policy.q_loss.loss + return q_loss.loss def adam_optimizer(policy: Policy, @@ -314,9 +317,16 @@ def adam_optimizer(policy: Policy, def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: - return dict({ - "cur_lr": policy.cur_lr, - }, **policy.q_loss.stats) + stats = {} + for stats_key in policy.model_gpu_towers[0].tower_stats[ + "q_loss"].stats.keys(): + stats[stats_key] = torch.mean( + torch.stack([ + t.tower_stats["q_loss"].stats[stats_key].to(policy.device) + for t in policy.model_gpu_towers if "q_loss" in t.tower_stats + ])) + stats["cur_lr"] = policy.cur_lr + return stats def setup_early_mixins(policy: Policy, obs_space, action_space, @@ -385,7 +395,7 @@ def grad_process_and_td_error_fn(policy: Policy, def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, action_dist) -> Dict[str, TensorType]: - return {"q_values": policy.q_values} + return {"q_values": model.tower_stats["q_values"]} DQNTorchPolicy = build_policy_class( diff --git a/rllib/agents/dqn/learner_thread.py b/rllib/agents/dqn/learner_thread.py index 0f8d6f15bd79a..93bed4b18de5e 100644 --- a/rllib/agents/dqn/learner_thread.py +++ b/rllib/agents/dqn/learner_thread.py @@ -1,9 +1,8 @@ import queue import threading -from ray.rllib.evaluation.metrics import get_learner_stats -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat @@ -33,7 +32,7 @@ def __init__(self, local_worker): self.daemon = True self.weights_updated = False self.stopped = False - self.stats = {} + self.learner_info = {} def run(self): # Switch on eager mode if configured. @@ -49,11 +48,18 @@ def step(self): if replay is not None: prio_dict = {} with self.grad_timer: - grad_out = self.local_worker.learn_on_batch(replay) - for pid, info in grad_out.items(): - td_error = info.get( - "td_error", - info[LEARNER_STATS_KEY].get("td_error")) + # Use LearnerInfoBuilder as a unified way to build the + # final results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same + # structure no matter the setup (multi-GPU, multi-agent, + # minibatch SGD, tf vs torch). + learner_info_builder = LearnerInfoBuilder(num_devices=1) + multi_agent_results = self.local_worker.learn_on_batch( + replay) + for pid, results in multi_agent_results.items(): + learner_info_builder.add_learn_on_batch_results( + results, pid) + td_error = results["td_error"] # Switch off auto-conversion from numpy to torch/tf # tensors for the indices. This may lead to errors # when sent to the buffer for processing @@ -62,7 +68,7 @@ def step(self): prio_dict[pid] = ( replay.policy_batches[pid].get("batch_indexes"), td_error) - self.stats[pid] = get_learner_stats(info) + self.learner_info = learner_info_builder.finalize() self.grad_timer.push_units_processed(replay.count) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) diff --git a/rllib/agents/dqn/r2d2.py b/rllib/agents/dqn/r2d2.py index 7985b55fe305a..d568272e957e9 100644 --- a/rllib/agents/dqn/r2d2.py +++ b/rllib/agents/dqn/r2d2.py @@ -28,7 +28,7 @@ DEFAULT_CONFIG = dqn.DQNTrainer.merge_trainer_configs( dqn.DEFAULT_CONFIG, # See keys in impala.py, which are also supported. { - # Learning rate for adam optimizer + # Learning rate for adam optimizer. "lr": 1e-4, # Discount factor. "gamma": 0.997, @@ -40,8 +40,6 @@ "num_workers": 2, # Batch mode must be complete_episodes. "batch_mode": "complete_episodes", - # R2D2 does not suport n-step > 1 yet! - "n_step": 1, # If True, assume a zero-initialized state input (no matter where in # the episode the sequence is located). @@ -71,7 +69,6 @@ # Size of the replay buffer (in sequences, not timesteps). "buffer_size": 100000, # If True prioritized replay buffer will be used. - # Note: Not supported yet by R2D2! "prioritized_replay": False, # Set automatically: The number of contiguous environment steps to # replay at once. Will be calculated via @@ -91,7 +88,8 @@ def validate_config(config: TrainerConfigDict) -> None: """Checks and updates the config based on settings. - Rewrites rollout_fragment_length to take into account n_step truncation. + Rewrites rollout_fragment_length to take into account burn-in and + max_seq_len truncation. """ if config["replay_sequence_length"] != -1: raise ValueError( @@ -102,15 +100,9 @@ def validate_config(config: TrainerConfigDict) -> None: config["replay_sequence_length"] = \ config["burn_in"] + config["model"]["max_seq_len"] - if config.get("prioritized_replay"): - raise ValueError("Prioritized replay is not supported for R2D2 yet!") - if config.get("batch_mode") != "complete_episodes": raise ValueError("`batch_mode` must be 'complete_episodes'!") - if config["n_step"] > 1: - raise ValueError("`n_step` > 1 not yet supported by R2D2!") - def calculate_rr_weights(config: TrainerConfigDict) -> List[float]: """Calculate the round robin weights for the rollout and train steps""" diff --git a/rllib/agents/dqn/r2d2_tf_policy.py b/rllib/agents/dqn/r2d2_tf_policy.py index 1d72d12e7e25b..d34c35a44976b 100644 --- a/rllib/agents/dqn/r2d2_tf_policy.py +++ b/rllib/agents/dqn/r2d2_tf_policy.py @@ -156,7 +156,7 @@ def r2d2_loss(policy: Policy, model, _, def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, seq_mask)) - # Make sure use the correct time indices: + # Make sure to use the correct time indices: # Q(t) - [gamma * r + Q^(t+1)] q_selected = tf.reshape(q_selected, [B, T])[:, :-1] td_error = q_selected - tf.stop_gradient( @@ -164,7 +164,9 @@ def reduce_mean_valid(t): td_error = td_error * tf.cast(seq_mask, tf.float32) weights = tf.reshape(weights, [B, T])[:, :-1] policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) - policy._td_error = tf.reshape(td_error, [-1]) + # Store the TD-error per time chunk (b/c we need only one mean + # prioritized replay weight per stored sequence). + policy._td_error = tf.reduce_mean(td_error, axis=-1) policy._loss_stats = { "mean_q": reduce_mean_valid(q_selected), "min_q": tf.reduce_min(q_selected), diff --git a/rllib/agents/dqn/r2d2_torch_policy.py b/rllib/agents/dqn/r2d2_torch_policy.py index 894c6dc2fb729..97c34327f7215 100644 --- a/rllib/agents/dqn/r2d2_torch_policy.py +++ b/rllib/agents/dqn/r2d2_torch_policy.py @@ -19,8 +19,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import apply_grad_clipping, FLOAT_MIN, \ - huber_loss, sequence_mask +from ray.rllib.utils.torch_ops import apply_grad_clipping, \ + concat_multi_gpu_td_errors, FLOAT_MIN, huber_loss, sequence_mask from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, nn = try_import_torch() @@ -170,16 +170,20 @@ def reduce_mean_valid(t): td_error = q_selected - target.reshape([B, T])[:, :-1].detach() td_error = td_error * seq_mask weights = weights.reshape([B, T])[:, :-1] - policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) - policy._td_error = td_error.reshape([-1]) - policy._loss_stats = { - "mean_q": reduce_mean_valid(q_selected), - "min_q": torch.min(q_selected), - "max_q": torch.max(q_selected), - "mean_td_error": reduce_mean_valid(td_error), - } + total_loss = reduce_mean_valid(weights * huber_loss(td_error)) - return policy._total_loss + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["total_loss"] = total_loss + model.tower_stats["mean_q"] = reduce_mean_valid(q_selected) + model.tower_stats["min_q"] = torch.min(q_selected) + model.tower_stats["max_q"] = torch.max(q_selected) + model.tower_stats["mean_td_error"] = reduce_mean_valid(td_error) + # Store per time chunk (b/c we need only one mean + # prioritized replay weight per stored sequence). + model.tower_stats["td_error"] = torch.mean(td_error, dim=-1) + + return total_loss def h_function(x, epsilon=1.0): @@ -233,15 +237,23 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # Do forward pass on loss to update td error attribute r2d2_loss(self, self.model, None, input_dict) - return self._td_error + return self.model.tower_stats["td_error"] self.compute_td_error = compute_td_error -def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: - return dict({ +def build_q_stats(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]: + + return { "cur_lr": policy.cur_lr, - }, **policy._loss_stats) + "total_loss": torch.mean( + torch.stack(policy.get_tower_stats("total_loss"))), + "mean_q": torch.mean(torch.stack(policy.get_tower_stats("mean_q"))), + "min_q": torch.mean(torch.stack(policy.get_tower_stats("min_q"))), + "max_q": torch.mean(torch.stack(policy.get_tower_stats("max_q"))), + "mean_td_error": torch.mean( + torch.stack(policy.get_tower_stats("mean_td_error"))), + } def setup_early_mixins(policy: Policy, obs_space, action_space, @@ -279,7 +291,7 @@ def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, postprocess_fn=postprocess_nstep_and_prio, optimizer_fn=adam_optimizer, extra_grad_process_fn=grad_process_and_td_error_fn, - extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error}, + extra_learn_fetches_fn=concat_multi_gpu_td_errors, extra_action_out_fn=extra_action_out_fn, before_init=setup_early_mixins, before_loss_init=before_loss_init, diff --git a/rllib/agents/dqn/simple_q_tf_policy.py b/rllib/agents/dqn/simple_q_tf_policy.py index 0801b6fd26e63..13e62bca1fd9a 100644 --- a/rllib/agents/dqn/simple_q_tf_policy.py +++ b/rllib/agents/dqn/simple_q_tf_policy.py @@ -181,7 +181,7 @@ def compute_q_values(policy: Policy, explore, is_training=None) -> TensorType: model_out, _ = model({ - SampleBatch.CUR_OBS: obs, + SampleBatch.OBS: obs, "is_training": is_training if is_training is not None else policy._get_is_training_placeholder(), }, [], None) diff --git a/rllib/agents/dqn/simple_q_torch_policy.py b/rllib/agents/dqn/simple_q_torch_policy.py index 055ce51598265..205fa6042e09e 100644 --- a/rllib/agents/dqn/simple_q_torch_policy.py +++ b/rllib/agents/dqn/simple_q_torch_policy.py @@ -16,7 +16,7 @@ from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import huber_loss +from ray.rllib.utils.torch_ops import concat_multi_gpu_td_errors, huber_loss from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, nn = try_import_torch() @@ -112,12 +112,20 @@ def build_q_losses(policy: Policy, model, dist_class, td_error = q_t_selected - q_t_selected_target.detach() loss = torch.mean(huber_loss(td_error)) - # save TD error as an attribute for outside access - policy.td_error = td_error + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["loss"] = loss + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error return loss +def stats_fn(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]: + return {"loss": torch.mean(torch.stack(policy.get_tower_stats("loss")))} + + def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, action_dist) -> Dict[str, TensorType]: """Adds q-values to the action out dict.""" @@ -144,10 +152,11 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, framework="torch", loss_fn=build_q_losses, get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, + stats_fn=stats_fn, extra_action_out_fn=extra_action_out_fn, after_init=setup_late_mixins, make_model_and_action_dist=build_q_model_and_distribution, mixins=[TargetNetworkMixin], action_distribution_fn=get_distribution_inputs_and_class, - extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error}, + extra_learn_fetches_fn=concat_multi_gpu_td_errors, ) diff --git a/rllib/agents/dqn/tests/test_apex_dqn.py b/rllib/agents/dqn/tests/test_apex_dqn.py index 63c051310baec..93702bf8d7c1b 100644 --- a/rllib/agents/dqn/tests/test_apex_dqn.py +++ b/rllib/agents/dqn/tests/test_apex_dqn.py @@ -4,8 +4,10 @@ import ray import ray.rllib.agents.dqn.apex as apex from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestApexDQN(unittest.TestCase): @@ -26,7 +28,9 @@ def test_apex_zero_workers(self): config["optimizer"]["num_replay_buffer_shards"] = 1 for _ in framework_iterator(config): trainer = apex.ApexTrainer(config=config, env="CartPole-v0") - trainer.train() + results = trainer.train() + check_train_results(results) + print(results) trainer.stop() def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): @@ -53,7 +57,9 @@ def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): check_compute_single_action(trainer) for i in range(2): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) # Test again per-worker epsilon distribution # (should not have changed). @@ -97,7 +103,8 @@ def _step_n_times(trainer, n: int): """ for _ in range(n): results = trainer.train() - return results["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"] + return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ + LEARNER_STATS_KEY]["cur_lr"] # Check eager execution frameworks here, since it's easier to control # exact timesteps with these frameworks. diff --git a/rllib/agents/dqn/tests/test_dqn.py b/rllib/agents/dqn/tests/test_dqn.py index dbf4876742b1f..fbf029a511243 100644 --- a/rllib/agents/dqn/tests/test_dqn.py +++ b/rllib/agents/dqn/tests/test_dqn.py @@ -4,7 +4,7 @@ import ray import ray.rllib.agents.dqn as dqn from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestDQN(unittest.TestCase): @@ -30,6 +30,7 @@ def test_dqn_compilation(self): trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) @@ -46,6 +47,7 @@ def test_dqn_compilation(self): trainer = dqn.DQNTrainer(config=rainbow_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) diff --git a/rllib/agents/dqn/tests/test_r2d2.py b/rllib/agents/dqn/tests/test_r2d2.py index d6e0d52d285e8..44b2e0887a1c5 100644 --- a/rllib/agents/dqn/tests/test_r2d2.py +++ b/rllib/agents/dqn/tests/test_r2d2.py @@ -4,7 +4,7 @@ import ray.rllib.agents.dqn as dqn from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() @@ -43,6 +43,7 @@ def test_r2d2_compilation(self): trainer = dqn.R2D2Trainer(config=config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer, include_state=True) diff --git a/rllib/agents/dqn/tests/test_simple_q.py b/rllib/agents/dqn/tests/test_simple_q.py index 12cddac283208..299bf39f63e51 100644 --- a/rllib/agents/dqn/tests/test_simple_q.py +++ b/rllib/agents/dqn/tests/test_simple_q.py @@ -10,7 +10,7 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.numpy import fc, one_hot, huber_loss from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() @@ -41,6 +41,7 @@ def test_simple_q_compilation(self): sb = rw.sample() assert sb.count == config["rollout_fragment_length"] results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) diff --git a/rllib/agents/dreamer/dreamer.py b/rllib/agents/dreamer/dreamer.py index 4a8170f527875..b3433f62cd5a0 100644 --- a/rllib/agents/dreamer/dreamer.py +++ b/rllib/agents/dreamer/dreamer.py @@ -7,11 +7,12 @@ from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - LEARNER_INFO, _get_shared_metrics + _get_shared_metrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.agents.dreamer.dreamer_model import DreamerModel from ray.rllib.execution.rollout_ops import ParallelRollouts +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.typing import SampleBatchType logger = logging.getLogger(__name__) diff --git a/rllib/agents/impala/tests/test_impala.py b/rllib/agents/impala/tests/test_impala.py index ba5e28e82073c..13e22240e34aa 100644 --- a/rllib/agents/impala/tests/test_impala.py +++ b/rllib/agents/impala/tests/test_impala.py @@ -4,8 +4,10 @@ import ray.rllib.agents.impala as impala from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.test_utils import check, \ - check_compute_single_action, framework_iterator + check_compute_single_action, check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() @@ -39,7 +41,10 @@ def test_impala_compilation(self): # to do with LSTMs, though). trainer = impala.ImpalaTrainer(config=local_cfg, env=env) for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) + check_compute_single_action( trainer, include_state=lstm, @@ -61,7 +66,8 @@ def test_impala_lr_schedule(self): config["env"] = "CartPole-v0" def get_lr(result): - return result["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"] + return result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ + LEARNER_STATS_KEY]["cur_lr"] for fw in framework_iterator(config, frameworks=("tf", "torch")): trainer = impala.ImpalaTrainer(config=config) diff --git a/rllib/agents/impala/vtrace_tf_policy.py b/rllib/agents/impala/vtrace_tf_policy.py index 99960a3206b2c..f5b5ddc4192db 100644 --- a/rllib/agents/impala/vtrace_tf_policy.py +++ b/rllib/agents/impala/vtrace_tf_policy.py @@ -111,8 +111,13 @@ def __init__(self, self.mean_entropy = tf.reduce_mean(masked_entropy) # The summed weighted loss. - self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - - self.entropy * entropy_coeff) + self.total_loss = self.pi_loss - self.entropy * entropy_coeff + + # Optional vf loss (or in a separate term due to separate + # optimizers/networks). + self.loss_wo_vf = self.total_loss + if not config["_separate_vf_optimizer"]: + self.total_loss += self.vf_loss * vf_loss_coeff def _make_time_major(policy, seq_lens, tensor, drop_last=False): @@ -220,7 +225,10 @@ def make_time_major(*args, **kw): clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) - return policy.loss.total_loss + if policy.config.get("_separate_vf_optimizer"): + return policy.loss.loss_wo_vf, policy.loss.vf_loss + else: + return policy.loss.total_loss def stats(policy, train_batch): @@ -239,13 +247,21 @@ def stats(policy, train_batch): "vf_loss": policy.loss.mean_vf_loss, "vf_explained_var": explained_variance( tf.reshape(policy.loss.value_targets, [-1]), - tf.reshape(values_batched, [-1])), + tf.reshape(values_batched, [-1])) } def grad_stats(policy, train_batch, grads): + # We have support for more than one loss (list of lists of grads). + if policy.config.get("_tf_policy_handles_more_than_one_loss"): + grad_gnorm = [tf.linalg.global_norm(g) for g in grads] + # Old case: We have a single list of grads (only one loss term and + # optimizer). + else: + grad_gnorm = tf.linalg.global_norm(grads) + return { - "grad_gnorm": tf.linalg.global_norm(grads), + "grad_gnorm": grad_gnorm, } diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index ec279cd5573b0..c8738d1875f63 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -1,10 +1,12 @@ import gym import logging import numpy as np +from typing import Any, Dict import ray import ray.rllib.agents.impala.vtrace_torch as vtrace from ray.rllib.models.torch.torch_action_dist import TorchCategorical +from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule, \ @@ -182,17 +184,22 @@ def _make_time_major(*args, **kw): clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) - # Store loss object only for multi-GPU tower 0. - if model is policy.model_gpu_towers[0]: - policy.loss = loss - values_batched = make_time_major( - policy, - train_batch.get(SampleBatch.SEQ_LENS), - values, - drop_last=policy.config["vtrace"]) - policy._vf_explained_var = explained_variance( - torch.reshape(loss.value_targets, [-1]), - torch.reshape(values_batched, [-1])), + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["pi_loss"] = loss.pi_loss + model.tower_stats["vf_loss"] = loss.vf_loss + model.tower_stats["entropy"] = loss.entropy + model.tower_stats["mean_entropy"] = loss.mean_entropy + model.tower_stats["total_loss"] = loss.total_loss + + values_batched = make_time_major( + policy, + train_batch.get(SampleBatch.SEQ_LENS), + values, + drop_last=policy.config["vtrace"]) + model.tower_stats["vf_explained_var"] = explained_variance( + torch.reshape(loss.value_targets, [-1]), + torch.reshape(values_batched, [-1])) return loss.total_loss @@ -236,15 +243,21 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False): return res -def stats(policy, train_batch): +def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, Any]: + return { "cur_lr": policy.cur_lr, - "policy_loss": policy.loss.pi_loss, - "entropy": policy.loss.mean_entropy, + "total_loss": torch.mean( + torch.stack(policy.get_tower_stats("total_loss"))), + "policy_loss": torch.mean( + torch.stack(policy.get_tower_stats("pi_loss"))), + "entropy": torch.mean( + torch.stack(policy.get_tower_stats("mean_entropy"))), "entropy_coeff": policy.entropy_coeff, "var_gnorm": global_norm(policy.model.trainable_variables()), - "vf_loss": policy.loss.vf_loss, - "vf_explained_var": policy._vf_explained_var, + "vf_loss": torch.mean(torch.stack(policy.get_tower_stats("vf_loss"))), + "vf_explained_var": torch.mean( + torch.stack(policy.get_tower_stats("vf_explained_var"))), } diff --git a/rllib/agents/maml/maml.py b/rllib/agents/maml/maml.py index 9d82a0e192cc5..c85d4f158b3c5 100644 --- a/rllib/agents/maml/maml.py +++ b/rllib/agents/maml/maml.py @@ -8,11 +8,12 @@ from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics + STEPS_TRAINED_COUNTER, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.execution.metric_ops import CollectMetrics from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.util.iter import from_actors logger = logging.getLogger(__name__) @@ -98,9 +99,10 @@ def __call__(self, data_tuple): # Metric Updating metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += samples.count + fetches = None for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) - fetches = get_learner_stats(fetches) + learner_stats = get_learner_stats(fetches) # Sync workers with meta policy self.workers.sync_weights() @@ -110,11 +112,12 @@ def __call__(self, data_tuple): # Update KLS def update(pi, pi_id): - assert "inner_kl" not in fetches, ( - "inner_kl should be nested under policy id key", fetches) - if pi_id in fetches: - assert "inner_kl" in fetches[pi_id], (fetches, pi_id) - pi.update_kls(fetches[pi_id]["inner_kl"]) + assert "inner_kl" not in learner_stats, ( + "inner_kl should be nested under policy id key", learner_stats) + if pi_id in learner_stats: + assert "inner_kl" in learner_stats[pi_id], (learner_stats, + pi_id) + pi.update_kls(learner_stats[pi_id]["inner_kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) diff --git a/rllib/agents/maml/tests/test_maml.py b/rllib/agents/maml/tests/test_maml.py index b84e028571907..e1905b5cc853f 100644 --- a/rllib/agents/maml/tests/test_maml.py +++ b/rllib/agents/maml/tests/test_maml.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.maml as maml from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestMAML(unittest.TestCase): @@ -34,7 +34,9 @@ def test_maml_compilation(self): env_ = "ray.rllib.examples.env.{}".format(env) trainer = maml.MAMLTrainer(config=config, env=env_) for i in range(num_iterations): - trainer.train() + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action( trainer, include_prev_action_reward=True) trainer.stop() diff --git a/rllib/agents/marwil/tests/test_bc.py b/rllib/agents/marwil/tests/test_bc.py index c6508330e43de..d6ac234897839 100644 --- a/rllib/agents/marwil/tests/test_bc.py +++ b/rllib/agents/marwil/tests/test_bc.py @@ -6,7 +6,7 @@ import ray.rllib.agents.marwil as marwil from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() @@ -51,7 +51,11 @@ def test_bc_compilation_and_learning_from_offline_file(self): trainer = marwil.BCTrainer(config=config, env="CartPole-v0") learnt = False for i in range(num_iterations): - eval_results = trainer.train().get("evaluation") + results = trainer.train() + check_train_results(results) + print(results) + + eval_results = results.get("evaluation") if eval_results: print("iter={} R={}".format( i, eval_results["episode_reward_mean"])) diff --git a/rllib/agents/marwil/tests/test_marwil.py b/rllib/agents/marwil/tests/test_marwil.py index 29c6b678ecf2c..b8ca7af86ae21 100644 --- a/rllib/agents/marwil/tests/test_marwil.py +++ b/rllib/agents/marwil/tests/test_marwil.py @@ -9,7 +9,7 @@ from ray.rllib.offline import JsonReader from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -57,7 +57,11 @@ def test_marwil_compilation_and_learning_from_offline_file(self): trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0") learnt = False for i in range(num_iterations): - eval_results = trainer.train().get("evaluation") + results = trainer.train() + check_train_results(results) + print(results) + + eval_results = results.get("evaluation") if eval_results: print("iter={} R={} ".format( i, eval_results["episode_reward_mean"])) diff --git a/rllib/agents/mbmpo/mbmpo.py b/rllib/agents/mbmpo/mbmpo.py index 0a537213ac193..aaf2d835e6c1f 100644 --- a/rllib/agents/mbmpo/mbmpo.py +++ b/rllib/agents/mbmpo/mbmpo.py @@ -26,10 +26,11 @@ get_learner_stats from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics + STEPS_TRAINED_COUNTER, _get_shared_metrics from ray.rllib.execution.metric_ops import CollectMetrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.sgd import standardized from ray.rllib.utils.torch_ops import convert_to_torch_tensor from ray.rllib.utils.typing import EnvType, TrainerConfigDict @@ -160,17 +161,19 @@ def __call__(self, data_tuple): adapt_metrics_dict, prefix="MAMLIter{}".format(self.step_counter)) # MAML Meta-update. + fetches = None for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) - fetches = get_learner_stats(fetches) + learner_stats = get_learner_stats(fetches) # Update KLs. def update(pi, pi_id): - assert "inner_kl" not in fetches, ( - "inner_kl should be nested under policy id key", fetches) - if pi_id in fetches: - assert "inner_kl" in fetches[pi_id], (fetches, pi_id) - pi.update_kls(fetches[pi_id]["inner_kl"]) + assert "inner_kl" not in learner_stats, ( + "inner_kl should be nested under policy id key", learner_stats) + if pi_id in learner_stats: + assert "inner_kl" in learner_stats[pi_id], (learner_stats, + pi_id) + pi.update_kls(learner_stats[pi_id]["inner_kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) diff --git a/rllib/agents/mbmpo/tests/test_mbmpo.py b/rllib/agents/mbmpo/tests/test_mbmpo.py index de708fd50d58c..941686c3e717b 100644 --- a/rllib/agents/mbmpo/tests/test_mbmpo.py +++ b/rllib/agents/mbmpo/tests/test_mbmpo.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.mbmpo as mbmpo from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestMBMPO(unittest.TestCase): @@ -28,8 +28,12 @@ def test_mbmpo_compilation(self): trainer = mbmpo.MBMPOTrainer( config=config, env="ray.rllib.examples.env.mbmpo_env.CartPoleWrapper") + for i in range(num_iterations): - trainer.train() + results = trainer.train() + check_train_results(results) + print(results) + check_compute_single_action( trainer, include_prev_action_reward=False) trainer.stop() diff --git a/rllib/agents/pg/pg_torch_policy.py b/rllib/agents/pg/pg_torch_policy.py index d707f01f2364e..34a17c5e03f97 100644 --- a/rllib/agents/pg/pg_torch_policy.py +++ b/rllib/agents/pg/pg_torch_policy.py @@ -44,11 +44,15 @@ def pg_torch_loss( # L = -E[ log(pi(a|s)) * A] log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS]) - # Save the loss in the policy object for the stats_fn below. - policy.pi_err = -torch.mean( + # Final policy loss. + policy_loss = -torch.mean( log_probs * train_batch[Postprocessing.ADVANTAGES]) - return policy.pi_err + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["policy_loss"] = policy_loss + + return policy_loss def pg_loss_stats(policy: Policy, @@ -64,8 +68,8 @@ def pg_loss_stats(policy: Policy, """ return { - # `pi_err` (the loss) is stored inside `pg_torch_loss()`. - "policy_loss": policy.pi_err.item(), + "policy_loss": torch.mean( + torch.stack(policy.get_tower_stats("policy_loss"))), } diff --git a/rllib/agents/pg/tests/test_pg.py b/rllib/agents/pg/tests/test_pg.py index 44a52829beaf3..40b985cc8e488 100644 --- a/rllib/agents/pg/tests/test_pg.py +++ b/rllib/agents/pg/tests/test_pg.py @@ -7,8 +7,9 @@ from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils import check, check_compute_single_action, fc, \ - framework_iterator +from ray.rllib.utils.numpy import fc +from ray.rllib.utils.test_utils import check, check_compute_single_action, \ + check_train_results, framework_iterator class TestPG(unittest.TestCase): @@ -31,7 +32,10 @@ def test_pg_compilation(self): for env in ["FrozenLake-v0", "CartPole-v0"]: trainer = pg.PGTrainer(config=config, env=env) for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) + check_compute_single_action( trainer, include_prev_action_reward=True) diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index 455044bebfe1d..142b96d6e247f 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -304,7 +304,7 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: "vf_loss": policy._mean_vf_loss, "vf_explained_var": explained_variance( tf.reshape(policy._value_targets, [-1]), - tf.reshape(values_batched, [-1])), + tf.reshape(values_batched, [-1])) } if policy.config["vtrace"]: diff --git a/rllib/agents/ppo/appo_torch_policy.py b/rllib/agents/ppo/appo_torch_policy.py index f8ee24989d825..324b73bf5a6b7 100644 --- a/rllib/agents/ppo/appo_torch_policy.py +++ b/rllib/agents/ppo/appo_torch_policy.py @@ -159,7 +159,7 @@ def reduce_mean_valid(t): torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) - mean_kl = reduce_mean_valid(action_kl) + mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. @@ -188,7 +188,7 @@ def reduce_mean_valid(t): torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) - mean_kl = reduce_mean_valid(action_kl) + mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. @@ -208,16 +208,17 @@ def reduce_mean_valid(t): # Optional additional KL Loss if policy.config["use_kl_loss"]: - total_loss += policy.kl_coeff * mean_kl - - policy._total_loss = total_loss - policy._mean_policy_loss = mean_policy_loss - # Backward compatibility: Deprecate policy._mean_kl. - policy._mean_kl_loss = policy._mean_kl = mean_kl - policy._mean_vf_loss = mean_vf_loss - policy._mean_entropy = mean_entropy - policy._value_targets = value_targets - policy._vf_explained_var = explained_variance( + total_loss += policy.kl_coeff * mean_kl_loss + + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["total_loss"] = total_loss + model.tower_stats["mean_policy_loss"] = mean_policy_loss + model.tower_stats["mean_kl_loss"] = mean_kl_loss + model.tower_stats["mean_vf_loss"] = mean_vf_loss + model.tower_stats["mean_entropy"] = mean_entropy + model.tower_stats["value_targets"] = value_targets + model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] @@ -239,22 +240,28 @@ def stats(policy: Policy, train_batch: SampleBatch): """ stats_dict = { "cur_lr": policy.cur_lr, - "policy_loss": policy._mean_policy_loss, - "entropy": policy._mean_entropy, + "total_loss": torch.mean( + torch.stack(policy.get_tower_stats("total_loss"))), + "policy_loss": torch.mean( + torch.stack(policy.get_tower_stats("mean_policy_loss"))), + "entropy": torch.mean( + torch.stack(policy.get_tower_stats("mean_entropy"))), "var_gnorm": global_norm(policy.model.trainable_variables()), - "vf_loss": policy._mean_vf_loss, - "vf_explained_var": policy._vf_explained_var, + "vf_loss": torch.mean( + torch.stack(policy.get_tower_stats("mean_vf_loss"))), + "vf_explained_var": torch.mean( + torch.stack(policy.get_tower_stats("vf_explained_var"))), } if policy.config["vtrace"]: is_stat_mean = torch.mean(policy._is_ratio, [0, 1]) is_stat_var = torch.var(policy._is_ratio, [0, 1]) - stats_dict.update({"mean_IS": is_stat_mean}) - stats_dict.update({"var_IS": is_stat_var}) + stats_dict["mean_IS"] = is_stat_mean + stats_dict["var_IS"] = is_stat_var if policy.config["use_kl_loss"]: - stats_dict.update({"kl": policy._mean_kl_loss}) - stats_dict.update({"KL_Coeff": policy.kl_coeff}) + stats_dict["kl"] = policy.get_tower_stats("mean_kl_loss") + stats_dict["KL_Coeff"] = policy.kl_coeff return stats_dict diff --git a/rllib/agents/ppo/ddppo.py b/rllib/agents/ppo/ddppo.py index d3eee646999e4..b7c15918b16fe 100644 --- a/rllib/agents/ppo/ddppo.py +++ b/rllib/agents/ppo/ddppo.py @@ -26,9 +26,10 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - STEPS_TRAINED_COUNTER, LEARNER_INFO, LEARN_ON_BATCH_TIMER, \ + STEPS_TRAINED_COUNTER, LEARN_ON_BATCH_TIMER, \ _get_shared_metrics, _get_global_vars from ray.rllib.evaluation.rollout_worker import get_global_worker +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -75,6 +76,11 @@ "truncate_episodes": True, # This is auto set based on sample batch size. "train_batch_size": -1, + # Kl divergence penalty should be fixed to 0 in DDPPO because in order + # for it to be used as a penalty, we would have to un-decentralize + # DDPPO + "kl_coeff": 0.0, + "kl_target": 0.0 }, _allow_unknown_configs=True, ) @@ -131,6 +137,13 @@ def validate_config(config): raise ValueError( "Distributed data parallel requires truncate_episodes " "batch mode.") + # DDPPO doesn't support KL penalties like PPO-1. + # In order to support KL penalties, DDPPO would need to become + # undecentralized, which defeats the purpose of the algorithm. + # Users can still tune the entropy coefficient to control the + # policy entropy (similar to controlling the KL penalty). + if config["kl_coeff"] != 0.0 or config["kl_target"] != 0.0: + raise ValueError("DDPPO doesn't support KL penalties like PPO-1") def execution_plan(workers: WorkerSet, diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index e0ced5d82cdeb..e43d460087b84 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -20,9 +20,11 @@ StandardizeFields, SelectExperiences from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep from ray.rllib.execution.metric_ops import StandardMetricsReporting -from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -217,12 +219,12 @@ def warn_about_bad_reward_scales(config, result): return result # Punt on handling multiagent case. # Warn about excessively high VF loss. - learner_stats = result["info"]["learner"] - if DEFAULT_POLICY_ID in learner_stats: + learner_info = result["info"][LEARNER_INFO] + if DEFAULT_POLICY_ID in learner_info: scaled_vf_loss = config["vf_loss_coeff"] * \ - learner_stats[DEFAULT_POLICY_ID][LEARNER_STATS_KEY]["vf_loss"] + learner_info[DEFAULT_POLICY_ID][LEARNER_STATS_KEY]["vf_loss"] - policy_loss = learner_stats[DEFAULT_POLICY_ID][LEARNER_STATS_KEY][ + policy_loss = learner_info[DEFAULT_POLICY_ID][LEARNER_STATS_KEY][ "policy_loss"] if config.get("model", {}).get("vf_share_layers") and \ scaled_vf_loss > 100: diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index f8f310e6b07e3..69e19e33d7817 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -105,15 +105,15 @@ def reduce_mean_valid(t): policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) - # Store stats in policy for stats_fn. - policy._total_loss = total_loss - policy._mean_policy_loss = mean_policy_loss - policy._mean_vf_loss = mean_vf_loss - policy._vf_explained_var = explained_variance( + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["total_loss"] = total_loss + model.tower_stats["mean_policy_loss"] = mean_policy_loss + model.tower_stats["mean_vf_loss"] = mean_vf_loss + model.tower_stats["vf_explained_var"] = explained_variance( train_batch[Postprocessing.VALUE_TARGETS], model.value_function()) - policy._mean_entropy = mean_entropy - # Backward compatibility: Deprecate policy._mean_kl. - policy._mean_kl_loss = policy._mean_kl = mean_kl_loss + model.tower_stats["mean_entropy"] = mean_entropy + model.tower_stats["mean_kl_loss"] = mean_kl_loss return total_loss @@ -132,12 +132,17 @@ def kl_and_loss_stats(policy: Policy, return { "cur_kl_coeff": policy.kl_coeff, "cur_lr": policy.cur_lr, - "total_loss": policy._total_loss, - "policy_loss": policy._mean_policy_loss, - "vf_loss": policy._mean_vf_loss, - "vf_explained_var": policy._vf_explained_var, - "kl": policy._mean_kl_loss, - "entropy": policy._mean_entropy, + "total_loss": torch.mean( + torch.stack(policy.get_tower_stats("total_loss"))), + "policy_loss": torch.mean( + torch.stack(policy.get_tower_stats("mean_policy_loss"))), + "vf_loss": torch.mean( + torch.stack(policy.get_tower_stats("mean_vf_loss"))), + "vf_explained_var": torch.mean( + torch.stack(policy.get_tower_stats("vf_explained_var"))), + "kl": torch.mean(torch.stack(policy.get_tower_stats("mean_kl_loss"))), + "entropy": torch.mean( + torch.stack(policy.get_tower_stats("mean_entropy"))), "entropy_coeff": policy.entropy_coeff, } diff --git a/rllib/agents/ppo/tests/test_appo.py b/rllib/agents/ppo/tests/test_appo.py index 32a5989263f7c..be007f3dd9995 100644 --- a/rllib/agents/ppo/tests/test_appo.py +++ b/rllib/agents/ppo/tests/test_appo.py @@ -3,7 +3,7 @@ import ray import ray.rllib.agents.ppo as ppo from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestAPPO(unittest.TestCase): @@ -27,7 +27,9 @@ def test_appo_compilation(self): _config["vtrace"] = False trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0") for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() @@ -36,7 +38,9 @@ def test_appo_compilation(self): _config["vtrace"] = True trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0") for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() @@ -55,10 +59,12 @@ def test_appo_two_tf_optimizers(self): num_iterations = 2 # Only supported for tf so far. - for _ in framework_iterator(config, frameworks="tf"): + for _ in framework_iterator(config, frameworks=("tf2", "tf")): trainer = ppo.APPOTrainer(config=config, env="CartPole-v0") for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/agents/ppo/tests/test_ddppo.py b/rllib/agents/ppo/tests/test_ddppo.py index e1191cfb2cd35..0e8154a662d12 100644 --- a/rllib/agents/ppo/tests/test_ddppo.py +++ b/rllib/agents/ppo/tests/test_ddppo.py @@ -1,11 +1,13 @@ import unittest +import pytest import ray import ray.rllib.agents.ppo as ppo from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.policy.policy import LEARNER_STATS_KEY +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestDDPPO(unittest.TestCase): @@ -26,7 +28,9 @@ def test_ddppo_compilation(self): for _ in framework_iterator(config, frameworks="torch"): trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") for i in range(num_iterations): - trainer.train() + results = trainer.train() + check_train_results(results) + print(results) # Make sure, weights on all workers are the same (including # local one). weights = trainer.workers.foreach_worker( @@ -48,13 +52,25 @@ def test_ddppo_schedule(self): trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") for _ in range(num_iterations): result = trainer.train() - lr = result["info"]["learner"][DEFAULT_POLICY_ID][ + lr = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ LEARNER_STATS_KEY]["cur_lr"] trainer.stop() assert lr == 0.0, "lr should anneal to 0.0" + def test_validate_config(self): + """Test if DDPPO will raise errors after invalid configs are passed.""" + config = ppo.ddppo.DEFAULT_CONFIG.copy() + config["kl_coeff"] = 1. + msg = "DDPPO doesn't support KL penalties like PPO-1" + # import ipdb; ipdb.set_trace() + with pytest.raises(ValueError, match=msg): + ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") + config["kl_coeff"] = 0. + config["kl_target"] = 1. + with pytest.raises(ValueError, match=msg): + ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0") + if __name__ == "__main__": - import pytest import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index 2dfcec41010b5..198922ee7a338 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -14,11 +14,12 @@ from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_action_dist import TorchCategorical -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.numpy import fc -from ray.rllib.utils.test_utils import check, framework_iterator, \ - check_compute_single_action +from ray.rllib.utils.test_utils import check, check_compute_single_action, \ + check_train_results, framework_iterator # Fake CartPole episode of n time steps. FAKE_BATCH = SampleBatch({ @@ -59,7 +60,8 @@ def _check_lr_tf(policy, policy_id): assert lr == optim_lr, "LR scheduling error!" def on_train_result(self, *, trainer, result: dict, **kwargs): - stats = result["info"]["learner"][DEFAULT_POLICY_ID][LEARNER_STATS_KEY] + stats = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ + LEARNER_STATS_KEY] # Learning rate should go to 0 after 1 iter. check(stats["cur_lr"], 5e-5 if trainer.iteration == 1 else 0.0) # Entropy coeff goes to 0.05, then 0.0 (per iter). @@ -90,7 +92,7 @@ def test_ppo_compilation_and_schedule_mixins(self): config["model"]["lstm_cell_size"] = 10 config["model"]["max_seq_len"] = 20 # Use default-native keras models whenever possible. - config["model"]["_use_default_native_models"] = True + # config["model"]["_use_default_native_models"] = True # Setup lr- and entropy schedules for testing. config["lr_schedule"] = [[0, config["lr"]], [128, 0.0]] @@ -124,7 +126,9 @@ def test_ppo_compilation_and_schedule_mixins(self): check(lr, config["lr"]) for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action( trainer, @@ -313,6 +317,19 @@ def test_ppo_loss_function(self): check(pl, np.mean(-pg_loss)) check(v, np.mean(vf_loss), decimals=4) check(tl, overall_loss, decimals=4) + elif fw == "torch": + check(policy.model.tower_stats["mean_kl_loss"], kl) + check(policy.model.tower_stats["mean_entropy"], entropy) + check(policy.model.tower_stats["mean_policy_loss"], + np.mean(-pg_loss)) + check( + policy.model.tower_stats["mean_vf_loss"], + np.mean(vf_loss), + decimals=4) + check( + policy.model.tower_stats["total_loss"], + overall_loss, + decimals=4) else: check(policy._mean_kl_loss, kl) check(policy._mean_entropy, entropy) diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index 6c1078cbad314..ca0324ce4d08f 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -8,7 +8,6 @@ from ray.rllib.agents.qmix.model import RNNModel, _get_size from ray.rllib.env.multi_agent_env import ENV_STATE from ray.rllib.env.wrappers.group_agents_wrapper import GROUP_REWARDS -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import chop_into_sequences @@ -16,6 +15,7 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import _unpack_obs from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.annotations import override # Torch must be installed. diff --git a/rllib/agents/sac/rnnsac.py b/rllib/agents/sac/rnnsac.py index 3fb67e50d8ced..79bf6cdc816bc 100644 --- a/rllib/agents/sac/rnnsac.py +++ b/rllib/agents/sac/rnnsac.py @@ -11,10 +11,6 @@ { # Batch mode (see common config) "batch_mode": "complete_episodes", - # If True prioritized replay buffer will be used. - "prioritized_replay": False, - # RNNSAC does not suport n-step > 1 yet! - "n_step": 1, # If True, assume a zero-initialized state input (no matter where in # the episode the sequence is located). # If False, store the initial states along with each SampleBatch, use @@ -50,9 +46,6 @@ def validate_config(config: TrainerConfigDict) -> None: config["replay_sequence_length"] = \ config["burn_in"] + config["model"]["max_seq_len"] - if config["n_step"] > 1: - raise ValueError("`n_step` > 1 not yet supported by RNNSAC!") - def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: """Policy class picker function. Class is chosen based on DL-framework. diff --git a/rllib/agents/sac/rnnsac_torch_policy.py b/rllib/agents/sac/rnnsac_torch_policy.py index c0d223c0a4766..faef59e1bee67 100644 --- a/rllib/agents/sac/rnnsac_torch_policy.py +++ b/rllib/agents/sac/rnnsac_torch_policy.py @@ -371,6 +371,7 @@ def reduce_mean_valid(t): critic_loss.append( reduce_mean_valid( train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) + td_error = td_error * seq_mask # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. @@ -401,26 +402,21 @@ def reduce_mean_valid(t): actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) - # Save for stats function. - policy.q_t = q_t * seq_mask[..., None] - policy.policy_t = policy_t * seq_mask[..., None] - policy.log_pis_t = log_pis_t * seq_mask[..., None] - - # Store td-error in model, such that for multi-GPU, we do not override - # them during the parallel loss phase. TD-error tensor in final stats - # can then be concatenated and retrieved for each individual batch item. - model.td_error = td_error * seq_mask - - policy.actor_loss = actor_loss - policy.critic_loss = critic_loss - policy.alpha_loss = alpha_loss - policy.log_alpha_value = model.log_alpha - policy.alpha_value = alpha - policy.target_entropy = model.target_entropy + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["q_t"] = q_t * seq_mask[..., None] + model.tower_stats["policy_t"] = policy_t * seq_mask[..., None] + model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None] + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + model.tower_stats["alpha_loss"] = alpha_loss + # Store per time chunk (b/c we need only one mean + # prioritized replay weight per stored sequence). + model.tower_stats["td_error"] = torch.mean( + td_error.reshape([-1, T]), dim=-1) # Return all loss terms corresponding to our optimizers. - return tuple([policy.actor_loss] + policy.critic_loss + - [policy.alpha_loss]) + return tuple([actor_loss] + critic_loss + [alpha_loss]) RNNSACTorchPolicy = SACTorchPolicy.with_updates( diff --git a/rllib/agents/sac/sac_tf_model.py b/rllib/agents/sac/sac_tf_model.py index 546de04ab47c9..0b78f65a526fb 100644 --- a/rllib/agents/sac/sac_tf_model.py +++ b/rllib/agents/sac/sac_tf_model.py @@ -1,6 +1,7 @@ import gym from gym.spaces import Box, Discrete import numpy as np +import tree # pip install dm_tree from typing import Dict, List, Optional from ray.rllib.models.catalog import ModelCatalog @@ -267,13 +268,18 @@ def get_policy_output(self, model_out: TensorType) -> TensorType: Returns: TensorType: Distribution inputs for sampling actions. """ - # Model outs may come as original Tuple observations, concat them + # Model outs may come as original Tuple/Dict observations, concat them # here if this is the case. if isinstance(self.action_model.obs_space, Box): if isinstance(model_out, (list, tuple)): model_out = tf.concat(model_out, axis=-1) elif isinstance(model_out, dict): - model_out = tf.concat(list(model_out.values()), axis=-1) + model_out = tf.concat( + [ + tf.expand_dims(val, 1) if len(val.shape) == 1 else val + for val in tree.flatten(model_out.values()) + ], + axis=-1) out, _ = self.action_model({"obs": model_out}, [], None) return out diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index 111d8b717f494..629de0efce536 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -6,7 +6,6 @@ from gym.spaces import Box, Discrete from functools import partial import logging -import numpy as np from typing import Dict, List, Optional, Tuple, Type, Union import ray @@ -53,9 +52,6 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, target model will be created in this function and assigned to `policy.target_model`. """ - # With separate state-preprocessor (before obs+action concat). - num_outputs = int(np.product(obs_space.shape)) - # Force-ignore any additionally provided hidden layer sizes. # Everything should be configured using SAC's "Q_model" and "policy_model" # settings. @@ -70,7 +66,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, - num_outputs=num_outputs, + num_outputs=None, model_config=config["model"], framework=config["framework"], default_model=default_model_cls, @@ -90,7 +86,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, policy.target_model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, - num_outputs=num_outputs, + num_outputs=None, model_config=config["model"], framework=config["framework"], default_model=default_model_cls, diff --git a/rllib/agents/sac/sac_torch_model.py b/rllib/agents/sac/sac_torch_model.py index 64bbb40920453..1fdc09412da13 100644 --- a/rllib/agents/sac/sac_torch_model.py +++ b/rllib/agents/sac/sac_torch_model.py @@ -1,6 +1,7 @@ import gym from gym.spaces import Box, Discrete import numpy as np +import tree # pip install dm_tree from typing import Dict, List, Optional from ray.rllib.models.catalog import ModelCatalog @@ -281,7 +282,12 @@ def get_policy_output(self, model_out: TensorType) -> TensorType: if isinstance(model_out, (list, tuple)): model_out = torch.cat(model_out, dim=-1) elif isinstance(model_out, dict): - model_out = torch.cat(list(model_out.values()), dim=-1) + model_out = torch.cat( + [ + torch.unsqueeze(val, 1) if len(val.shape) == 1 else val + for val in tree.flatten(model_out.values()) + ], + dim=-1) out, _ = self.action_model({"obs": model_out}, [], None) return out diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index 6bfdb98decc7b..dee2693abf29e 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -5,6 +5,7 @@ import gym from gym.spaces import Box, Discrete import logging +import tree # pip install dm_tree from typing import Dict, List, Optional, Tuple, Type, Union import ray @@ -314,26 +315,21 @@ def actor_critic_loss( # the Q-net(s)' variables. actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy) - # Save for stats function. - policy.q_t = q_t - policy.policy_t = policy_t - policy.log_pis_t = log_pis_t + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["q_t"] = q_t + model.tower_stats["policy_t"] = policy_t + model.tower_stats["log_pis_t"] = log_pis_t + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + model.tower_stats["alpha_loss"] = alpha_loss - # Store td-error in model, such that for multi-GPU, we do not override - # them during the parallel loss phase. TD-error tensor in final stats - # can then be concatenated and retrieved for each individual batch item. - model.td_error = td_error - - policy.actor_loss = actor_loss - policy.critic_loss = critic_loss - policy.alpha_loss = alpha_loss - policy.log_alpha_value = model.log_alpha - policy.alpha_value = alpha - policy.target_entropy = model.target_entropy + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error # Return all loss terms corresponding to our optimizers. - return tuple([policy.actor_loss] + policy.critic_loss + - [policy.alpha_loss]) + return tuple([actor_loss] + critic_loss + [alpha_loss]) def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: @@ -346,17 +342,23 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: Returns: Dict[str, TensorType]: The stats dict. """ + q_t = torch.stack(policy.get_tower_stats("q_t")) + return { - "actor_loss": torch.mean(policy.actor_loss), - "critic_loss": torch.mean(torch.stack(policy.critic_loss)), - "alpha_loss": torch.mean(policy.alpha_loss), - "alpha_value": torch.mean(policy.alpha_value), - "log_alpha_value": torch.mean(policy.log_alpha_value), - "target_entropy": policy.target_entropy, - "policy_t": torch.mean(policy.policy_t), - "mean_q": torch.mean(policy.q_t), - "max_q": torch.max(policy.q_t), - "min_q": torch.min(policy.q_t), + "actor_loss": torch.mean( + torch.stack(policy.get_tower_stats("actor_loss"))), + "critic_loss": torch.mean( + torch.stack(tree.flatten(policy.get_tower_stats("critic_loss")))), + "alpha_loss": torch.mean( + torch.stack(policy.get_tower_stats("alpha_loss"))), + "alpha_value": torch.exp(policy.model.log_alpha), + "log_alpha_value": policy.model.log_alpha, + "target_entropy": policy.model.target_entropy, + "policy_t": torch.mean( + torch.stack(policy.get_tower_stats("policy_t"))), + "mean_q": torch.mean(q_t), + "max_q": torch.max(q_t), + "min_q": torch.min(q_t), } @@ -430,9 +432,9 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, # (one TD-error value per item in batch to update PR weights). actor_critic_loss(self, self.model, None, input_dict) - # `self.td_error` is set within actor_critic_loss call. Return - # its updated value here. - return self.td_error + # `self.model.td_error` is set within actor_critic_loss call. + # Return its updated value here. + return self.model.tower_stats["td_error"] # Assign the method to policy (self) for later usage. self.compute_td_error = compute_td_error diff --git a/rllib/agents/sac/tests/test_rnnsac.py b/rllib/agents/sac/tests/test_rnnsac.py new file mode 100644 index 0000000000000..f0e8c5a750c57 --- /dev/null +++ b/rllib/agents/sac/tests/test_rnnsac.py @@ -0,0 +1,73 @@ +import unittest + +import ray +import ray.rllib.agents.sac as sac +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.test_utils import check_compute_single_action, \ + framework_iterator + +tf1, tf, tfv = try_import_tf() +torch, nn = try_import_torch() + + +class TestRNNSAC(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_rnnsac_compilation(self): + """Test whether a R2D2Trainer can be built on all frameworks.""" + config = sac.RNNSAC_DEFAULT_CONFIG.copy() + config["num_workers"] = 0 # Run locally. + + # Wrap with an LSTM and use a very simple base-model. + config["model"] = { + "max_seq_len": 20, + } + config["policy_model"] = { + "use_lstm": True, + "lstm_cell_size": 64, + "fcnet_hiddens": [10], + "lstm_use_prev_action": True, + "lstm_use_prev_reward": True, + } + config["Q_model"] = { + "use_lstm": True, + "lstm_cell_size": 64, + "fcnet_hiddens": [10], + "lstm_use_prev_action": True, + "lstm_use_prev_reward": True, + } + + # Test with PR activated. + config["prioritized_replay"] = True + + config["burn_in"] = 20 + config["zero_init_states"] = True + + config["lr"] = 5e-4 + + num_iterations = 1 + + # Test building an RNNSAC agent in all frameworks. + for _ in framework_iterator(config, frameworks="torch"): + trainer = sac.RNNSACTrainer(config=config, env="CartPole-v0") + for i in range(num_iterations): + results = trainer.train() + print(results) + + check_compute_single_action( + trainer, + include_state=True, + include_prev_action_reward=True, + ) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/sac/tests/test_sac.py b/rllib/agents/sac/tests/test_sac.py index d9b1de208af33..06083b33e3fa9 100644 --- a/rllib/agents/sac/tests/test_sac.py +++ b/rllib/agents/sac/tests/test_sac.py @@ -1,5 +1,5 @@ from gym import Env -from gym.spaces import Box, Discrete, Tuple +from gym.spaces import Box, Dict, Discrete, Tuple import numpy as np import re import unittest @@ -21,8 +21,9 @@ from ray.rllib.utils.numpy import fc, huber_loss, relu from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator from ray.rllib.utils.torch_ops import convert_to_torch_tensor +from ray import tune tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -71,8 +72,6 @@ def test_sac_compilation(self): config["num_workers"] = 0 # Run locally. config["n_step"] = 3 config["twin_q"] = True - config["clip_actions"] = False - config["normalize_actions"] = True config["learning_starts"] = 0 config["prioritized_replay"] = True config["rollout_fragment_length"] = 10 @@ -92,22 +91,28 @@ def test_sac_compilation(self): image_space = Box(-1.0, 1.0, shape=(84, 84, 3)) simple_space = Box(-1.0, 1.0, shape=(3, )) + tune.register_env( + "random_dict_env", lambda _: RandomEnv({ + "observation_space": Dict({ + "a": simple_space, + "b": Discrete(2), + "c": image_space, }), + "action_space": Box(-1.0, 1.0, shape=(1, )), })) + tune.register_env( + "random_tuple_env", lambda _: RandomEnv({ + "observation_space": Tuple([ + simple_space, Discrete(2), image_space]), + "action_space": Box(-1.0, 1.0, shape=(1, )), })) + for fw in framework_iterator(config): # Test for different env types (discrete w/ and w/o image, + cont). for env in [ - RandomEnv, + "random_dict_env", + "random_tuple_env", "MsPacmanNoFrameskip-v4", "CartPole-v0", ]: print("Env={}".format(env)) - if env == RandomEnv: - config["env_config"] = { - "observation_space": Tuple((simple_space, Discrete(2), - image_space)), - "action_space": Box(-1.0, 1.0, shape=(1, )), - } - else: - config["env_config"] = {} # Test making the Q-model a custom one for CartPole, otherwise, # use the default model. config["Q_model"]["custom_model"] = "batch_norm{}".format( @@ -116,6 +121,7 @@ def test_sac_compilation(self): trainer = sac.SACTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() + check_train_results(results) print(results) check_compute_single_action(trainer) @@ -306,8 +312,10 @@ def test_sac_loss_function(self): elif fw == "torch": loss_torch(policy, policy.model, None, input_) - c, a, e, t = policy.critic_loss, policy.actor_loss, \ - policy.alpha_loss, policy.model.td_error + c, a, e, t = policy.get_tower_stats("critic_loss")[0], \ + policy.get_tower_stats("actor_loss")[0], \ + policy.get_tower_stats("alpha_loss")[0], \ + policy.get_tower_stats("td_error")[0] # Test actor gradients. policy.actor_optim.zero_grad() diff --git a/rllib/agents/tests/test_trainer.py b/rllib/agents/tests/test_trainer.py index baf7b665963ca..937206deac138 100644 --- a/rllib/agents/tests/test_trainer.py +++ b/rllib/agents/tests/test_trainer.py @@ -13,6 +13,7 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.examples.parallel_evaluation_and_training import \ AssertNumEvalEpisodesCallback +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import framework_iterator @@ -72,7 +73,7 @@ def test_add_delete_policy(self): trainer = pg.PGTrainer(config=config) pol0 = trainer.get_policy("p0") r = trainer.train() - self.assertTrue("p0" in r["info"]["learner"]) + self.assertTrue("p0" in r["info"][LEARNER_INFO]) for i in range(1, 3): def new_mapping_fn(agent_id, episode, worker, **kwargs): diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 7147ba9ea85c7..a1f4b64ee2426 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -8,22 +8,24 @@ import pickle import tempfile import time -from typing import Callable, Dict, List, Optional, Type, Union +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import ray from ray.actor import ActorHandle from ray.exceptions import RayError from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.utils import gym_env_creator from ray.rllib.evaluation.collectors.simple_list_collector import \ SimpleListCollector +from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.policy.policy import Policy, PolicySpec -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils import deep_update, FilterManager, merge_dicts from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override, \ PublicAPI @@ -36,7 +38,7 @@ from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \ PartialTrainerConfigDict, PolicyID, ResultDict, TensorStructType, \ - TrainerConfigDict + TensorType, TrainerConfigDict from ray.tune.logger import Logger, UnifiedLogger from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.resources import Resources @@ -1007,17 +1009,29 @@ def _sync_weights_to_workers( @PublicAPI def compute_single_action( self, - observation: TensorStructType, - state: List[TensorStructType] = None, - prev_action: TensorStructType = None, - prev_reward: float = None, - info: EnvInfoDict = None, + observation: Optional[TensorStructType] = None, + state: Optional[List[TensorStructType]] = None, + *, + prev_action: Optional[TensorStructType] = None, + prev_reward: Optional[float] = None, + info: Optional[EnvInfoDict] = None, + input_dict: Optional[SampleBatch] = None, policy_id: PolicyID = DEFAULT_POLICY_ID, full_fetch: bool = False, - explore: bool = None, - unsquash_actions: Optional[bool] = None, - clip_actions: Optional[bool] = None, - ) -> TensorStructType: + explore: Optional[bool] = None, + timestep: Optional[int] = None, + episode: Optional[MultiAgentEpisode] = None, + unsquash_action: Optional[bool] = None, + clip_action: Optional[bool] = None, + + # Deprecated args. + unsquash_actions=DEPRECATED_VALUE, + clip_actions=DEPRECATED_VALUE, + + # Kwargs placeholder for future compatibility. + **kwargs, + ) -> Union[TensorStructType, Tuple[TensorStructType, List[TensorType], + Dict[str, TensorType]]]: """Computes an action for the specified policy on the local worker. Note that you can also access the policy object through @@ -1025,70 +1039,123 @@ def compute_single_action( directly. Args: - observation (TensorStructType): observation from the environment. - state (List[TensorStructType]): RNN hidden state, if any. If state - is not None, then all of compute_single_action(...) is returned - (computed action, rnn state(s), logits dictionary). - Otherwise compute_single_action(...)[0] is returned - (computed action). - prev_action (TensorStructType): Previous action value, if any. - prev_reward (float): Previous reward, if any. - info (EnvInfoDict): info object, if any - policy_id (PolicyID): Policy to query (only applies to - multi-agent). - full_fetch (bool): Whether to return extra action fetch results. - This is always set to True if RNN state is specified. - explore (bool): Whether to pick an exploitation or exploration - action (default: None -> use self.config["explore"]). - unsquash_actions (bool): Should actions be unsquashed according to - the env's/Policy's action space? - clip_actions (bool): Should actions be clipped according to the - env's/Policy's action space? + observation: Single (unbatched) observation from the + environment. + state: List of all RNN hidden (single, unbatched) state tensors. + prev_action: Single (unbatched) previous action value. + prev_reward: Single (unbatched) previous reward value. + info: Env info dict, if any. + input_dict: An optional SampleBatch that holds all the values + for: obs, state, prev_action, and prev_reward, plus maybe + custom defined views of the current env trajectory. Note + that only one of `obs` or `input_dict` must be non-None. + policy_id: Policy to query (only applies to multi-agent). + Default: "default_policy". + full_fetch: Whether to return extra action fetch results. + This is always set to True if `state` is specified. + explore: Whether to apply exploration to the action. + Default: None -> use self.config["explore"]. + timestep: The current (sampling) time step. + episode: This provides access to all of the internal episodes' + state, which may be useful for model-based or multi-agent + algorithms. + unsquash_action: Should actions be unsquashed according to the + env's/Policy's action space? If None, use the value of + self.config["normalize_actions"]. + clip_action: Should actions be clipped according to the + env's/Policy's action space? If None, use the value of + self.config["clip_actions"]. + + Keyword Args: + kwargs: forward compatibility placeholder Returns: - any: The computed action if full_fetch=False, or - tuple: The full output of policy.compute_actions() if - full_fetch=True or we have an RNN-based Policy. + The computed action if full_fetch=False, or a tuple of a) the + full output of policy.compute_actions() if full_fetch=True + or we have an RNN-based Policy. Raises: KeyError: If the `policy_id` cannot be found in this Trainer's local worker. """ + if clip_actions != DEPRECATED_VALUE: + deprecation_warning( + old="Trainer.compute_single_action(`clip_actions`=...)", + new="Trainer.compute_single_action(`clip_action`=...)", + error=False) + clip_action = clip_actions + if unsquash_actions != DEPRECATED_VALUE: + deprecation_warning( + old="Trainer.compute_single_action(`unsquash_actions`=...)", + new="Trainer.compute_single_action(`unsquash_action`=...)", + error=False) + unsquash_action = unsquash_actions + + # User provided an input-dict: Assert that `obs`, `prev_a|r`, `state` + # are all None. + err_msg = "Provide either `input_dict` OR [`observation`, ...] as " \ + "args to Trainer.compute_single_action!" + if input_dict is not None: + assert observation is None and prev_action is None and \ + prev_reward is None and state is None, err_msg + observation = input_dict[SampleBatch.OBS] + else: + assert observation is not None, err_msg + + # Get the policy to compute the action for (in the multi-agent case, + # Trainer may hold >1 policies). policy = self.get_policy(policy_id) if policy is None: raise KeyError( f"PolicyID '{policy_id}' not found in PolicyMap of the " f"Trainer's local worker!") - local_worker = self.workers.local_worker() - if state is None: - state = [] - # Check the preprocessor and preprocess, if necessary. pp = local_worker.preprocessors[policy_id] if pp and type(pp).__name__ != "NoPreprocessor": observation = pp.transform(observation) - filtered_observation = local_worker.filters[policy_id]( + observation = local_worker.filters[policy_id]( observation, update=False) - # Compute the action. - result = policy.compute_single_action( - filtered_observation, - state, - prev_action, - prev_reward, - info, - unsquash_actions=unsquash_actions, - clip_actions=clip_actions, - explore=explore) + # Input-dict. + if input_dict is not None: + input_dict[SampleBatch.OBS] = observation + action, state, extra = policy.compute_single_action( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episode=episode, + ) + # Individual args. + else: + action, state, extra = policy.compute_single_action( + obs=observation, + state=state, + prev_action=prev_action, + prev_reward=prev_reward, + info=info, + explore=explore, + timestep=timestep, + episode=episode, + ) + + # If we work in normalized action space (normalize_actions=True), + # we re-translate here into the env's action space. + if unsquash_action: + action = space_utils.unsquash_action(action, + policy.action_space_struct) + # Clip, according to env's action space. + elif clip_action: + action = space_utils.clip_action(action, + policy.action_space_struct) # Return 3-Tuple: Action, states, and extra-action fetches. if state or full_fetch: - return result + return action, state, extra # Ensure backward compatibility. else: - return result[0] + return action @Deprecated(new="compute_single_action", error=False) def compute_action(self, *args, **kwargs): @@ -1098,15 +1165,21 @@ def compute_action(self, *args, **kwargs): def compute_actions( self, observations: TensorStructType, - state: List[TensorStructType] = None, - prev_action: TensorStructType = None, - prev_reward: TensorStructType = None, - info=None, - policy_id=DEFAULT_POLICY_ID, - full_fetch=False, - explore=None, + state: Optional[List[TensorStructType]] = None, + *, + prev_action: Optional[TensorStructType] = None, + prev_reward: Optional[TensorStructType] = None, + info: Optional[EnvInfoDict] = None, + policy_id: PolicyID = DEFAULT_POLICY_ID, + full_fetch: bool = False, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + episodes: Optional[List[MultiAgentEpisode]] = None, + unsquash_actions: Optional[bool] = None, + clip_actions: Optional[bool] = None, + # Deprecated. normalize_actions=None, - clip_actions=None, + **kwargs, ): """Computes an action for the specified policy on the local Worker. @@ -1114,30 +1187,46 @@ def compute_actions( self.get_policy(policy_id) and call compute_actions() on it directly. Args: - observation (obj): observation from the environment. - state (dict): RNN hidden state, if any. If state is not None, + observation: observation from the environment. + state: RNN hidden state, if any. If state is not None, then all of compute_single_action(...) is returned (computed action, rnn state(s), logits dictionary). Otherwise compute_single_action(...)[0] is returned (computed action). - prev_action (obj): previous action value, if any - prev_reward (int): previous reward, if any - info (dict): info object, if any - policy_id (str): Policy to query (only applies to multi-agent). - full_fetch (bool): Whether to return extra action fetch results. + prev_action: Previous action value, if any. + prev_reward: Previous reward, if any. + info: Env info dict, if any. + policy_id: Policy to query (only applies to multi-agent). + full_fetch: Whether to return extra action fetch results. This is always set to True if RNN state is specified. - explore (bool): Whether to pick an exploitation or exploration + explore: Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). - normalize_actions (bool): Should actions be unsquashed according - to the env's/Policy's action space? - clip_actions (bool): Should actions be clipped according to the - env's/Policy's action space? + timestep: The current (sampling) time step. + episodes: This provides access to all of the internal episodes' + state, which may be useful for model-based or multi-agent + algorithms. + unsquash_actions: Should actions be unsquashed according + to the env's/Policy's action space? If None, use + self.config["normalize_actions"]. + clip_actions: Should actions be clipped according to the + env's/Policy's action space? If None, use + self.config["clip_actions"]. + + Keyword Args: + kwargs: forward compatibility placeholder Returns: any: The computed action if full_fetch=False, or tuple: The full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy. """ + if normalize_actions is not None: + deprecation_warning( + old="Trainer.compute_actions(`normalize_actions`=...)", + new="Trainer.compute_actions(`unsquash_actions`=...)", + error=False) + unsquash_actions = normalize_actions + # Preprocess obs and states. state_defined = state is not None policy = self.get_policy(policy_id) @@ -1162,23 +1251,38 @@ def compute_actions( state = list(zip(*filtered_state)) state = [np.stack(s) for s in state] + input_dict = {SampleBatch.OBS: obs_batch} + if prev_action: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action + if prev_reward: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward + if info: + input_dict[SampleBatch.INFOS] = info + for i, s in enumerate(state): + input_dict[f"state_in_{i}"] = s + # Batch compute actions - actions, states, infos = policy.compute_actions( - obs_batch, - state, - prev_action, - prev_reward, - info, - normalize_actions=normalize_actions, - clip_actions=clip_actions, - explore=explore) - - # Unbatch actions for the environment - atns, actions = space_utils.unbatch(actions), {} - for key, atn in zip(observations, atns): - actions[key] = atn - - # Unbatch states into a dict + actions, states, infos = policy.compute_actions_from_input_dict( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episodes=episodes, + ) + + # Unbatch actions for the environment into a multi-agent dict. + single_actions = space_utils.unbatch(actions) + actions = {} + for key, a in zip(observations, single_actions): + # If we work in normalized action space (normalize_actions=True), + # we re-translate here into the env's action space. + if unsquash_actions: + a = space_utils.unsquash_action(a, policy.action_space_struct) + # Clip, according to env's action space. + elif clip_actions: + a = space_utils.clip_action(a, policy.action_space_struct) + actions[key] = a + + # Unbatch states into a multi-agent dict. unbatched_states = {} for idx, agent_id in enumerate(observations): unbatched_states[agent_id] = [s[idx] for s in states] @@ -1403,6 +1507,7 @@ def collect_metrics(self, selected_workers=selected_workers) @classmethod + @override(Trainable) def resource_help(cls, config: TrainerConfigDict) -> str: return ("\n\nYou can adjust the resource requests of RLlib agents by " "setting `num_workers`, `num_gpus`, and other configs. See " @@ -1738,23 +1843,25 @@ def with_updates(**overrides) -> Type["Trainer"]: "build_trainer()` function!") def _register_if_needed(self, env_object: Union[str, EnvType, None], - config): + config) -> Optional[str]: if isinstance(env_object, str): return env_object elif isinstance(env_object, type): name = env_object.__name__ - # Add convenience `_get_spaces` method. + if config.get("remote_worker_envs"): - def _get_spaces(s): - return s.observation_space, s.action_space + @ray.remote(num_cpus=0) + class _wrapper(env_object): + # Add convenience `_get_spaces` and `_is_multi_agent` + # methods. + def _get_spaces(self): + return self.observation_space, self.action_space - env_object._get_spaces = _get_spaces + def _is_multi_agent(self): + return isinstance(self, MultiAgentEnv) - if config.get("remote_worker_envs"): - register_env( - name, - lambda cfg: ray.remote(num_cpus=0)(env_object).remote(cfg)) + register_env(name, lambda cfg: _wrapper.remote(cfg)) else: register_env(name, lambda cfg: env_object(cfg)) return name diff --git a/rllib/contrib/alpha_zero/core/alpha_zero_policy.py b/rllib/contrib/alpha_zero/core/alpha_zero_policy.py index 7b3b46e74747e..ad97829c04ba3 100644 --- a/rllib/contrib/alpha_zero/core/alpha_zero_policy.py +++ b/rllib/contrib/alpha_zero/core/alpha_zero_policy.py @@ -1,10 +1,11 @@ import numpy as np -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.contrib.alpha_zero.core.mcts import Node, RootParentNode from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY torch, _ = try_import_torch() @@ -39,9 +40,9 @@ def compute_actions(self, **kwargs): input_dict = {"obs": obs_batch} - if prev_action_batch: + if prev_action_batch is not None: input_dict["prev_actions"] = prev_action_batch - if prev_reward_batch: + if prev_reward_batch is not None: input_dict["prev_rewards"] = prev_reward_batch return self.compute_actions_from_input_dict( diff --git a/rllib/contrib/bandits/agents/policy.py b/rllib/contrib/bandits/agents/policy.py index e47c91005232c..07d837b4fc150 100644 --- a/rllib/contrib/bandits/agents/policy.py +++ b/rllib/contrib/bandits/agents/policy.py @@ -9,11 +9,11 @@ ParametricLinearModelThompsonSampling, ParametricLinearModelUCB from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import restore_original_dimensions -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.util.debug import log_once logger = logging.getLogger(__name__) diff --git a/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py b/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py index dfe3b8c85156d..4501a04357fee 100644 --- a/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py +++ b/rllib/contrib/bandits/examples/LinTS_train_wheel_env.py @@ -7,6 +7,7 @@ from ray.rllib.contrib.bandits.agents import LinTSTrainer from ray.rllib.contrib.bandits.envs import WheelBanditEnv +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO def plot_model_weights(means, covs): @@ -43,7 +44,7 @@ def plot_model_weights(means, covs): trainer.train() info = trainer.train() - print(info["info"]["learner"]) + print(info["info"][LEARNER_INFO]) # Get model parameters means = [model.arms[i].theta.numpy() for i in range(5)] diff --git a/rllib/contrib/maddpg/maddpg_policy.py b/rllib/contrib/maddpg/maddpg_policy.py index 86e417e5d3112..51a02f35afaea 100644 --- a/rllib/contrib/maddpg/maddpg_policy.py +++ b/rllib/contrib/maddpg/maddpg_policy.py @@ -1,6 +1,5 @@ import ray from ray.rllib.agents.dqn.dqn_tf_policy import minimize_and_clip -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.postprocessing import adjust_nstep from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch @@ -9,6 +8,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils.framework import try_import_tf, try_import_tfp +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY import logging from gym.spaces import Box, Discrete diff --git a/rllib/contrib/sumo/connector.py b/rllib/contrib/sumo/connector.py index 6b1d3d1d47e35..0b795c45c8421 100644 --- a/rllib/contrib/sumo/connector.py +++ b/rllib/contrib/sumo/connector.py @@ -162,7 +162,7 @@ def _stopping_condition(self, current_step_counter, until_end): return True return False - def step(self, until_end=False, agents=set()): + def step(self, until_end=False, agents=None): """ Runs a "learning" step and returns if the simulation has finished. This function in meant to be called by the RLLIB Environment. @@ -176,6 +176,9 @@ def step(self, until_end=False, agents=set()): Return: Bool. True iff the simulation is still ongoing. """ + if agents is None: + agents = set() + # Execute SUMO steps until the learning needs to happen current_step_counter = 0 logger.debug( diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 4b2c77fe1532b..8ee302eb24683 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -1,5 +1,6 @@ from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING +import ray from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -121,10 +122,14 @@ def to_base_env( env = _VectorEnvToBaseEnv(env) else: if remote_envs: + # Determine, whether the already existing sub-env (could + # be a ray.actor) is multi-agent or not. + multiagent = ray.get(env._is_multi_agent.remote()) if \ + hasattr(env, "_is_multi_agent") else False env = RemoteVectorEnv( make_env, num_envs, - multiagent=False, + multiagent=multiagent, remote_env_batch_wait_ms=remote_env_batch_wait_ms, existing_envs=[env], ) diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index ed5705bf725d0..4840de357585a 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -163,7 +163,8 @@ def make_multi_agent(env_name_or_creator): """ class MultiEnv(MultiAgentEnv): - def __init__(self, config): + def __init__(self, config=None): + config = config or {} num = config.pop("num_agents", 1) if isinstance(env_name_or_creator, str): self.agents = [ diff --git a/rllib/env/policy_server_input.py b/rllib/env/policy_server_input.py index 26e96673adb5c..c7148a94a8a2d 100644 --- a/rllib/env/policy_server_input.py +++ b/rllib/env/policy_server_input.py @@ -89,10 +89,18 @@ def get_metrics(): # and sends data and metrics into the queues. handler = _make_handler(self.rollout_worker, self.samples_queue, self.metrics_queue) - HTTPServer.__init__(self, (address, port), handler) - - logger.info("Starting connector server at {}:{}".format( - self.server_name, self.server_port)) + try: + import time + time.sleep(1) + HTTPServer.__init__(self, (address, port), handler) + except OSError: + print(f"Creating a PolicyServer on {address}:{port} failed!") + import time + time.sleep(1) + raise + + logger.info("Starting connector server at " + f"{self.server_name}:{self.server_port}") # Start the serving thread, listening on socket and handling commands. serving_thread = threading.Thread( diff --git a/rllib/env/remote_vector_env.py b/rllib/env/remote_vector_env.py index aa2e958efee5a..2d09302f59c15 100644 --- a/rllib/env/remote_vector_env.py +++ b/rllib/env/remote_vector_env.py @@ -29,6 +29,8 @@ def __init__(self, existing_envs: Optional[List[ray.actor.ActorHandle]] = None): # Could be creating local or remote envs. self.make_env = make_env + # Whether the given `make_env` callable already returns ray.remote + # objects or not. self.make_env_creates_actors = False # Already existing env objects (generated by the RolloutWorker). self.existing_envs = existing_envs or [] @@ -50,9 +52,13 @@ def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, self.actors = [] while len(self.actors) < self.num_envs: self.actors.append(self.make_env(len(self.actors))) - # `self.make_env` produces gym.Envs (or other similar types, such + # `self.make_env` produces gym.Envs (or children thereof, such # as MultiAgentEnv): Need to auto-wrap it here. The problem with - # this is that custom methods wil get lost. + # this is that custom methods wil get lost. If you would like to + # keep your custom methods in your envs, you should provide the + # env class directly in your config (w/o tune.register_env()), + # such that your class will directly be made a @ray.remote + # (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`). else: def make_remote_env(i): @@ -125,7 +131,15 @@ def make_remote_env(i): def send_actions(self, action_dict: MultiEnvDict) -> None: for env_id, actions in action_dict.items(): actor = self.actors[env_id] - obj_ref = actor.step.remote(actions) + # `actor` is a simple single-agent (remote) env, e.g. a gym.Env + # that was made a @ray.remote. + if not self.multiagent and self.make_env_creates_actors: + obj_ref = actor.step.remote(actions[_DUMMY_AGENT_ID]) + # `actor` is already a _RemoteSingleAgentEnv or + # _RemoteMultiAgentEnv wrapper + # (handles the multi-agent action_dict automatically). + else: + obj_ref = actor.step.remote(actions) self.pending[obj_ref] = actor @override(BaseEnv) diff --git a/rllib/env/tests/test_local_inference.sh b/rllib/env/tests/test_local_inference.sh deleted file mode 100755 index be910f173c620..0000000000000 --- a/rllib/env/tests/test_local_inference.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -rm -f last_checkpoint.out -pkill -f cartpole_server.py -sleep 1 - -if [ -f test_local_inference.sh ]; then - basedir="../../examples/serving" -else - basedir="rllib/examples/serving" # In bazel. -fi - -# Start server with 2 workers (will listen on ports 9900 and 9901 for client -# connections). -# Do not attempt to restore from checkpoint; leads to errors on travis. -(python $basedir/cartpole_server.py --run=PPO --num-workers=2 --no-restore 2>&1 | grep -v 200) & -server_pid=$! - -echo "Waiting for server to start" -while ! curl localhost:9900; do - sleep 1 -done -while ! curl localhost:9901; do - sleep 1 -done - -# Start client 1 (port 9900). -sleep 2 -(python $basedir/cartpole_client.py --inference-mode=local --port=9900) & -client1_pid=$! - -# Start client 2 (port 9901). -sleep 2 -(python $basedir/cartpole_client.py --inference-mode=local --port=9901) & -client2_pid=$! - -# Start client 3 (also port 9901) and run it until it reaches 150.0 -# reward. Then stop everything. -sleep 2 -python $basedir/cartpole_client.py --stop-reward=150.0 --inference-mode=local --port=9901 - -kill $server_pid $client1_pid $client2_pid || true diff --git a/rllib/env/tests/test_policy_client_server_setup.sh b/rllib/env/tests/test_policy_client_server_setup.sh new file mode 100755 index 0000000000000..4d458ee5b8dba --- /dev/null +++ b/rllib/env/tests/test_policy_client_server_setup.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +rm -f last_checkpoint.out + +if [ "$1" == "local" ]; then + inference_mode=local +else + inference_mode=remote +fi + +if [ "$2" == "cartpole" ]; then + server_script=cartpole_server.py + client_script=cartpole_client.py + stop_criterion="--stop-reward=150.0" +else + server_script=unity3d_server.py + client_script=unity3d_dummy_client.py + stop_criterion="--num-episodes=10" +fi + +pkill -f $server_script +sleep 1 + +if [ -f test_policy_client_server_setup.sh ]; then + basedir="../../examples/serving" +else + basedir="rllib/examples/serving" # In bazel. +fi + + +# Start server with 2 workers (will listen on ports 9900 and 9901 for client +# connections). +# Do not attempt to restore from checkpoint; leads to errors on travis. +(python $basedir/$server_script --run=PPO --num-workers=2 --no-restore 2>&1 | grep -v 200) & +server_pid=$! + +echo "Waiting for server to start ..." +while ! curl localhost:9900; do + sleep 1 +done +echo "Remote worker #1 on port 9900 is up!" +while ! curl localhost:9901; do + sleep 1 +done +echo "Remote worker #2 on port 9901 is up!" + +# Start client 1 (connect to port 9900). +sleep 2 +(python $basedir/$client_script --inference-mode=$inference_mode --port=9900) & +client1_pid=$! + +# Start client 2 (connect to port 9901). +sleep 2 +(python $basedir/$client_script --inference-mode=$inference_mode --port=9901) & +client2_pid=$! + +# Start client 3 (also connecting to port 9901) and run it until it reaches +# x reward (CartPole) or n episodes (dummy Unity3D). +# Then stop everything. +sleep 2 +python $basedir/$client_script $stop_criterion --inference-mode=$inference_mode --port=9901 + +kill $server_pid $client1_pid $client2_pid || true diff --git a/rllib/env/tests/test_remote_inference.sh b/rllib/env/tests/test_remote_inference.sh deleted file mode 100755 index 1a9ead838576c..0000000000000 --- a/rllib/env/tests/test_remote_inference.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -rm -f last_checkpoint.out -pkill -f cartpole_server.py -sleep 1 - -if [ -f test_local_inference.sh ]; then - basedir="../../examples/serving" -else - basedir="rllib/examples/serving" # In bazel. -fi - -# Do not attempt to restore from checkpoint; leads to errors on travis. -(python $basedir/cartpole_server.py --run=DQN --num-workers=2 --no-restore 2>&1 | grep -v 200) & -server_pid=$! - -echo "Waiting for server to start" -while ! curl localhost:9900; do - sleep 1 -done -while ! curl localhost:9901; do - sleep 1 -done - -# Start client 1 (port 9900). -sleep 2 -(python $basedir/cartpole_client.py --inference-mode=remote --port=9900) & -client1_pid=$! - -# Start client 2 (port 9901). -sleep 2 -(python $basedir/cartpole_client.py --inference-mode=remote --port=9901) & -client2_pid=$! - -# Start client 3 (also port 9901) and run it until it reaches 150.0 -# reward. Then stop everything. -sleep 2 -python $basedir/cartpole_client.py --stop-reward=150.0 --inference-mode=remote --port=9901 - -kill $server_pid $client1_pid $client2_pid || true - diff --git a/rllib/env/tests/test_remote_worker_envs.py b/rllib/env/tests/test_remote_worker_envs.py new file mode 100644 index 0000000000000..ba80c7e4cede1 --- /dev/null +++ b/rllib/env/tests/test_remote_worker_envs.py @@ -0,0 +1,98 @@ +import gym +import numpy as np +from pettingzoo.butterfly import pistonball_v4 +from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0 +import unittest + +import ray +from ray.rllib.agents.pg import pg +from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv +from ray.rllib.examples.env.random_env import RandomEnv, RandomMultiAgentEnv +from ray.rllib.examples.remote_vector_env_with_custom_api import \ + NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv +from ray import tune + + +# Function that outputs the environment you wish to register. +def env_creator(config): + env = pistonball_v4.env(local_ratio=config.get("local_ratio", 0.2)) + env = dtype_v0(env, dtype=np.float32) + env = color_reduction_v0(env, mode="R") + env = normalize_obs_v0(env) + return env + + +tune.register_env("cartpole", lambda env_ctx: gym.make("CartPole-v0")) + +tune.register_env("pistonball", + lambda config: PettingZooEnv(env_creator(config))) + + +class TestRemoteWorkerEnvSetting(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init(num_cpus=4) + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_remote_worker_env(self): + config = pg.DEFAULT_CONFIG.copy() + config["remote_worker_envs"] = True + config["num_envs_per_worker"] = 4 + + # Simple string env definition (gym.make(...)). + config["env"] = "CartPole-v0" + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + # Using tune.register. + config["env"] = "cartpole" + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + # Using class directly. + config["env"] = RandomEnv + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + # Using class directly: Sub-class of gym.Env, + # which implements its own API. + config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + def test_remote_worker_env_multi_agent(self): + config = pg.DEFAULT_CONFIG.copy() + config["remote_worker_envs"] = True + config["num_envs_per_worker"] = 4 + + # Full classpath provided. + config["env"] = \ + "ray.rllib.examples.env.random_env.RandomMultiAgentEnv" + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + # Using tune.register. + config["env"] = "pistonball" + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + # Using class directly. + config["env"] = RandomMultiAgentEnv + trainer = pg.PGTrainer(config=config) + print(trainer.train()) + trainer.stop() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/wrappers/unity3d_env.py b/rllib/env/wrappers/unity3d_env.py index 2ec8fd6282945..2f9f75e79cb1e 100644 --- a/rllib/env/wrappers/unity3d_env.py +++ b/rllib/env/wrappers/unity3d_env.py @@ -6,6 +6,7 @@ from typing import Callable, Optional, Tuple from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID logger = logging.getLogger(__name__) @@ -304,10 +305,12 @@ def get_policy_configs_for_game( # Policies (Unity: "behaviors") and agent-to-policy mapping fns. if game_name == "SoccerStrikersVsGoalie": policies = { - "Goalie": (None, obs_spaces["Goalie"], action_spaces["Goalie"], - {}), - "Striker": (None, obs_spaces["Striker"], - action_spaces["Striker"], {}), + "Goalie": PolicySpec( + observation_space=obs_spaces["Goalie"], + action_space=action_spaces["Goalie"]), + "Striker": PolicySpec( + observation_space=obs_spaces["Striker"], + action_space=action_spaces["Striker"]), } def policy_mapping_fn(agent_id, episode, worker, **kwargs): @@ -315,8 +318,9 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): else: policies = { - game_name: (None, obs_spaces[game_name], - action_spaces[game_name], {}), + game_name: PolicySpec( + observation_space=obs_spaces[game_name], + action_space=action_spaces[game_name]), } def policy_mapping_fn(agent_id, episode, worker, **kwargs): diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 40745251e2b64..7c5415375d230 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -756,8 +756,11 @@ def postprocess_episode( "True. Alternatively, set no_done_at_end=True to " "allow this.") - other_batches = pre_batches.copy() - del other_batches[agent_id] + if len(pre_batches) > 1: + other_batches = pre_batches.copy() + del other_batches[agent_id] + else: + other_batches = {} pid = self.agent_key_to_policy_id[(episode_id, agent_id)] policy = self.policy_map[pid] if any(pre_batch[SampleBatch.DONES][:-1]) or len( diff --git a/rllib/evaluation/metrics.py b/rllib/evaluation/metrics.py index 06afe96d3fc6f..73c25f916f0bb 100644 --- a/rllib/evaluation/metrics.py +++ b/rllib/evaluation/metrics.py @@ -9,8 +9,8 @@ from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict if TYPE_CHECKING: @@ -42,7 +42,6 @@ def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict: >>> print(get_stats(grad_info)) {"vf_loss": ..., "policy_loss": ...} """ - if LEARNER_STATS_KEY in grad_info: return grad_info[LEARNER_STATS_KEY] @@ -57,10 +56,15 @@ def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict: @DeveloperAPI def collect_metrics(local_worker: Optional["RolloutWorker"] = None, - remote_workers: List[ActorHandle] = [], - to_be_collected: List[ObjectRef] = [], + remote_workers: Optional[List[ActorHandle]] = None, + to_be_collected: Optional[List[ObjectRef]] = None, timeout_seconds: int = 180) -> ResultDict: """Gathers episode metrics from RolloutWorker instances.""" + if remote_workers is None: + remote_workers = [] + + if to_be_collected is None: + to_be_collected = [] episodes, to_be_collected = collect_episodes( local_worker, @@ -74,11 +78,16 @@ def collect_metrics(local_worker: Optional["RolloutWorker"] = None, @DeveloperAPI def collect_episodes( local_worker: Optional["RolloutWorker"] = None, - remote_workers: List[ActorHandle] = [], - to_be_collected: List[ObjectRef] = [], + remote_workers: Optional[List[ActorHandle]] = None, + to_be_collected: Optional[List[ObjectRef]] = None, timeout_seconds: int = 180 ) -> Tuple[List[Union[RolloutMetrics, OffPolicyEstimate]], List[ObjectRef]]: """Gathers new episodes metrics tuples from the given evaluators.""" + if remote_workers is None: + remote_workers = [] + + if to_be_collected is None: + to_be_collected = [] if remote_workers: pending = [ diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index a703b9f0a66e1..7151851587f73 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -860,14 +860,15 @@ def compute_gradients( summarize(samples))) if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} - if self.tf_sess is not None: - builder = TFRunBuilder(self.tf_sess, "compute_gradients") + if self.policy_config.get("framework") == "tf": for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue + policy = self.policy_map[pid] + builder = TFRunBuilder(policy.get_session(), + "compute_gradients") grad_out[pid], info_out[pid] = ( - self.policy_map[pid]._build_compute_gradients( - builder, batch)) + policy._build_compute_gradients(builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: @@ -897,14 +898,21 @@ def apply_gradients(self, grads: ModelGradients) -> Dict[PolicyID, Any]: if log_once("apply_gradients"): logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) if isinstance(grads, dict): - if self.tf_sess is not None: - builder = TFRunBuilder(self.tf_sess, "apply_gradients") - outputs = { - pid: self.policy_map[pid]._build_apply_gradients( - builder, grad) - for pid, grad in grads.items() + if self.policy_config.get("framework") == "tf": + builders = {} + outputs = {} + for pid, grad in grads.items(): + if pid not in self.policies_to_train: + continue + policy = self.policy_map[pid] + builders[pid] = TFRunBuilder(policy.get_session(), + "apply_gradients") + outputs[pid] = policy._build_apply_gradients( + builders[pid], grad) + return { + pid: builders[pid].get(op) + for pid, op in outputs.items() } - return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 0737355dc0dfe..09fdb3b968dea 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -179,7 +179,7 @@ def central_vf_stats(policy, train_batch, grads): return { "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], - policy._central_value_out), + policy._central_value_out) } diff --git a/rllib/examples/custom_keras_model.py b/rllib/examples/custom_keras_model.py index cec793dd17bb6..c1c419d50e545 100644 --- a/rllib/examples/custom_keras_model.py +++ b/rllib/examples/custom_keras_model.py @@ -11,9 +11,10 @@ from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY tf1, tf, tfv = try_import_tf() @@ -110,7 +111,7 @@ def metrics(self): # Tests https://github.com/ray-project/ray/issues/7293 def check_has_custom_metric(result): - r = result["result"]["info"]["learner"] + r = result["result"]["info"][LEARNER_INFO] if DEFAULT_POLICY_ID in r: r = r[DEFAULT_POLICY_ID].get(LEARNER_STATS_KEY, r[DEFAULT_POLICY_ID]) diff --git a/rllib/examples/custom_model_loss_and_metrics.py b/rllib/examples/custom_model_loss_and_metrics.py index 9cea42cdf639a..6a38084f01188 100644 --- a/rllib/examples/custom_model_loss_and_metrics.py +++ b/rllib/examples/custom_model_loss_and_metrics.py @@ -19,9 +19,10 @@ from ray.rllib.examples.models.custom_loss_model import CustomLossModel, \ TorchCustomLossModel from ray.rllib.models import ModelCatalog -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY tf1, tf, tfv = try_import_tf() @@ -83,9 +84,9 @@ # Torch metrics structure. if args.framework == "torch": - assert LEARNER_STATS_KEY in info["learner"][DEFAULT_POLICY_ID] - assert "model" in info["learner"][DEFAULT_POLICY_ID] - assert "custom_metrics" in info["learner"][DEFAULT_POLICY_ID] + assert LEARNER_STATS_KEY in info[LEARNER_INFO][DEFAULT_POLICY_ID] + assert "model" in info[LEARNER_INFO][DEFAULT_POLICY_ID] + assert "custom_metrics" in info[LEARNER_INFO][DEFAULT_POLICY_ID] # TODO: (sven) Make sure the metrics structure gets unified between # tf and torch. Tf should work like current torch: @@ -96,4 +97,5 @@ # model: [return values of ModelV2's `metrics` method] # custom_metrics: [return values of callback: `on_learn_on_batch`] else: - assert "model" in info["learner"][DEFAULT_POLICY_ID][LEARNER_STATS_KEY] + assert "model" in info[LEARNER_INFO][DEFAULT_POLICY_ID][ + LEARNER_STATS_KEY] diff --git a/rllib/examples/deterministic_training.py b/rllib/examples/deterministic_training.py index 528e002971c43..e6fd21e56a9c3 100644 --- a/rllib/examples/deterministic_training.py +++ b/rllib/examples/deterministic_training.py @@ -10,6 +10,7 @@ from ray.rllib.examples.env.env_using_remote_actor import \ CartPoleWithRemoteParamServer, ParameterStorage from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import check parser = argparse.ArgumentParser() @@ -60,6 +61,7 @@ check(results1["hist_stats"], results2["hist_stats"]) # As well as training behavior (minibatch sequence during SGD # iterations). - check(results1["info"]["learner"][DEFAULT_POLICY_ID]["learner_stats"], - results2["info"]["learner"][DEFAULT_POLICY_ID]["learner_stats"]) + check( + results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"], + results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"]) ray.shutdown() diff --git a/rllib/examples/env/coin_game_non_vectorized_env.py b/rllib/examples/env/coin_game_non_vectorized_env.py index 5d725ade56d5d..e773bab36a6b9 100644 --- a/rllib/examples/env/coin_game_non_vectorized_env.py +++ b/rllib/examples/env/coin_game_non_vectorized_env.py @@ -13,7 +13,7 @@ from gym.utils import seeding from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.utils import override -from typing import Dict +from typing import Dict, Optional from ray.rllib.examples.env.utils.interfaces import InfoAccumulationInterface @@ -36,7 +36,9 @@ class CoinGame(InfoAccumulationInterface, MultiAgentEnv, gym.Env): np.array([-1, 0]), ] - def __init__(self, config: Dict = {}): + def __init__(self, config: Optional[Dict] = None): + if config is None: + config = {} self._validate_config(config) @@ -325,7 +327,10 @@ def _init_info(self): class AsymCoinGame(CoinGame): NAME = "AsymCoinGame" - def __init__(self, config: dict = {}): + def __init__(self, config: Optional[dict] = None): + if config is None: + config = {} + if "asymmetric" in config: assert config["asymmetric"] else: diff --git a/rllib/examples/env/coin_game_vectorized_env.py b/rllib/examples/env/coin_game_vectorized_env.py index a71fa4327d399..546a9b1a815b0 100644 --- a/rllib/examples/env/coin_game_vectorized_env.py +++ b/rllib/examples/env/coin_game_vectorized_env.py @@ -21,7 +21,9 @@ class VectorizedCoinGame(CoinGame): Vectorized Coin Game environment. """ - def __init__(self, config={}): + def __init__(self, config=None): + if config is None: + config = {} super().__init__(config) @@ -159,7 +161,10 @@ def _load_env(self, env_state): class AsymVectorizedCoinGame(VectorizedCoinGame): NAME = "AsymCoinGame" - def __init__(self, config={}): + def __init__(self, config=None): + if config is None: + config = {} + if "asymmetric" in config: assert config["asymmetric"] else: diff --git a/rllib/examples/env/matrix_sequential_social_dilemma.py b/rllib/examples/env/matrix_sequential_social_dilemma.py index 97d222b3cff20..9348a184890b8 100644 --- a/rllib/examples/env/matrix_sequential_social_dilemma.py +++ b/rllib/examples/env/matrix_sequential_social_dilemma.py @@ -8,7 +8,7 @@ import logging from abc import ABC from collections import Iterable -from typing import Dict +from typing import Dict, Optional import numpy as np from gym.spaces import Discrete @@ -39,7 +39,9 @@ class MatrixSequentialSocialDilemma(InfoAccumulationInterface, MultiAgentEnv, episode. """ - def __init__(self, config: Dict = {}): + def __init__(self, config: Optional[Dict] = None): + if config is None: + config = {} assert "reward_randomness" not in config.keys() assert self.PAYOUT_MATRIX is not None diff --git a/rllib/examples/env/random_env.py b/rllib/examples/env/random_env.py index b6b451fef7c33..ceeca23424c24 100644 --- a/rllib/examples/env/random_env.py +++ b/rllib/examples/env/random_env.py @@ -14,7 +14,9 @@ class RandomEnv(gym.Env): configured as well. """ - def __init__(self, config): + def __init__(self, config=None): + config = config or {} + # Action space. self.action_space = config.get("action_space", Discrete(2)) # Observation space from which to sample. @@ -63,3 +65,25 @@ def step(self, action): # Multi-agent version of the RandomEnv. RandomMultiAgentEnv = make_multi_agent(lambda c: RandomEnv(c)) + + +# Large observation space "pre-compiled" random env (for testing). +class RandomLargeObsSpaceEnv(RandomEnv): + def __init__(self, config=None): + config = config or {} + config.update({ + "observation_space": gym.spaces.Box(-1.0, 1.0, (5000, )) + }) + super().__init__(config=config) + + +# Large observation space + cont. actions "pre-compiled" random env +# (for testing). +class RandomLargeObsSpaceEnvContActions(RandomEnv): + def __init__(self, config=None): + config = config or {} + config.update({ + "observation_space": gym.spaces.Box(-1.0, 1.0, (5000, )), + "action_space": gym.spaces.Box(-1.0, 1.0, (5, )), + }) + super().__init__(config=config) diff --git a/rllib/examples/pettingzoo_env.py b/rllib/examples/pettingzoo_env.py index 5eeb962200849..661f03f012088 100644 --- a/rllib/examples/pettingzoo_env.py +++ b/rllib/examples/pettingzoo_env.py @@ -42,19 +42,17 @@ def env_creator(config): # Register env register_env("pistonball", lambda config: PettingZooEnv(env_creator(config))) - env = PettingZooEnv(env_creator(config)) - observation_space = env.observation_space - action_space = env.action_space - del env # Configuration for multiagent setup with policy sharing: config["multiagent"] = { - # Setup a single, shared policy for all agents. - "policies": { - "av": (None, observation_space, action_space, {}) - }, - # Map all agents to that policy. - "policy_mapping_fn": lambda agent_id, episode, **kwargs: "av", + # Setup a single, shared policy for all agents: "av". + # Use a simple set of strings (PolicyID) here. RLlib will + # automatically determine the policy class (Trainer's default class), + # observation- and action spaces (inferred from the env), and + # config overrides ({} in this case). + "policies": {"av"}, + # Map all agents to the "av" PolicyID. + "policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: "av", } # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. diff --git a/rllib/examples/remote_vector_env_with_custom_api.py b/rllib/examples/remote_vector_env_with_custom_api.py index 1dcc65eda89f8..c212249990611 100644 --- a/rllib/examples/remote_vector_env_with_custom_api.py +++ b/rllib/examples/remote_vector_env_with_custom_api.py @@ -65,7 +65,7 @@ class NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv(TaskSettableEnv): of gym.Env). """ - def __init__(self, config): + def __init__(self, config=None): self.action_space = gym.spaces.Box(0, 1, shape=(1, )) self.observation_space = gym.spaces.Box(0, 1, shape=(2, )) self.task = 1 @@ -108,7 +108,6 @@ def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: # Specify your custom (single, non-vectorized) env directly as a # class. This way, RLlib can auto-create Actors from this class # and handle everything correctly. - # TODO: Test for multi-agent case. "env": NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv, # Set up our own callbacks. "callbacks": TaskSettingCallback, diff --git a/rllib/examples/rock_paper_scissors_multiagent.py b/rllib/examples/rock_paper_scissors_multiagent.py index bc7477a7f0716..0905314c1140b 100644 --- a/rllib/examples/rock_paper_scissors_multiagent.py +++ b/rllib/examples/rock_paper_scissors_multiagent.py @@ -9,19 +9,19 @@ import argparse import os +from pettingzoo.classic import rps_v2 import random from ray import tune from ray.rllib.agents.pg import PGTrainer, PGTFPolicy, PGTorchPolicy from ray.rllib.agents.registry import get_trainer_class +from ray.rllib.env import PettingZooEnv from ray.rllib.examples.policy.rock_paper_scissors_dummies import \ BeatLastHeuristic, AlwaysSameHeuristic from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved from ray.tune.registry import register_env -from ray.rllib.env import PettingZooEnv -from pettingzoo.classic import rps_v2 tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -149,8 +149,8 @@ def entropy_policy_gradient_loss(policy, model, dist_class, train_batch): logits, _ = model.from_batch(train_batch) action_dist = dist_class(logits, model) if args.framework == "torch": - # required by PGTorchPolicy's stats fn. - policy.pi_err = torch.tensor([0.0]) + # Required by PGTorchPolicy's stats fn. + model.tower_stats["policy_loss"] = torch.tensor([0.0]) return torch.mean(-0.1 * action_dist.entropy() - (action_dist.logp(train_batch["actions"]) * train_batch["advantages"])) diff --git a/rllib/examples/serving/cartpole_client.py b/rllib/examples/serving/cartpole_client.py index 4f9f247eda49b..a368e6b44b852 100755 --- a/rllib/examples/serving/cartpole_client.py +++ b/rllib/examples/serving/cartpole_client.py @@ -54,7 +54,7 @@ "(Policy-computed) ones.") parser.add_argument( "--stop-reward", - type=int, + type=float, default=9999, help="Stop once the specified reward is reached.") parser.add_argument( diff --git a/rllib/examples/serving/unity3d_client.py b/rllib/examples/serving/unity3d_client.py index 8c8784ebf18ab..f3089abd402ae 100644 --- a/rllib/examples/serving/unity3d_client.py +++ b/rllib/examples/serving/unity3d_client.py @@ -52,9 +52,13 @@ parser.add_argument( "--server", type=str, - default=SERVER_ADDRESS + ":" + str(SERVER_PORT), - help="The Policy server's address and port to connect to from this client." -) + default=SERVER_ADDRESS, + help="The Policy server's address to connect to from this client.") +parser.add_argument( + "--port", + type=int, + default=SERVER_PORT, + help="The port to use (on --server).") parser.add_argument( "--no-train", action="store_true", @@ -75,7 +79,7 @@ "learnt policy weights from the server?") parser.add_argument( "--stop-reward", - type=int, + type=float, default=9999, help="Stop once the specified reward is reached.") @@ -85,7 +89,7 @@ # Start the client for sending environment information (e.g. observations, # actions) to a policy server (listening on port 9900). client = PolicyClient( - "http://" + args.server, + "http://" + args.server + ":" + str(args.port), inference_mode=args.inference_mode, update_interval=args.update_interval_local_mode) diff --git a/rllib/examples/serving/unity3d_dummy_client.py b/rllib/examples/serving/unity3d_dummy_client.py new file mode 100644 index 0000000000000..93e7245f31a43 --- /dev/null +++ b/rllib/examples/serving/unity3d_dummy_client.py @@ -0,0 +1,144 @@ +""" +Dummy in-place replacement for the unity3d_client.py script +in case you don't have an actual Unity3D engine installed or just want +to test client/server connectivity with the unity3d_server.py script. + +This client script simply uses RLlib's RandomMultiAgentEnv to mimic +one of the ML Agents (Unity3D) example games (e.g. "3DBall"). + +To run this script on possibly different machines +against a central Policy server: + +1) Run (two separate shells/machines): +$ python unity3d_server.py --env 3DBall +$ python unity3d_dummy_client.py --env 3DBall --inference-mode=local +""" + +import argparse + +from ray.rllib.env.policy_client import PolicyClient +from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv +from ray.rllib.examples.env.random_env import RandomMultiAgentEnv + +SERVER_ADDRESS = "localhost" +SERVER_PORT = 9900 + +parser = argparse.ArgumentParser() +parser.add_argument( + "--env", + type=str, + default="3DBall", + choices=[ + "3DBall", "3DBallHard", "FoodCollector", "GridFoodCollector", + "Pyramids", "Sorter", "Tennis", "VisualHallway", "Walker" + ], + help="The name of the Env to mimic. Only those examples supported so " + "far for which all agents have the same " + "observation- and action spaces (feel free to add more to this script!)") +parser.add_argument( + "--horizon", + type=int, + default=200, + help="The max. number of `step()`s for any episode (per agent) before " + "it'll be reset again automatically.") +parser.add_argument( + "--server", + type=str, + default=SERVER_ADDRESS, + help="The Policy server's address to connect to from this client.") +parser.add_argument( + "--port", + type=int, + default=SERVER_PORT, + help="The port to use (on --server).") +parser.add_argument( + "--no-train", + action="store_true", + help="Whether to disable training (on the server side).") +parser.add_argument( + "--inference-mode", + type=str, + default="local", + choices=["local", "remote"], + help="Whether to compute actions `local`ly or `remote`ly. Note that " + "`local` is much faster b/c observations/actions do not have to be " + "sent via the network.") +parser.add_argument( + "--update-interval-local-mode", + type=float, + default=10.0, + help="For `inference-mode=local`, every how many seconds do we update " + "learnt policy weights from the server?") +parser.add_argument( + "--num-episodes", + type=int, + default=10, + help="Stop once the specified number of episodes have been played.") + +if __name__ == "__main__": + args = parser.parse_args() + + # Start the client for sending environment information (e.g. observations, + # actions) to a policy server (listening on port 9900). + client = PolicyClient( + "http://" + args.server + ":" + str(args.port), + inference_mode=args.inference_mode, + update_interval=args.update_interval_local_mode) + + # Get the multi-agent policies dict and agent->policy + # mapping-fn. + policies, policy_mapping_fn = \ + Unity3DEnv.get_policy_configs_for_game(args.env) + + # Make sure all policies' obs- and action spaces are the same. + # If not, we won't be able to mimic the Unity3D env using RLlib's + # RandomMultiAgentEnv. + first_policy_spec = next(iter(policies.values())) + for pid, policy_spec in policies.items(): + assert policy_spec.observation_space == \ + first_policy_spec.observation_space + assert policy_spec.action_space == first_policy_spec.action_space + + # Start and reset the actual Unity3DEnv (either already running Unity3D + # editor or a binary (game) to be started automatically). + env = RandomMultiAgentEnv({ + # Same number of agents as the actual Unity3D game would have. + "num_agents": len(policies), + # Make sure we stick to the user given horizons using our + # RandomMultiAgentEnv options. + "max_episode_len": args.horizon, + "p_done": 0.0, + # Same obs- action spaces as the actual Unity3D game would have. + "observation_space": first_policy_spec.observation_space, + "action_space": first_policy_spec.action_space, + }) + obs = env.reset() + eid = client.start_episode(training_enabled=not args.no_train) + + # Keep track of the total reward per episode. + total_rewards_this_episode = 0.0 + + # Loop through the env until n episodes completed. + num_episodes = 0 + while True: + # Get actions from the Policy server given our current obs. + actions = client.get_action(eid, obs) + # Apply actions to our env. + obs, rewards, dones, infos = env.step(actions) + total_rewards_this_episode += sum(rewards.values()) + # Log rewards and single-agent dones. + client.log_returns(eid, rewards, infos, multiagent_done_dict=dones) + # Check whether all agents are done and end the episode, if necessary. + if dones["__all__"]: + print("Episode done: Reward={}".format(total_rewards_this_episode)) + + num_episodes += 1 + if num_episodes >= args.num_episodes: + quit(0) + + # End the episode and reset dummy Env. + total_rewards_this_episode = 0.0 + client.end_episode(eid, obs) + obs = env.reset() + # Start a new episode. + eid = client.start_episode(training_enabled=not args.no_train) diff --git a/rllib/examples/serving/unity3d_server.py b/rllib/examples/serving/unity3d_server.py index 56c1a0089fe50..04ce5567fc165 100755 --- a/rllib/examples/serving/unity3d_server.py +++ b/rllib/examples/serving/unity3d_server.py @@ -31,24 +31,42 @@ import os import ray -from ray.tune import register_env -from ray.rllib.agents.ppo import PPOTrainer +from ray.rllib.agents.registry import get_trainer_class from ray.rllib.env.policy_server_input import PolicyServerInput from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv -from ray.rllib.examples.env.random_env import RandomMultiAgentEnv SERVER_ADDRESS = "localhost" SERVER_PORT = 9900 CHECKPOINT_FILE = "last_checkpoint_{}.out" parser = argparse.ArgumentParser() +parser.add_argument( + "--run", + default="PPO", + choices=["DQN", "PPO"], + help="The RLlib-registered algorithm to use.") +parser.add_argument( + "--framework", + choices=["tf", "tf2", "tfe", "torch"], + default="tf", + help="The DL framework specifier.") +parser.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of workers to use. Each worker will create " + "its own listening socket for incoming experiences.") parser.add_argument( "--env", type=str, default="3DBall", - choices=["3DBall", "SoccerStrikersVsGoalie"], - help="The name of the Env to run in the Unity3D editor. Either `3DBall` " - "or `SoccerStrikersVsGoalie` (feel free to add more to this script!)") + choices=[ + "3DBall", "3DBallHard", "FoodCollector", "GridFoodCollector", + "Pyramids", "SoccerStrikersVsGoalie", "Sorter", "Tennis", + "VisualHallway", "Walker" + ], + help="The name of the Env to run in the Unity3D editor " + "(feel free to add more to this script!)") parser.add_argument( "--port", type=int, @@ -71,11 +89,21 @@ args = parser.parse_args() ray.init() - # Create a fake-env for the server. This env will never be used (neither - # for sampling, nor for evaluation) and its obs/action Spaces do not - # matter either (multi-agent config below defines Spaces per Policy). - register_env("fake_unity", lambda c: RandomMultiAgentEnv(c)) - + # `InputReader` generator (returns None if no input reader is needed on + # the respective worker). + def _input(ioctx): + # We are remote worker or we are local worker with num_workers=0: + # Create a PolicyServerInput. + if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0: + return PolicyServerInput( + ioctx, SERVER_ADDRESS, args.port + ioctx.worker_index - + (1 if ioctx.worker_index > 0 else 0)) + # No InputReader (PolicyServerInput) needed. + else: + return None + + # Get the multi-agent policies dict and agent->policy + # mapping-fn. policies, policy_mapping_fn = \ Unity3DEnv.get_policy_configs_for_game(args.env) @@ -83,27 +111,31 @@ # build their own samplers (and also Policy objects iff # `inference_mode=local` on clients' command line). config = { - # Use the connector server to generate experiences. - "input": ( - lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, args.port)), - # Use a single worker process (w/ SyncSampler) to run the server. - "num_workers": 0, + # Indicate that the Trainer we setup here doesn't need an actual env. + # Allow spaces to be determined by user (see below). + "env": None, + + # Use the `PolicyServerInput` to generate experiences. + "input": _input, + # Use n worker processes to listen on different ports. + "num_workers": args.num_workers, # Disable OPE, since the rollouts are coming from online clients. "input_evaluation": [], # Other settings. "train_batch_size": 256, "rollout_fragment_length": 20, - # Multi-agent setup for the particular env. + # Multi-agent setup for the given env. "multiagent": { "policies": policies, "policy_mapping_fn": policy_mapping_fn, }, - "framework": "tf", + # DL framework to use. + "framework": args.framework, } # Create the Trainer used for Policy serving. - trainer = PPOTrainer(env="fake_unity", config=config) + trainer = get_trainer_class(args.run)(config=config) # Attempt to restore from checkpoint if possible. checkpoint_path = CHECKPOINT_FILE.format(args.env) diff --git a/rllib/examples/trajectory_view_api.py b/rllib/examples/trajectory_view_api.py index 31ce04e879126..b4a288e013bd5 100644 --- a/rllib/examples/trajectory_view_api.py +++ b/rllib/examples/trajectory_view_api.py @@ -1,13 +1,15 @@ import argparse +import numpy as np import ray -from ray import tune +from ray.rllib.agents.ppo import PPOTrainer from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole from ray.rllib.examples.models.trajectory_view_utilizing_models import \ FrameStackingCartPoleModel, TorchFrameStackingCartPoleModel from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_learning_achieved +from ray import tune tf1, tf, tfv = try_import_tf() @@ -47,18 +49,19 @@ args = parser.parse_args() ray.init(num_cpus=3) + num_frames = 16 + ModelCatalog.register_custom_model( "frame_stack_model", FrameStackingCartPoleModel if args.framework != "torch" else TorchFrameStackingCartPoleModel) - tune.register_env("stateless_cartpole", lambda c: StatelessCartPole()) config = { - "env": "stateless_cartpole", + "env": StatelessCartPole, "model": { "vf_share_layers": True, "custom_model": "frame_stack_model", "custom_model_config": { - "num_frames": 16, + "num_frames": num_frames, }, # To compare against a simple LSTM: @@ -81,8 +84,45 @@ "timesteps_total": args.stop_timesteps, "episode_reward_mean": args.stop_reward, } - results = tune.run(args.run, config=config, stop=stop, verbose=2) + results = tune.run( + args.run, config=config, stop=stop, verbose=2, checkpoint_at_end=True) if args.as_test: check_learning_achieved(results, args.stop_reward) + + checkpoints = results.get_trial_checkpoints_paths( + trial=results.get_best_trial("episode_reward_mean", mode="max"), + metric="episode_reward_mean") + + checkpoint_path = checkpoints[0][0] + trainer = PPOTrainer(config) + trainer.restore(checkpoint_path) + + # Inference loop. + env = StatelessCartPole() + + # Run manual inference loop for n episodes. + for _ in range(10): + episode_reward = 0.0 + reward = 0.0 + action = 0 + done = False + obs = env.reset() + while not done: + # Create a dummy action using the same observation n times, + # as well as dummy prev-n-actions and prev-n-rewards. + action, state, logits = trainer.compute_single_action( + input_dict={ + "obs": obs, + "prev_n_obs": np.stack([obs for _ in range(num_frames)]), + "prev_n_actions": np.stack([0 for _ in range(num_frames)]), + "prev_n_rewards": np.stack( + [1.0 for _ in range(num_frames)]), + }, + full_fetch=True) + obs, reward, done, info = env.step(action) + episode_reward += reward + + print(f"Episode reward={episode_reward}") + ray.shutdown() diff --git a/rllib/execution/common.py b/rllib/execution/common.py index 3349541dac2f6..25e4428bffb63 100644 --- a/rllib/execution/common.py +++ b/rllib/execution/common.py @@ -22,9 +22,6 @@ LEARN_ON_BATCH_TIMER = "learn" LOAD_BATCH_TIMER = "load" -# Instant metrics (keys for metrics.info). -LEARNER_INFO = "learner" - # Asserts that an object is a type of SampleBatch. def _check_sample_batch_type(batch: SampleBatchType) -> None: diff --git a/rllib/execution/learner_thread.py b/rllib/execution/learner_thread.py index be7b028cdb04f..d8c6f93c146b1 100644 --- a/rllib/execution/learner_thread.py +++ b/rllib/execution/learner_thread.py @@ -3,10 +3,11 @@ import threading from typing import Dict, Optional -from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.execution.minibatch_buffer import MinibatchBuffer from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ + LEARNER_INFO, LEARNER_STATS_KEY from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat from ray.util.iter import _NextValueNotReady @@ -56,7 +57,7 @@ def __init__(self, local_worker: RolloutWorker, minibatch_buffer_size: int, self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False - self.stats = {} + self.learner_info = {} self.stopped = False self.num_steps = 0 @@ -75,12 +76,24 @@ def step(self) -> Optional[_NextValueNotReady]: return _NextValueNotReady() with self.grad_timer: - fetches = self.local_worker.learn_on_batch(batch) + # Use LearnerInfoBuilder as a unified way to build the final + # results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same structure + # no matter the setup (multi-GPU, multi-agent, minibatch SGD, + # tf vs torch). + learner_info_builder = LearnerInfoBuilder(num_devices=1) + multi_agent_results = self.local_worker.learn_on_batch(batch) + for pid, results in multi_agent_results.items(): + learner_info_builder.add_learn_on_batch_results(results, pid) + self.learner_info = learner_info_builder.finalize() + learner_stats = { + pid: info[LEARNER_STATS_KEY] + for pid, info in self.learner_info.items() + } self.weights_updated = True - self.stats = get_learner_stats(fetches) self.num_steps += 1 - self.outqueue.put((batch.count, self.stats)) + self.outqueue.put((batch.count, learner_stats)) self.learner_queue_size.push(self.inqueue.qsize()) def add_learner_metrics(self, result: Dict) -> Dict: @@ -91,7 +104,7 @@ def timer_to_ms(timer): result["info"].update({ "learner_queue": self.learner_queue_size.stats(), - "learner": copy.deepcopy(self.stats), + LEARNER_INFO: copy.deepcopy(self.learner_info), "timing_breakdown": { "learner_grad_time_ms": timer_to_ms(self.grad_timer), "learner_load_time_ms": timer_to_ms(self.load_timer), diff --git a/rllib/execution/multi_gpu_learner_thread.py b/rllib/execution/multi_gpu_learner_thread.py index 0d230878ff609..1120be7a77d4b 100644 --- a/rllib/execution/multi_gpu_learner_thread.py +++ b/rllib/execution/multi_gpu_learner_thread.py @@ -1,15 +1,15 @@ import logging -import threading - from six.moves import queue +import threading -from ray.rllib.evaluation.metrics import get_learner_stats -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.execution.learner_thread import LearnerThread from ray.rllib.execution.minibatch_buffer import MinibatchBuffer +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ + LEARNER_STATS_KEY from ray.rllib.utils.timer import TimerStat from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -103,18 +103,14 @@ def __init__( self.train_batch_size = train_batch_size - # TODO: (sven) Allow multi-GPU to work for multi-agent as well. - self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID] + self.policy_map = self.local_worker.policy_map + self.devices = next(iter(self.policy_map.values())).devices - logger.info("MultiGPULearnerThread devices {}".format( - self.policy.devices)) - assert self.train_batch_size % len(self.policy.devices) == 0 - assert self.train_batch_size >= len(self.policy.devices),\ + logger.info("MultiGPULearnerThread devices {}".format(self.devices)) + assert self.train_batch_size % len(self.devices) == 0 + assert self.train_batch_size >= len(self.devices),\ "batch too small" - if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}: - raise NotImplementedError("Multi-gpu mode for multi-agent") - self.tower_stack_indices = list(range(num_multi_gpu_tower_stacks)) # Two queues for tower stacks: @@ -146,18 +142,39 @@ def step(self) -> None: with self.load_wait_timer: buffer_idx, released = self.ready_tower_stacks_buffer.get() + get_num_samples_loaded_into_buffer = 0 with self.grad_timer: - fetches = self.policy.learn_on_loaded_batch( - offset=0, buffer_index=buffer_idx) - self.weights_updated = True - self.stats = {DEFAULT_POLICY_ID: get_learner_stats(fetches)} + # Use LearnerInfoBuilder as a unified way to build the final + # results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same structure + # no matter the setup (multi-GPU, multi-agent, minibatch SGD, + # tf vs torch). + learner_info_builder = LearnerInfoBuilder( + num_devices=len(self.devices)) + + for pid in self.policy_map.keys(): + # Not a policy-to-train. + if pid not in self.local_worker.policies_to_train: + continue + policy = self.policy_map[pid] + default_policy_results = policy.learn_on_loaded_batch( + offset=0, buffer_index=buffer_idx) + learner_info_builder.add_learn_on_batch_results( + default_policy_results) + self.weights_updated = True + get_num_samples_loaded_into_buffer += \ + policy.get_num_samples_loaded_into_buffer(buffer_idx) + + self.learner_info = learner_info_builder.finalize() + learner_stats = { + pid: self.learner_info[pid][LEARNER_STATS_KEY] + for pid in self.learner_info.keys() + } if released: self.idle_tower_stacks.put(buffer_idx) - self.outqueue.put( - (self.policy.get_num_samples_loaded_into_buffer(buffer_idx), - self.stats)) + self.outqueue.put((get_num_samples_loaded_into_buffer, learner_stats)) self.learner_queue_size.push(self.inqueue.qsize()) @@ -180,7 +197,7 @@ def run(self) -> None: def _step(self) -> None: s = self.multi_gpu_learner_thread - policy = s.policy + policy_map = s.policy_map # Get a new batch from the data (inqueue). with self.queue_timer: @@ -191,7 +208,14 @@ def _step(self) -> None: # Load the batch into the idle stack. with self.load_timer: - policy.load_batch_into_buffer(batch=batch, buffer_index=buffer_idx) + for pid in policy_map.keys(): + if pid not in s.local_worker.policies_to_train: + continue + policy = policy_map[pid] + policy.load_batch_into_buffer( + batch=batch if isinstance(batch, SampleBatch) else + batch.policy_batches[pid], + buffer_index=buffer_idx) # Tag just-loaded stack as "ready". s.ready_tower_stacks.put(buffer_idx) diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 364a814c8c996..1f65620b115eb 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -4,14 +4,15 @@ from ray.util.iter import from_actors, LocalIterator from ray.util.iter_metrics import SharedMetrics -from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \ - STEPS_SAMPLED_COUNTER, LEARNER_INFO, SAMPLE_TIMER, GRAD_WAIT_TIMER, \ + STEPS_SAMPLED_COUNTER, SAMPLE_TIMER, GRAD_WAIT_TIMER, \ _check_sample_batch_type, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY from ray.rllib.utils.sgd import standardized from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients @@ -130,7 +131,9 @@ def __call__(self, item): (grads, info), count = item metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count - metrics.info[LEARNER_INFO] = get_learner_stats(info) + metrics.info[LEARNER_INFO] = { + DEFAULT_POLICY_ID: info + } if LEARNER_STATS_KEY in info else info metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() - self.fetch_start_time) return grads, count @@ -162,15 +165,24 @@ def __init__(self, min_batch_size: int, count_steps_by: str = "env_steps"): def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) - self.buffer.append(batch) if self.count_steps_by == "env_steps": - self.count += batch.count + size = batch.count else: assert isinstance(batch, MultiAgentBatch), \ "`count_steps_by=agent_steps` only allowed in multi-agent " \ "environments!" - self.count += batch.agent_steps() + size = batch.agent_steps() + + # Incoming batch is an empty dummy batch -> Ignore. + # Possibly produced automatically by a PolicyServer to unblock + # an external env waiting for inputs from unresponsive/disconnected + # client(s). + if size == 0: + return [] + + self.count += size + self.buffer.append(batch) if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index 6c0e089ef598a..e289d5a7f2fbb 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -1,22 +1,21 @@ import logging import numpy as np import math -import tree # pip install dm_tree from typing import List, Tuple, Any import ray -from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import \ AGENT_STEPS_TRAINED_COUNTER, APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, \ - LAST_TARGET_UPDATE_TS, LEARNER_INFO, LEARN_ON_BATCH_TIMER, \ + LAST_TARGET_UPDATE_TS, LEARN_ON_BATCH_TIMER, \ LOAD_BATCH_TIMER, NUM_TARGET_UPDATES, STEPS_SAMPLED_COUNTER, \ STEPS_TRAINED_COUNTER, WORKER_UPDATE_TIMER, _check_sample_batch_type, \ _get_global_vars, _get_shared_metrics -from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ + LEARNER_INFO from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients @@ -62,7 +61,7 @@ def __call__(self, # train batch and loop through train batch `num_sgd_iter` times. if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: lw = self.workers.local_worker() - info = do_minibatch_sgd( + learner_info = do_minibatch_sgd( batch, { pid: lw.get_policy(pid) for pid in self.policies @@ -70,9 +69,10 @@ def __call__(self, }, lw, self.num_sgd_iter, self.sgd_minibatch_size, []) # Single update step using train batch. else: - info = self.workers.local_worker().learn_on_batch(batch) + learner_info = \ + self.workers.local_worker().learn_on_batch(batch) - metrics.info[LEARNER_INFO] = info + metrics.info[LEARNER_INFO] = learner_info learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): @@ -88,7 +88,7 @@ def __call__(self, e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) - return batch, info + return batch, learner_info class MultiGPUTrainOneStep: @@ -174,56 +174,43 @@ def __call__(self, # Execute minibatch SGD on loaded data. with learn_timer: - fetches = {} + # Use LearnerInfoBuilder as a unified way to build the final + # results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same structure + # no matter the setup (multi-GPU, multi-agent, minibatch SGD, + # tf vs torch). + learner_info_builder = LearnerInfoBuilder( + num_devices=len(self.devices)) + for policy_id, samples_per_device in num_loaded_samples.items(): policy = self.local_worker.policy_map[policy_id] num_batches = max( 1, int(samples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) - batch_fetches_all_towers = [] for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. - batch_fetches = policy.learn_on_loaded_batch( + results = policy.learn_on_loaded_batch( permutation[batch_index] * self.per_device_batch_size, buffer_index=0) - # No towers: Single CPU. - if "tower_0" not in batch_fetches: - batch_fetches_all_towers.append(batch_fetches) - else: - batch_fetches_all_towers.append( - tree.map_structure_with_path( - lambda p, *s: all_tower_reduce(p, *s), - *(batch_fetches.pop( - "tower_{}".format(tower_num)) - for tower_num in range( - len(self.devices))))) - for k, v in batch_fetches.items(): - if k == LEARNER_STATS_KEY: - for k1, v1 in batch_fetches[k].items(): - batch_fetches_all_towers[-1][ - LEARNER_STATS_KEY][k1] = v1 - else: - batch_fetches_all_towers[-1][k] = v - - # Reduce mean across all minibatch SGD steps (axis=0 to keep - # all shapes as-is). - fetches[policy_id] = tree.map_structure( - lambda *s: None if s[0] is None else np.nanmean(s, axis=0), - *batch_fetches_all_towers) + learner_info_builder.add_learn_on_batch_results( + results, policy_id) + + # Tower reduce and finalize results. + learner_info = learner_info_builder.finalize() load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() - metrics.info[LEARNER_INFO] = fetches + metrics.info[LEARNER_INFO] = learner_info if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: @@ -234,24 +221,13 @@ def __call__(self, # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) - return samples, fetches + return samples, learner_info # Backward compatibility. TrainTFMultiGPU = MultiGPUTrainOneStep -def all_tower_reduce(path, *tower_data): - """Reduces stats across towers based on their stats-dict paths.""" - if len(path) == 1 and path[0] == "td_error": - return np.concatenate(tower_data, axis=0) - elif path[-1].startswith("min_"): - return np.nanmin(tower_data) - elif path[-1].startswith("max_"): - return np.nanmax(tower_data) - return np.nanmean(tower_data) - - class ComputeGradients: """Callable that computes gradients with respect to the policy loss. @@ -273,7 +249,12 @@ def __call__(self, samples: SampleBatchType) -> Tuple[ModelGradients, int]: metrics = _get_shared_metrics() with metrics.timers[COMPUTE_GRADS_TIMER]: grad, info = self.workers.local_worker().compute_gradients(samples) - metrics.info[LEARNER_INFO] = get_learner_stats(info) + # RolloutWorker.compute_gradients returns pure single agent stats + # in a non-multi agent setup. + if isinstance(samples, MultiAgentBatch): + metrics.info[LEARNER_INFO] = info + else: + metrics.info[LEARNER_INFO] = {DEFAULT_POLICY_ID: info} return grad, samples.count diff --git a/rllib/models/tests/test_preprocessors.py b/rllib/models/tests/test_preprocessors.py index 015efe6edd723..2107ddec0cbd0 100644 --- a/rllib/models/tests/test_preprocessors.py +++ b/rllib/models/tests/test_preprocessors.py @@ -10,7 +10,7 @@ get_preprocessor, NoPreprocessor, TupleFlatteningPreprocessor, \ OneHotPreprocessor, AtariRamPreprocessor, GenericPixelPreprocessor from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator class TestPreprocessors(unittest.TestCase): @@ -50,7 +50,9 @@ def test_preprocessing_disabled(self): for _ in framework_iterator(config): trainer = ppo.PPOTrainer(config=config) for i in range(num_iterations): - print(trainer.train()) + results = trainer.train() + check_train_results(results) + print(results) check_compute_single_action(trainer) trainer.stop() diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index 2236607d3f75f..c7323c41cab96 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -38,6 +38,8 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, assert isinstance(self.original_space, (Dict, Tuple)), \ "`obs_space.original_space` must be [Dict|Tuple]!" + self.processed_obs_space = self.original_space if \ + model_config.get("_disable_preprocessor_api") else obs_space super().__init__(self.original_space, action_space, num_outputs, model_config, name) @@ -124,8 +126,10 @@ def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: - orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], - self.obs_space, "tf") + orig_obs = restore_original_dimensions( + input_dict[SampleBatch.OBS], + self.processed_obs_space, + tensorlib="tf") # Push image observations through our CNNs. outs = [] for i, component in enumerate(tree.flatten(orig_obs)): diff --git a/rllib/models/torch/complex_input_net.py b/rllib/models/torch/complex_input_net.py index b795e4d5485c3..ac053bab6ccf3 100644 --- a/rllib/models/torch/complex_input_net.py +++ b/rllib/models/torch/complex_input_net.py @@ -40,6 +40,9 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, assert isinstance(self.original_space, (Dict, Tuple)), \ "`obs_space.original_space` must be [Dict|Tuple]!" + self.processed_obs_space = self.original_space if \ + model_config.get("_disable_preprocessor_api") else obs_space + nn.Module.__init__(self) TorchModelV2.__init__(self, self.original_space, action_space, num_outputs, model_config, name) @@ -140,8 +143,10 @@ def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: - orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], - self.obs_space, "tf") + orig_obs = restore_original_dimensions( + input_dict[SampleBatch.OBS], + self.processed_obs_space, + tensorlib="torch") # Push image observations through our CNNs. outs = [] for i, component in enumerate(tree.flatten(orig_obs)): diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py index 5cde72c4422e6..a7cc38cef6c8a 100644 --- a/rllib/models/torch/torch_modelv2.py +++ b/rllib/models/torch/torch_modelv2.py @@ -46,6 +46,14 @@ def __init__(self, *args, **kwargs): name, framework="torch") + # Dict to store per multi-gpu tower stats into. + # In PyTorch multi-GPU, we use a single TorchPolicy and copy + # it's Model(s) n times (1 copy for each GPU). When computing the loss + # on each tower, we cannot store the stats (e.g. `entropy`) inside the + # policy object as this would lead to race conditions between the + # different towers all accessing the same property at the same time. + self.tower_stats = {} + @override(ModelV2) def variables(self, as_dict: bool = False) -> \ Union[List[TensorType], Dict[str, TensorType]]: diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 169dc0bad7f41..76bb4c6bb666a 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -11,13 +11,14 @@ from ray.util.debug import log_once from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.repeated_values import RepeatedValues -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import add_mixins, force_list from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.tf_ops import get_gpu_devices from ray.rllib.utils.threading import with_lock @@ -65,15 +66,17 @@ def convert_eager_inputs(func): @functools.wraps(func) def _func(*args, **kwargs): if tf.executing_eagerly(): - args = [_convert_to_tf(x) for x in args] + eager_args = [_convert_to_tf(x) for x in args] # TODO: (sven) find a way to remove key-specific hacks. - kwargs = { + eager_kwargs = { k: _convert_to_tf( v, dtype=tf.int64 if k == "timestep" else None) for k, v in kwargs.items() if k not in {"info_batch", "episodes"} } - return func(*args, **kwargs) + return func(*eager_args, **eager_kwargs) + else: + return func(*args, **kwargs) return _func @@ -182,6 +185,14 @@ def apply_gradients(self, grads): return TracedEagerPolicy +class OptimizerWrapper: + def __init__(self, tape): + self.tape = tape + + def compute_gradients(self, loss, var_list): + return list(zip(self.tape.gradient(loss, var_list), var_list)) + + def build_eager_tf_policy( name, loss_fn, @@ -323,8 +334,11 @@ def __init__(self, observation_space, action_space, config): if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) - # TODO: (sven) Allow tf policy to have more than 1 optimizer. - # Just like torch Policy does. + + # The list of local (tf) optimizers (one per loss term). + self._optimizers: List[LocalOptimizer] = optimizers + # Backward compatibility: A user's policy may only support a single + # loss term and optimizer (no lists). self._optimizer: LocalOptimizer = \ optimizers[0] if optimizers else None @@ -432,6 +446,7 @@ def compute_actions(self, lambda s: tf.convert_to_tensor(s), obs_batch), }, _is_training=tf.constant(False)) + self._lazy_tensor_dict(input_dict) if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = \ tf.convert_to_tensor(prev_action_batch) @@ -465,7 +480,6 @@ def compute_actions_from_input_dict( explore, timestep) @with_lock - @convert_eager_inputs @convert_eager_outputs def _compute_action_helper(self, input_dict, state_batches, episodes, explore, timestep): @@ -481,7 +495,8 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, self._is_training = False self._state_in = state_batches or [] # Calculate RNN sequence lengths. - batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0] + batch_size = int( + tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]) seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \ else None @@ -528,7 +543,7 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, dist_inputs, self.dist_class, state_out = \ action_distribution_fn( self, self.model, - input_dict[SampleBatch.CUR_OBS], + input_dict[SampleBatch.OBS], explore=explore, timestep=timestep, is_training=False) @@ -566,7 +581,7 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, extra_fetches.update(extra_action_out_fn(self)) # Update our global timestep by the batch size. - self.global_timestep += int(batch_size) + self.global_timestep += batch_size return actions, state_out, extra_fetches @@ -725,51 +740,78 @@ def export_checkpoint(self, export_dir): def _get_is_training_placeholder(self): return tf.convert_to_tensor(self._is_training) - def _apply_gradients(self, grads_and_vars): - if apply_gradients_fn: - apply_gradients_fn(self, self._optimizer, grads_and_vars) - else: - self._optimizer.apply_gradients( - [(g, v) for g, v in grads_and_vars if g is not None]) - @with_lock def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" - with tf.GradientTape(persistent=compute_gradients_fn is not None) \ - as tape: - loss = loss_fn(self, self.model, self.dist_class, samples) - + # Gather all variables for which to calculate losses. if isinstance(self.model, tf.keras.Model): variables = self.model.trainable_variables else: variables = self.model.trainable_variables() - if compute_gradients_fn: - - class OptimizerWrapper: - def __init__(self, tape): - self.tape = tape - - def compute_gradients(self, loss, var_list): - return list( - zip(self.tape.gradient(loss, var_list), var_list)) + # Calculate the loss(es) inside a tf GradientTape. + with tf.GradientTape(persistent=compute_gradients_fn is not None) \ + as tape: + losses = loss_fn(self, self.model, self.dist_class, samples) + losses = force_list(losses) - grads_and_vars = compute_gradients_fn(self, - OptimizerWrapper(tape), - loss) + # User provided a compute_gradients_fn. + if compute_gradients_fn: + # Wrap our tape inside a wrapper, such that the resulting + # object looks like a "classic" tf.optimizer. This way, custom + # compute_gradients_fn will work on both tf static graph + # and tf-eager. + optimizer = OptimizerWrapper(tape) + # More than one loss terms/optimizers. + if self.config["_tf_policy_handles_more_than_one_loss"]: + grads_and_vars = compute_gradients_fn( + self, [optimizer] * len(losses), losses) + # Only one loss and one optimizer. + else: + grads_and_vars = [ + compute_gradients_fn(self, optimizer, losses[0]) + ] + # Default: Compute gradients using the above tape. else: - grads_and_vars = list( - zip(tape.gradient(loss, variables), variables)) + grads_and_vars = [ + list(zip(tape.gradient(loss, variables), variables)) + for loss in losses + ] if log_once("grad_vars"): - for _, v in grads_and_vars: - logger.info("Optimizing variable {}".format(v.name)) + for g_and_v in grads_and_vars: + for g, v in g_and_v: + if g is not None: + logger.info(f"Optimizing variable {v.name}") + + # `grads_and_vars` is returned a list (len=num optimizers/losses) + # of lists of (grad, var) tuples. + if self.config["_tf_policy_handles_more_than_one_loss"]: + grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars] + # `grads_and_vars` is returned as a list of (grad, var) tuples. + else: + grads_and_vars = grads_and_vars[0] + grads = [g for g, _ in grads_and_vars] - grads = [g for g, v in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, stats + def _apply_gradients(self, grads_and_vars): + if apply_gradients_fn: + if self.config["_tf_policy_handles_more_than_one_loss"]: + apply_gradients_fn(self, self._optimizers, grads_and_vars) + else: + apply_gradients_fn(self, self._optimizer, grads_and_vars) + else: + if self.config["_tf_policy_handles_more_than_one_loss"]: + for i, o in enumerate(self._optimizers): + o.apply_gradients([(g, v) for g, v in grads_and_vars[i] + if g is not None]) + else: + self._optimizer.apply_gradients( + [(g, v) for g, v in grads_and_vars if g is not None]) + def _stats(self, outputs, samples, grads): fetches = {} diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 3f75a8429c98a..6fd89f0117b97 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -14,9 +14,8 @@ from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import from_config -from ray.rllib.utils.spaces.space_utils import clip_action, \ - get_base_struct_from_space, get_dummy_batch_for_space, unbatch, \ - unsquash_action +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ + get_dummy_batch_for_space, unbatch from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ TensorType, TensorStructType, TrainerConfigDict, Tuple, Union @@ -28,10 +27,6 @@ logger = logging.getLogger(__name__) -# By convention, metrics from optimizing the loss can be reported in the -# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. -LEARNER_STATS_KEY = "learner_stats" - # A policy spec used in the "config.multiagent.policies" specification dict # as values (keys are the policy IDs (str)). E.g.: # config: @@ -180,16 +175,17 @@ def compute_actions( @DeveloperAPI def compute_single_action( self, - obs: TensorStructType, + obs: Optional[TensorStructType] = None, state: Optional[List[TensorType]] = None, + *, prev_action: Optional[TensorStructType] = None, prev_reward: Optional[TensorStructType] = None, info: dict = None, + input_dict: Optional[SampleBatch] = None, episode: Optional["MultiAgentEpisode"] = None, - clip_actions: bool = None, explore: Optional[bool] = None, timestep: Optional[int] = None, - unsquash_actions: bool = None, + # Kwars placeholder for future compatibility. **kwargs) -> \ Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]: """Unbatched version of compute_actions. @@ -199,14 +195,13 @@ def compute_single_action( state: List of RNN state inputs, if any. prev_action: Previous action value, if any. prev_reward: Previous reward, if any. - info (dict): Info object, if any. - episode: this provides access to all - of the internal episode state, which may be useful for - model-based or multi-agent algorithms. - unsquash_actions: Should actions be unsquashed according to - the Policy's action space? - clip_actions: Should actions be clipped according to the - Policy's action space? + info: Info object, if any. + input_dict: A SampleBatch or input dict containing the + single (unbatched) Tensors to compute actions. If given, it'll + be used instead of `obs`, `state`, `prev_action|reward`, and + `info`. + episode: This provides access to all of the internal episode state, + which may be useful for model-based or multi-agent algorithms. explore: Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). @@ -220,43 +215,37 @@ def compute_single_action( - state_outs: List of RNN state outputs, if any. - info: Dictionary of extra features, if any. """ - # If policy works in normalized space, we should unsquash the action. - # Use value of config.normalize_actions, if None. - unsquash_actions = \ - unsquash_actions if unsquash_actions is not None \ - else self.config["normalize_actions"] - clip_actions = clip_actions if clip_actions is not None else \ - self.config["clip_actions"] - - prev_action_batch = None - prev_reward_batch = None - info_batch = None + # Build the input-dict used for the call to + # `self.compute_actions_from_input_dict()`. + if input_dict is None: + input_dict = {SampleBatch.OBS: obs} + if state is not None: + for i, s in enumerate(state): + input_dict[f"state_in_{i}"] = s + if prev_action is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action + if prev_reward is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward + if info is not None: + input_dict[SampleBatch.INFOS] = info + + # Batch all data in input dict. + input_dict = tree.map_structure_with_path( + lambda p, s: (s if p == "seq_lens" else s.unsqueeze(0) if + torch and isinstance(s, torch.Tensor) else + np.expand_dims(s, 0)), + input_dict) + episodes = None - state_batch = None - if prev_action is not None: - prev_action_batch = [prev_action] - if prev_reward is not None: - prev_reward_batch = [prev_reward] - if info is not None: - info_batch = [info] if episode is not None: episodes = [episode] - if state is not None: - state_batch = [ - s.unsqueeze(0) - if torch and isinstance(s, torch.Tensor) else np.expand_dims( - s, 0) for s in state - ] - - out = self.compute_actions( - tree.map_structure(lambda s: np.array([s]), obs), - state_batch, - prev_action_batch=prev_action_batch, - prev_reward_batch=prev_reward_batch, - info_batch=info_batch, + + out = self.compute_actions_from_input_dict( + input_dict=SampleBatch(input_dict), episodes=episodes, explore=explore, - timestep=timestep) + timestep=timestep, + ) # Some policies don't return a tuple, but always just a single action. # E.g. ES and ARS. @@ -271,16 +260,6 @@ def compute_single_action( assert len(single_action) == 1 single_action = single_action[0] - # If we work in normalized action space (normalize_actions=True), - # we re-translate here into the env's action space. - if unsquash_actions: - single_action = unsquash_action(single_action, - self.action_space_struct) - # Clip, according to env's action space. - elif clip_actions: - single_action = clip_action(single_action, - self.action_space_struct) - # Return action, internal state(s), infos. return single_action, [s[0] for s in state_out], \ {k: v[0] for k, v in info.items()} @@ -288,7 +267,7 @@ def compute_single_action( @DeveloperAPI def compute_actions_from_input_dict( self, - input_dict: SampleBatch, + input_dict: Union[SampleBatch, Dict[str, TensorStructType]], explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["MultiAgentEpisode"]] = None, @@ -300,14 +279,19 @@ def compute_actions_from_input_dict( to construct the input_dict for the Model. Args: - input_dict (SampleBatch): A SampleBatch containing the Tensors + input_dict: A SampleBatch or input dict containing the Tensors to compute actions. `input_dict` already abides to the Policy's as well as the Model's view requirements and can thus be passed to the Model as-is. - explore (bool): Whether to pick an exploitation or exploration + explore: Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). - timestep (Optional[int]): The current (sampling) time step. - kwargs: forward compatibility placeholder + timestep: The current (sampling) time step. + episodes: This provides access to all of the internal episodes' + state, which may be useful for model-based or multi-agent + algorithms. + + Keyword Args: + kwargs: Forward compatibility placeholder. Returns: Tuple: diff --git a/rllib/policy/policy_template.py b/rllib/policy/policy_template.py index ea231ed2abc8f..d3463df7eaf71 100644 --- a/rllib/policy/policy_template.py +++ b/rllib/policy/policy_template.py @@ -7,12 +7,13 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils import add_mixins, force_list, NullContextManager from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch, try_import_jax +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.torch_ops import convert_to_non_torch_type from ray.rllib.utils.typing import ModelGradients, TensorType, \ TrainerConfigDict diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 9192d5ba6d4d5..389278a1a4328 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -183,7 +183,7 @@ def concat_samples( >>> print(SampleBatch.concat_samples([b1, b2])) {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])} """ - if isinstance(samples[0], MultiAgentBatch): + if any(isinstance(s, MultiAgentBatch) for s in samples): return MultiAgentBatch.concat_samples(samples) concatd_seq_lens = [] concat_samples = [] @@ -1171,7 +1171,12 @@ def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch": policy_batches = collections.defaultdict(list) env_steps = 0 for s in samples: + # Some batches in `samples` are not MultiAgentBatch. if not isinstance(s, MultiAgentBatch): + # If empty SampleBatch: ok (just ignore). + if isinstance(s, SampleBatch) and len(s) <= 0: + continue + # Otherwise: Error. raise ValueError( "`MultiAgentBatch.concat_samples()` can only concat " "MultiAgentBatch types, not {}!".format(type(s).__name__)) diff --git a/rllib/policy/tests/test_compute_log_likelihoods.py b/rllib/policy/tests/test_compute_log_likelihoods.py index 52259d6ea6e60..330ea381bddf9 100644 --- a/rllib/policy/tests/test_compute_log_likelihoods.py +++ b/rllib/policy/tests/test_compute_log_likelihoods.py @@ -57,7 +57,7 @@ def do_test_log_likelihood(run, explore=True, # Do not unsquash actions # (remain in normalized [-1.0; 1.0] space). - unsquash_actions=False, + unsquash_action=False, )) # Test all taken actions for their log-likelihoods vs expected values. diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index bebc9fa185b26..4f4deb15c05e3 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -10,15 +10,16 @@ import ray import ray.experimental.tf_utils from ray.util.debug import log_once -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils import force_list -from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override from ray.rllib.utils.debug import summarize -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_tf, get_variable +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.tf_ops import get_gpu_devices @@ -423,14 +424,18 @@ def compute_actions( timestep = timestep if timestep is not None else self.global_timestep builder = TFRunBuilder(self.get_session(), "compute_actions") + + input_dict = {SampleBatch.OBS: obs_batch} + if state_batches: + for i, s in enumerate(state_batches): + input_dict[f"state_in_{i}"] = s + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch + to_fetch = self._build_compute_actions( - builder, - obs_batch=obs_batch, - state_batches=state_batches, - prev_action_batch=prev_action_batch, - prev_reward_batch=prev_reward_batch, - explore=explore, - timestep=timestep) + builder, input_dict=input_dict, explore=explore, timestep=timestep) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) @@ -1005,6 +1010,12 @@ def _build_compute_actions(self, # TODO: (sven) This can be deprecated after trajectory view API flag is # removed and always True. else: + if log_once("_build_compute_actions_input_dict"): + deprecation_warning( + old="_build_compute_actions(.., obs_batch=.., ..)", + new="_build_compute_actions(.., input_dict=..)", + error=False, + ) state_batches = state_batches or [] if len(self._state_inputs) != len(state_batches): raise ValueError( diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index fb7e9519ec878..f2ec7dfaadcc7 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -6,15 +6,16 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy from ray.rllib.policy import eager_tf_policy -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils import add_mixins, force_list from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.typing import AgentID, ModelGradients, PolicyID, \ - TensorType, TrainerConfigDict +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.typing import AgentID, ModelGradients, TensorType, \ + TrainerConfigDict if TYPE_CHECKING: from ray.rllib.evaluation import MultiAgentEpisode @@ -53,7 +54,7 @@ def build_tf_policy( extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[ str, TensorType]]] = None, validate_spaces: Optional[Callable[ - [PolicyID, gym.Space, gym.Space, TrainerConfigDict], None]] = None, + [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_loss_init: Optional[Callable[[ diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index f50729d005ed2..bf1c69410ff83 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -5,21 +5,23 @@ import math import numpy as np import os -import time import threading +import time +import tree # pip install dm_tree from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, \ TYPE_CHECKING import ray from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import force_list, NullContextManager from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.threading import with_lock @@ -703,6 +705,34 @@ def apply_gradients(self, gradients: ModelGradients) -> None: self._optimizers[0].step() + @DeveloperAPI + def get_tower_stats(self, stats_name: str) -> List[TensorStructType]: + """Returns list of per-tower stats, copied to this Policy's device. + + Args: + stats_name: The name of the stats to average over (this str + must exist as a key inside each tower's `tower_stats` dict). + + Returns: + The list of stats tensor (structs) of all towers, copied to this + Policy's device. + + Raises: + AssertionError: If the `stats_name` cannot be found in any one + of the tower's `tower_stats` dicts. + """ + data = [] + for tower in self.model_gpu_towers: + if stats_name in tower.tower_stats: + data.append( + tree.map_structure(lambda s: s.to(self.device), + tower.tower_stats[stats_name])) + assert len(data) > 0, \ + f"Stats `{stats_name}` not found in any of the towers (you have " \ + f"{len(self.model_gpu_towers)} towers in total)! Make " \ + "sure you call the loss function on at least one of the towers." + return data + @override(Policy) @DeveloperAPI def get_weights(self) -> ModelWeights: diff --git a/rllib/tests/test_exec_api.py b/rllib/tests/test_exec_api.py index b415c4faadf46..11339f08640b5 100644 --- a/rllib/tests/test_exec_api.py +++ b/rllib/tests/test_exec_api.py @@ -4,6 +4,7 @@ from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ STEPS_TRAINED_COUNTER +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import framework_iterator @@ -29,7 +30,7 @@ def test_exec_plan_stats(ray_start_regular): result = trainer.train() assert isinstance(result, dict) assert "info" in result - assert "learner" in result["info"] + assert LEARNER_INFO in result["info"] assert STEPS_SAMPLED_COUNTER in result["info"] assert STEPS_TRAINED_COUNTER in result["info"] assert "timers" in result diff --git a/rllib/tests/test_supported_multi_agent.py b/rllib/tests/test_supported_multi_agent.py index 0f4063bb2e886..2c114cec4d02f 100644 --- a/rllib/tests/test_supported_multi_agent.py +++ b/rllib/tests/test_supported_multi_agent.py @@ -4,7 +4,9 @@ from ray.rllib.agents.registry import get_trainer_class from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ MultiAgentMountainCar -from ray.rllib.utils.test_utils import framework_iterator +from ray.rllib.policy.policy import PolicySpec +from ray.rllib.utils.test_utils import check_train_results, \ + framework_iterator from ray.tune import register_env @@ -13,7 +15,23 @@ def check_support_multiagent(alg, config): lambda _: MultiAgentMountainCar({"num_agents": 2})) register_env("multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})) - config["log_level"] = "ERROR" + + # Simulate a simple multi-agent setup. + policies = { + "policy_0": PolicySpec(config={"gamma": 0.99}), + "policy_1": PolicySpec(config={"gamma": 0.95}), + } + policy_ids = list(policies.keys()) + + def policy_mapping_fn(agent_id, episode, worker, **kwargs): + pol_id = policy_ids[agent_id] + return pol_id + + config["multiagent"] = { + "policies": policies, + "policy_mapping_fn": policy_mapping_fn, + } + for fw in framework_iterator(config): if fw in ["tf2", "tfe"] and \ alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]: @@ -25,7 +43,9 @@ def check_support_multiagent(alg, config): a = get_trainer_class(alg)( config=config, env="multi_agent_cartpole") - print(a.train()) + results = a.train() + check_train_results(results) + print(results) a.stop() diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index 993558e77d223..d290d3ef87f68 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -69,6 +69,11 @@ def _do_check(alg, config, a_name, o_name): try: a = get_trainer_class(alg)(config=config, env=RandomEnv) + except ray.exceptions.RayActorError as e: + if isinstance(e.args[2], UnsupportedSpaceException): + stat = "unsupported" + else: + raise except UnsupportedSpaceException: stat = "unsupported" else: @@ -99,10 +104,11 @@ def _do_check(alg, config, a_name, o_name): _do_check(alg, config, a_name, o_name) # Do the remaining obs spaces. assert len(OBSERVATION_SPACES_TO_TEST) >= len(ACTION_SPACES_TO_TEST) + fixed_action_key = next(iter(ACTION_SPACES_TO_TEST.keys())) for i, o_name in enumerate(OBSERVATION_SPACES_TO_TEST.keys()): if i < len(ACTION_SPACES_TO_TEST): continue - _do_check(alg, config, "discrete", o_name) + _do_check(alg, config, fixed_action_key, o_name) class TestSupportedSpacesPG(unittest.TestCase): diff --git a/rllib/utils/__init__.py b/rllib/utils/__init__.py index 4f1f33083f01c..e720bfebfc468 100644 --- a/rllib/utils/__init__.py +++ b/rllib/utils/__init__.py @@ -11,7 +11,7 @@ from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \ PolynomialSchedule, ExponentialSchedule, ConstantSchedule from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator + check_train_results, framework_iterator from ray.tune.utils import merge_dicts, deep_update @@ -77,6 +77,7 @@ def __exit__(self, *args): "add_mixins", "check", "check_compute_single_action", + "check_train_results", "deep_update", "deprecation_warning", "fc", diff --git a/rllib/utils/exploration/stochastic_sampling.py b/rllib/utils/exploration/stochastic_sampling.py index daa6089d483b4..593233625de15 100644 --- a/rllib/utils/exploration/stochastic_sampling.py +++ b/rllib/utils/exploration/stochastic_sampling.py @@ -1,7 +1,7 @@ import functools import gym import numpy as np -from typing import Union +from typing import Optional, Union from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 @@ -61,11 +61,12 @@ def __init__(self, dtype=np.int64) @override(Exploration) - def get_exploration_action(self, - *, - action_distribution: ActionDistribution, - timestep: Union[int, TensorType], - explore: bool = True): + def get_exploration_action( + self, + *, + action_distribution: ActionDistribution, + timestep: Optional[Union[int, TensorType]] = None, + explore: bool = True): if self.framework == "torch": return self._get_torch_exploration_action(action_distribution, timestep, explore) @@ -74,7 +75,7 @@ def get_exploration_action(self, timestep, explore) def _get_tf_exploration_action_op(self, action_dist, timestep, explore): - ts = timestep if timestep is not None else self.last_timestep + 1 + ts = self.last_timestep + 1 stochastic_actions = tf.cond( pred=tf.convert_to_tensor(ts < self.random_timesteps), @@ -100,10 +101,7 @@ def _get_tf_exploration_action_op(self, action_dist, timestep, explore): # Increment `last_timestep` by 1 (or set to `timestep`). if self.framework in ["tf2", "tfe"]: - if timestep is None: - self.last_timestep.assign_add(1) - else: - self.last_timestep.assign(timestep) + self.last_timestep.assign_add(1) return action, logp else: assign_op = (tf1.assign_add(self.last_timestep, 1) diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rllib/utils/metrics/learner_info.py b/rllib/utils/metrics/learner_info.py new file mode 100644 index 0000000000000..ebe44a7c9fcda --- /dev/null +++ b/rllib/utils/metrics/learner_info.py @@ -0,0 +1,84 @@ +from collections import defaultdict +import numpy as np +import tree # pip install dm_tree +from typing import Dict + +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.typing import PolicyID + +# Instant metrics (keys for metrics.info). +LEARNER_INFO = "learner" +# By convention, metrics from optimizing the loss can be reported in the +# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. +LEARNER_STATS_KEY = "learner_stats" + + +class LearnerInfoBuilder: + def __init__(self, num_devices: int = 1): + self.num_devices = num_devices + self.results_all_towers = defaultdict(list) + self.is_finalized = False + + def add_learn_on_batch_results( + self, + results: Dict, + policy_id: PolicyID = DEFAULT_POLICY_ID, + ) -> None: + """Adds a policy.learn_on_(loaded)?_batch() result to this builder. + + Args: + results: The results returned by Policy.learn_on_batch or + Policy.learn_on_loaded_batch. + policy_id: The policy's ID, whose learn_on_(loaded)_batch method + returned `results`. + """ + assert not self.is_finalized, \ + "LearnerInfo already finalized! Cannot add more results." + + # No towers: Single CPU. + if "tower_0" not in results: + self.results_all_towers[policy_id].append(results) + # Multi-GPU case: + else: + self.results_all_towers[policy_id].append( + tree.map_structure_with_path( + lambda p, *s: all_tower_reduce(p, *s), + *(results.pop("tower_{}".format(tower_num)) + for tower_num in range(self.num_devices)))) + for k, v in results.items(): + if k == LEARNER_STATS_KEY: + for k1, v1 in results[k].items(): + self.results_all_towers[policy_id][-1][ + LEARNER_STATS_KEY][k1] = v1 + else: + self.results_all_towers[policy_id][-1][k] = v + + def finalize(self): + self.is_finalized = True + + info = {} + for policy_id, results_all_towers in self.results_all_towers.items(): + # Reduce mean across all minibatch SGD steps (axis=0 to keep + # all shapes as-is). + info[policy_id] = tree.map_structure( + lambda *s: None if s[0] is None else np.nanmean(s, axis=0), + *results_all_towers) + + return info + + +def all_tower_reduce(path, *tower_data): + """Reduces stats across towers based on their stats-dict paths.""" + # TD-errors: Need to stay per batch item in order to be able to update + # each item's weight in a prioritized replay buffer. + if len(path) == 1 and path[0] == "td_error": + return np.concatenate(tower_data, axis=0) + + # Min stats: Reduce min. + if path[-1].startswith("min_"): + return np.nanmin(tower_data) + # Max stats: Reduce max. + elif path[-1].startswith("max_"): + return np.nanmax(tower_data) + # Everything else: Reduce mean. + return np.nanmean(tower_data) diff --git a/rllib/utils/multi_agent.py b/rllib/utils/multi_agent.py index b23726cb393db..50d5227c54e75 100644 --- a/rllib/utils/multi_agent.py +++ b/rllib/utils/multi_agent.py @@ -1,9 +1,13 @@ +from typing import Tuple + from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.typing import PartialTrainerConfigDict +from ray.rllib.utils.typing import MultiAgentPolicyConfigDict, \ + PartialTrainerConfigDict -def check_multi_agent(config: PartialTrainerConfigDict): +def check_multi_agent(config: PartialTrainerConfigDict) -> \ + Tuple[MultiAgentPolicyConfigDict, bool]: """Checks, whether a (partial) config defines a multi-agent setup. Args: @@ -11,18 +15,25 @@ def check_multi_agent(config: PartialTrainerConfigDict): to check for multi-agent. Returns: - Tuple[MultiAgentPolicyConfigDict, bool]: The resulting (all - fixed) multi-agent policy dict and whether we have a - multi-agent setup or not. + The resulting (all fixed) multi-agent policy dict and whether we + have a multi-agent setup or not. """ multiagent_config = config["multiagent"] policies = multiagent_config.get("policies") + + # Nothing specified in config dict -> Assume simple single agent setup + # with DEFAULT_POLICY_ID as only policy. if not policies: policies = {DEFAULT_POLICY_ID} + # Policies given as set (of PolicyIDs) -> Setup each policy automatically + # via empty PolicySpec (will make RLlib infer obs- and action spaces + # as well as the Policy's class). if isinstance(policies, set): policies = multiagent_config["policies"] = { pid: PolicySpec() for pid in policies } + # Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only + # PolicyID found in policies dict. is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies return policies, is_multiagent diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index b163c2a36fcd4..6b4f060a95598 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -1,38 +1,17 @@ """Utils for minibatch SGD across multiple RLlib policies.""" -import numpy as np import logging -from collections import defaultdict +import numpy as np import random -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, \ MultiAgentBatch +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder logger = logging.getLogger(__name__) -def averaged(kv, axis=None): - """Average the value lists of a dictionary. - - For non-scalar values, we simply pick the first value. - - Args: - kv (dict): dictionary with values that are lists of floats. - - Returns: - dictionary with single averaged float as values. - """ - out = {} - for k, v in kv.items(): - if v[0] is not None and not isinstance(v[0], dict): - out[k] = np.mean(v, axis=axis) - else: - out[k] = v[0] - return out - - -def standardized(array): +def standardized(array: np.ndarray): """Normalize the values in an array. Args: @@ -107,7 +86,12 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) - fetches = defaultdict(dict) + # Use LearnerInfoBuilder as a unified way to build the final + # results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same structure + # no matter the setup (multi-GPU, multi-agent, minibatch SGD, + # tf vs torch). + learner_info_builder = LearnerInfoBuilder(num_devices=1) for policy_id in policies.keys(): if policy_id not in samples.policy_batches: continue @@ -116,23 +100,14 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, for field in standardize_fields: batch[field] = standardized(batch[field]) - learner_stats = defaultdict(list) - model_stats = defaultdict(list) - custom_callbacks_stats = defaultdict(list) - for i in range(num_sgd_iter): for minibatch in minibatches(batch, sgd_minibatch_size): - batch_fetches = (local_worker.learn_on_batch( + results = (local_worker.learn_on_batch( MultiAgentBatch({ policy_id: minibatch }, minibatch.count)))[policy_id] - for k, v in batch_fetches.get(LEARNER_STATS_KEY, {}).items(): - learner_stats[k].append(v) - for k, v in batch_fetches.get("model", {}).items(): - model_stats[k].append(v) - for k, v in batch_fetches.get("custom_metrics", {}).items(): - custom_callbacks_stats[k].append(v) - fetches[policy_id][LEARNER_STATS_KEY] = averaged(learner_stats) - fetches[policy_id]["model"] = averaged(model_stats) - fetches[policy_id]["custom_metrics"] = averaged(custom_callbacks_stats) - return fetches + learner_info_builder.add_learn_on_batch_results( + results, policy_id) + + learner_info = learner_info_builder.finalize() + return learner_info diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index f119d3806968f..5fcb16da6471e 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -1,10 +1,12 @@ from collections import Counter import copy -import gym +from gym.spaces import Box import logging import numpy as np +import random import re import time +import tree # pip install dm_tree from typing import Any, Dict, List import yaml @@ -29,7 +31,8 @@ def framework_iterator(config=None, frameworks=("tf2", "tf", "tfe", "torch"), - session=False): + session=False, + with_eager_tracing=False): """An generator that allows for looping through n frameworks for testing. Provides the correct config entries ("framework") as well @@ -44,6 +47,8 @@ def framework_iterator(config=None, and yield that as second return value (otherwise yield (fw, None)). Also sets a seed (42) on the session to make the test deterministic. + with_eager_tracing: Include `eager_tracing=True` in the returned + configs, when framework=[tfe|tf2]. Yields: str: If enter_session is False: @@ -103,7 +108,15 @@ def framework_iterator(config=None, elif fw == "tf": assert not tf1.executing_eagerly() - yield fw if session is False else (fw, sess) + # Additionally loop through eager_tracing=True + False, if necessary. + if fw in ["tf2", "tfe"] and with_eager_tracing: + for tracing in [True, False]: + config["eager_tracing"] = tracing + yield fw if session is False else (fw, sess) + config["eager_tracing"] = False + # Yield current framework + tf-session (if necessary). + else: + yield fw if session is False else (fw, sess) # Exit any context we may have entered. if eager_ctx: @@ -260,31 +273,6 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False): "ERROR: x ({}) is the same as y ({})!".format(x, y) -def check_learning_achieved(tune_results, min_reward, evaluation=False): - """Throws an error if `min_reward` is not reached within tune_results. - - Checks the last iteration found in tune_results for its - "episode_reward_mean" value and compares it to `min_reward`. - - Args: - tune_results: The tune.run returned results object. - min_reward (float): The min reward that must be reached. - - Raises: - ValueError: If `min_reward` not reached. - """ - # Get maximum reward of all trials - # (check if at least one trial achieved some learning) - avg_rewards = [(trial.last_result["episode_reward_mean"] - if not evaluation else - trial.last_result["evaluation"]["episode_reward_mean"]) - for trial in tune_results.trials] - best_avg_reward = max(avg_rewards) - if best_avg_reward < min_reward: - raise ValueError("`stop-reward` of {} not reached!".format(min_reward)) - print("ok") - - def check_compute_single_action(trainer, include_state=False, include_prev_action_reward=False): @@ -300,17 +288,120 @@ def check_compute_single_action(trainer, Raises: ValueError: If anything unexpected happens. """ + # Have to import this here to avoid circular dependency. + from ray.rllib.policy.sample_batch import SampleBatch + + # Some Trainers may not abide to the standard API. try: pol = trainer.get_policy() except AttributeError: pol = trainer.policy + # Get the policy's model. model = pol.model action_space = pol.action_space + def _test(what, method_to_test, obs_space, full_fetch, explore, timestep, + unsquash, clip): + call_kwargs = {} + if what is trainer: + call_kwargs["full_fetch"] = full_fetch + + obs = obs_space.sample() + if isinstance(obs_space, Box): + obs = np.clip(obs, -1.0, 1.0) + state_in = None + if include_state: + state_in = model.get_initial_state() + if not state_in: + state_in = [] + i = 0 + while f"state_in_{i}" in model.view_requirements: + state_in.append(model.view_requirements[f"state_in_{i}"] + .space.sample()) + i += 1 + action_in = action_space.sample() \ + if include_prev_action_reward else None + reward_in = 1.0 if include_prev_action_reward else None + + if method_to_test == "input_dict": + assert what is pol + + input_dict = {SampleBatch.OBS: obs} + if include_prev_action_reward: + input_dict[SampleBatch.PREV_ACTIONS] = action_in + input_dict[SampleBatch.PREV_REWARDS] = reward_in + if state_in: + for i, s in enumerate(state_in): + input_dict[f"state_in_{i}"] = s + input_dict_batched = SampleBatch( + tree.map_structure(lambda s: np.expand_dims(s, 0), input_dict)) + action = pol.compute_actions_from_input_dict( + input_dict=input_dict_batched, + explore=explore, + timestep=timestep, + **call_kwargs) + # Unbatch everything to be able to compare against single + # action below. + # ARS and ES return action batches as lists. + if isinstance(action[0], list): + action = (np.array(action[0]), action[1], action[2]) + action = tree.map_structure(lambda s: s[0], action) + + try: + action2 = pol.compute_single_action( + input_dict=input_dict, + explore=explore, + timestep=timestep, + **call_kwargs) + # Make sure these are the same, unless we have exploration + # switched on (or noisy layers). + if not explore and not pol.config.get("noisy"): + check(action, action2) + except TypeError: + pass + else: + action = what.compute_single_action( + obs, + state_in, + prev_action=action_in, + prev_reward=reward_in, + explore=explore, + timestep=timestep, + unsquash_action=unsquash, + clip_action=clip, + **call_kwargs) + + state_out = None + if state_in or full_fetch or what is pol: + action, state_out, _ = action + if state_out: + for si, so in zip(state_in, state_out): + check(list(si.shape), so.shape) + + # Test whether unsquash/clipping works on the Trainer's + # compute_single_action method: Both flags should force the action + # to be within the space's bounds. + if method_to_test == "single" and what == trainer: + if not action_space.contains(action) and \ + (clip or unsquash or not isinstance(action_space, Box)): + raise ValueError( + f"Returned action ({action}) of trainer/policy {what} " + f"not in Env's action_space {action_space}") + # We are operating in normalized space: Expect only smaller action + # values. + if isinstance(action_space, Box) and not unsquash and \ + what.config.get("normalize_actions") and \ + np.any(np.abs(action) > 3.0): + raise ValueError( + f"Returned action ({action}) of trainer/policy {what} " + "should be in normalized space, but seems too large/small " + "for that!") + + # Loop through: Policy vs Trainer; Different API methods to calculate + # actions; unsquash option; clip option; full fetch or not. for what in [pol, trainer]: if what is trainer: - method_to_test = trainer.compute_single_action # Get the obs-space from Workers.env (not Policy) due to possible # pre-processor up front. worker_set = getattr(trainer, "workers", @@ -323,53 +414,134 @@ def check_compute_single_action(trainer, lambda p: p.observation_space) obs_space = getattr(obs_space, "original_space", obs_space) else: - method_to_test = pol.compute_single_action obs_space = pol.observation_space - for explore in [True, False]: - for full_fetch in ([False, True] if what is trainer else [False]): - call_kwargs = {} - if what is trainer: - call_kwargs["full_fetch"] = full_fetch - else: - call_kwargs["clip_actions"] = True - - obs = obs_space.sample() - if isinstance(obs_space, gym.spaces.Box): - obs = np.clip(obs, -1.0, 1.0) - state_in = None - if include_state: - state_in = model.get_initial_state() - if not state_in: - state_in = [] - i = 0 - while f"state_in_{i}" in model.view_requirements: - state_in.append(model.view_requirements[ - f"state_in_{i}"].space.sample()) - i += 1 - action_in = action_space.sample() \ - if include_prev_action_reward else None - reward_in = 1.0 if include_prev_action_reward else None - action = method_to_test( - obs, - state_in, - prev_action=action_in, - prev_reward=reward_in, - explore=explore, - **call_kwargs) + for method_to_test in ["single"] + \ + (["input_dict"] if what is pol else []): + for explore in [True, False]: + for full_fetch in ([False, True] + if what is trainer else [False]): + timestep = random.randint(0, 100000) + for unsquash in [True, False]: + for clip in ([False] if unsquash else [True, False]): + _test(what, method_to_test, obs_space, full_fetch, + explore, timestep, unsquash, clip) - state_out = None - if state_in or full_fetch or what is pol: - action, state_out, _ = action - if state_out: - for si, so in zip(state_in, state_out): - check(list(si.shape), so.shape) - if not action_space.contains(action): - raise ValueError( - "Returned action ({}) of trainer/policy {} not in " - "Env's action_space " - "({})!".format(action, what, action_space)) +def check_learning_achieved(tune_results, min_reward, evaluation=False): + """Throws an error if `min_reward` is not reached within tune_results. + + Checks the last iteration found in tune_results for its + "episode_reward_mean" value and compares it to `min_reward`. + + Args: + tune_results: The tune.run returned results object. + min_reward (float): The min reward that must be reached. + + Raises: + ValueError: If `min_reward` not reached. + """ + # Get maximum reward of all trials + # (check if at least one trial achieved some learning) + avg_rewards = [(trial.last_result["episode_reward_mean"] + if not evaluation else + trial.last_result["evaluation"]["episode_reward_mean"]) + for trial in tune_results.trials] + best_avg_reward = max(avg_rewards) + if best_avg_reward < min_reward: + raise ValueError("`stop-reward` of {} not reached!".format(min_reward)) + print("ok") + + +def check_train_results(train_results): + """Checks proper structure of a Trainer.train() returned dict. + + Args: + train_results: The train results dict to check. + + Raises: + AssertionError: If `train_results` doesn't have the proper structure or + data in it. + """ + # Import these here to avoid circular dependencies. + from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID + from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ + LEARNER_STATS_KEY + from ray.rllib.utils.multi_agent import check_multi_agent + + # Assert that some keys are where we would expect them. + for key in [ + "agent_timesteps_total", + "config", + "custom_metrics", + "episode_len_mean", + "episode_reward_max", + "episode_reward_mean", + "episode_reward_min", + "episodes_total", + "hist_stats", + "info", + "iterations_since_restore", + "num_healthy_workers", + "perf", + "policy_reward_max", + "policy_reward_mean", + "policy_reward_min", + "sampler_perf", + "time_since_restore", + "time_this_iter_s", + "timesteps_since_restore", + "timesteps_total", + "timers", + "time_total_s", + "training_iteration", + ]: + assert key in train_results, \ + f"'{key}' not found in `train_results` ({train_results})!" + + _, is_multi_agent = check_multi_agent(train_results["config"]) + + # Check in particular the "info" dict. + info = train_results["info"] + assert LEARNER_INFO in info, \ + f"'learner' not in train_results['infos'] ({info})!" + assert "num_steps_trained" in info,\ + f"'num_steps_trained' not in train_results['infos'] ({info})!" + + learner_info = info[LEARNER_INFO] + + # Make sure we have a default_policy key if we are not in a + # multi-agent setup. + if not is_multi_agent: + # APEX algos sometimes have an empty learner info dict (no metrics + # collected yet). + assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, \ + f"'{DEFAULT_POLICY_ID}' not found in " \ + f"train_results['infos']['learner'] ({learner_info})!" + + for pid, policy_stats in learner_info.items(): + if pid == "batch_count": + continue + # Expect td-errors to be per batch-item. + if "td_error" in policy_stats: + configured_b = train_results["config"]["train_batch_size"] + actual_b = policy_stats["td_error"].shape[0] + # R2D2 case. + if (configured_b - actual_b) / actual_b > 0.1: + assert configured_b / ( + train_results["config"]["model"]["max_seq_len"] + + train_results["config"]["burn_in"]) == actual_b + + # Make sure each policy has the LEARNER_STATS_KEY under it. + assert LEARNER_STATS_KEY in policy_stats + learner_stats = policy_stats[LEARNER_STATS_KEY] + for key, value in learner_stats.items(): + # Min- and max-stats should be single values. + if key.startswith("min_") or key.startswith("max_"): + assert np.isscalar( + value), f"'key' value not a scalar ({value})!" + + return train_results def run_learning_tests_from_yaml( diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index 20b0ea3d75f98..1b577be7ef727 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -146,7 +146,7 @@ def zero_logps_from_actions(actions: TensorStructType) -> TensorType: # `deterministic_actions` or `stochastic_actions`). In case # actions are just [B], zeros_like works just fine here, but if # actions are [B, ...], we have to reduce logp back to just [B]. - if len(logp_.shape) > 1: + while len(logp_.shape) > 1: logp_ = logp_[:, 0] return logp_ diff --git a/rllib/utils/tf_run_builder.py b/rllib/utils/tf_run_builder.py index 82b904bd13164..28a48558f73e7 100644 --- a/rllib/utils/tf_run_builder.py +++ b/rllib/utils/tf_run_builder.py @@ -59,7 +59,10 @@ def get(self, to_fetch): _count = 0 -def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None): +def run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None): + if feed_dict is None: + feed_dict = {} + if timeline_dir: from tensorflow.python.client import timeline diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index a27be53cc2695..90ccc64aad126 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -48,8 +48,8 @@ def atanh(x): def concat_multi_gpu_td_errors(policy): td_error = torch.cat( [ - getattr(t, "td_error", torch.tensor([0.0])).to(policy.device) - for t in policy.model_gpu_towers + t.tower_stats.get("td_error", torch.tensor([0.0])).to( + policy.device) for t in policy.model_gpu_towers ], dim=0) policy.td_error = td_error @@ -132,7 +132,7 @@ def explained_variance(y, pred): y_var = torch.var(y, dim=[0]) diff_var = torch.var(y - pred, dim=[0]) min_ = torch.tensor([-1.0]).to(pred.device) - return torch.max(min_, 1 - (diff_var / y_var)) + return torch.max(min_, 1 - (diff_var / y_var))[0] def global_norm(tensors): diff --git a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h index 483464c1ff6eb..56be36f4c87ff 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h @@ -17,6 +17,7 @@ namespace gcs { class MockGcsNodeManager : public GcsNodeManager { public: + MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr) {} MOCK_METHOD(void, HandleRegisterNode, (const rpc::RegisterNodeRequest &request, rpc::RegisterNodeReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h b/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h index f612e6d1d2841..627e3357879e7 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h +++ b/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h @@ -1,4 +1,4 @@ -// Copyright The Ray Authors. +// Copyright 2021 The Ray Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,8 +30,8 @@ class MockGcsPlacementGroupSchedulerInterface public: MOCK_METHOD(void, ScheduleUnplacedBundles, (std::shared_ptr placement_group, - std::function)> failure_callback, - std::function)> success_callback), + PGSchedulingFailureCallback failure_callback, + PGSchedulingSuccessfulCallback success_callback), (override)); MOCK_METHOD((absl::flat_hash_map>), GetBundlesOnNode, (const NodeID &node_id), (override)); @@ -63,11 +63,12 @@ namespace gcs { class MockGcsScheduleStrategy : public GcsScheduleStrategy { public: - MOCK_METHOD(ScheduleMap, Schedule, - (std::vector> & bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD( + ScheduleResult, Schedule, + (const std::vector> &bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -78,11 +79,12 @@ namespace gcs { class MockGcsPackStrategy : public GcsPackStrategy { public: - MOCK_METHOD(ScheduleMap, Schedule, - (std::vector> & bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD( + ScheduleResult, Schedule, + (const std::vector> &bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -93,11 +95,12 @@ namespace gcs { class MockGcsSpreadStrategy : public GcsSpreadStrategy { public: - MOCK_METHOD(ScheduleMap, Schedule, - (std::vector> & bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD( + ScheduleResult, Schedule, + (const std::vector> &bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -108,11 +111,12 @@ namespace gcs { class MockGcsStrictPackStrategy : public GcsStrictPackStrategy { public: - MOCK_METHOD(ScheduleMap, Schedule, - (std::vector> & bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD( + ScheduleResult, Schedule, + (const std::vector> &bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -123,11 +127,12 @@ namespace gcs { class MockGcsStrictSpreadStrategy : public GcsStrictSpreadStrategy { public: - MOCK_METHOD(ScheduleMap, Schedule, - (std::vector> & bundles, - const std::unique_ptr &context, - GcsResourceScheduler &gcs_resource_scheduler), - (override)); + MOCK_METHOD( + ScheduleResult, Schedule, + (const std::vector> &bundles, + const std::unique_ptr &context, + GcsResourceScheduler &gcs_resource_scheduler), + (override)); }; } // namespace gcs @@ -160,8 +165,8 @@ class MockGcsPlacementGroupScheduler : public GcsPlacementGroupScheduler { public: MOCK_METHOD(void, ScheduleUnplacedBundles, (std::shared_ptr placement_group, - std::function)> failure_handler, - std::function)> success_handler), + PGSchedulingFailureCallback failure_handler, + PGSchedulingSuccessfulCallback success_handler), (override)); MOCK_METHOD(void, DestroyPlacementGroupBundleResourcesIfExists, (const PlacementGroupID &placement_group_id), (override)); diff --git a/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h b/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h index d981be23a5472..764bee572cabc 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h @@ -17,6 +17,7 @@ namespace gcs { class MockGcsResourceManager : public GcsResourceManager { public: + using GcsResourceManager::GcsResourceManager; MOCK_METHOD(void, HandleGetResources, (const rpc::GetResourcesRequest &request, rpc::GetResourcesReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/pubsub/gcs_pub_sub.h b/src/mock/ray/gcs/pubsub/gcs_pub_sub.h new file mode 100644 index 0000000000000..21e500da0a002 --- /dev/null +++ b/src/mock/ray/gcs/pubsub/gcs_pub_sub.h @@ -0,0 +1,27 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace gcs { + +class MockGcsPubSub : public GcsPubSub { + public: + MOCK_METHOD(Status, Publish, + (const std::string &channel, const std::string &id, const std::string &data, + const StatusCallback &done), + (override)); +}; + +} // namespace gcs +} // namespace ray diff --git a/src/mock/ray/gcs/store_client/in_memory_store_client.h b/src/mock/ray/gcs/store_client/in_memory_store_client.h new file mode 100644 index 0000000000000..08af16a075a17 --- /dev/null +++ b/src/mock/ray/gcs/store_client/in_memory_store_client.h @@ -0,0 +1,66 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace gcs { + +class MockInMemoryStoreClient : public InMemoryStoreClient { + public: + MOCK_METHOD(Status, AsyncPut, + (const std::string &table_name, const std::string &key, + const std::string &data, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncPutWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const std::string &data, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGet, + (const std::string &table_name, const std::string &key, + const OptionalItemCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGetByIndex, + (const std::string &table_name, const std::string &index_key, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncGetAll, + (const std::string &table_name, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncDelete, + (const std::string &table_name, const std::string &key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDelete, + (const std::string &table_name, const std::vector &keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, + (const std::string &table_name, const std::vector &keys, + const std::vector &index_keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteByIndex, + (const std::string &table_name, const std::string &index_key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(int, GetNextJobID, (), (override)); +}; + +} // namespace gcs +} // namespace ray diff --git a/src/mock/ray/gcs/store_client/redis_store_client.h b/src/mock/ray/gcs/store_client/redis_store_client.h new file mode 100644 index 0000000000000..153a69755d3b7 --- /dev/null +++ b/src/mock/ray/gcs/store_client/redis_store_client.h @@ -0,0 +1,67 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace gcs { + +class MockRedisStoreClient : public RedisStoreClient { + public: + MockRedisStoreClient() : RedisStoreClient(nullptr) {} + MOCK_METHOD(Status, AsyncPut, + (const std::string &table_name, const std::string &key, + const std::string &data, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncPutWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const std::string &data, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGet, + (const std::string &table_name, const std::string &key, + const OptionalItemCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGetByIndex, + (const std::string &table_name, const std::string &index_key, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncGetAll, + (const std::string &table_name, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncDelete, + (const std::string &table_name, const std::string &key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDelete, + (const std::string &table_name, const std::vector &keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, + (const std::string &table_name, const std::vector &keys, + const std::vector &index_keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteByIndex, + (const std::string &table_name, const std::string &index_key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(int, GetNextJobID, (), (override)); +}; + +} // namespace gcs +} // namespace ray diff --git a/src/mock/ray/gcs/store_client/store_client.h b/src/mock/ray/gcs/store_client/store_client.h new file mode 100644 index 0000000000000..6f4e3b5382735 --- /dev/null +++ b/src/mock/ray/gcs/store_client/store_client.h @@ -0,0 +1,66 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace gcs { + +class MockStoreClient : public StoreClient { + public: + MOCK_METHOD(Status, AsyncPut, + (const std::string &table_name, const std::string &key, + const std::string &data, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncPutWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const std::string &data, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGet, + (const std::string &table_name, const std::string &key, + const OptionalItemCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncGetByIndex, + (const std::string &table_name, const std::string &index_key, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncGetAll, + (const std::string &table_name, + (const MapCallback &callback)), + (override)); + MOCK_METHOD(Status, AsyncDelete, + (const std::string &table_name, const std::string &key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteWithIndex, + (const std::string &table_name, const std::string &key, + const std::string &index_key, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDelete, + (const std::string &table_name, const std::vector &keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, + (const std::string &table_name, const std::vector &keys, + const std::vector &index_keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, AsyncDeleteByIndex, + (const std::string &table_name, const std::string &index_key, + const StatusCallback &callback), + (override)); + MOCK_METHOD(int, GetNextJobID, (), (override)); +}; + +} // namespace gcs +} // namespace ray diff --git a/src/mock/ray/pubsub/publisher.h b/src/mock/ray/pubsub/publisher.h new file mode 100644 index 0000000000000..7094a9afadeac --- /dev/null +++ b/src/mock/ray/pubsub/publisher.h @@ -0,0 +1,100 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace pubsub { +namespace pub_internal { + +template +class MockSubscriptionIndex : public SubscriptionIndex { + public: +}; + +} // namespace pub_internal +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { +namespace pub_internal { + +class MockLongPollConnection : public LongPollConnection { + public: +}; + +} // namespace pub_internal +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { +namespace pub_internal { + +class MockSubscriber : public Subscriber { + public: +}; + +} // namespace pub_internal +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockPublisherInterface : public PublisherInterface { + public: + MOCK_METHOD(bool, RegisterSubscription, + (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + const std::string &key_id_binary), + (override)); + MOCK_METHOD(void, Publish, + (const rpc::ChannelType channel_type, const rpc::PubMessage &pub_message, + const std::string &key_id_binary), + (override)); + MOCK_METHOD(void, PublishFailure, + (const rpc::ChannelType channel_type, const std::string &key_id_binary), + (override)); + MOCK_METHOD(bool, UnregisterSubscription, + (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + const std::string &key_id_binary), + (override)); +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockPublisher : public Publisher { + public: + MOCK_METHOD(bool, RegisterSubscription, + (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + const std::string &key_id_binary), + (override)); + MOCK_METHOD(void, Publish, + (const rpc::ChannelType channel_type, const rpc::PubMessage &pub_message, + const std::string &key_id_binary), + (override)); + MOCK_METHOD(void, PublishFailure, + (const rpc::ChannelType channel_type, const std::string &key_id_binary), + (override)); + MOCK_METHOD(bool, UnregisterSubscription, + (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + const std::string &key_id_binary), + (override)); +}; + +} // namespace pubsub +} // namespace ray diff --git a/src/mock/ray/pubsub/subscriber.h b/src/mock/ray/pubsub/subscriber.h new file mode 100644 index 0000000000000..38dc5f32afb65 --- /dev/null +++ b/src/mock/ray/pubsub/subscriber.h @@ -0,0 +1,155 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace pubsub { + +template +class MockSubscriptionInfo : public SubscriptionInfo { + public: +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockSubscribeChannelInterface : public SubscribeChannelInterface { + public: + MOCK_METHOD(void, Subscribe, + (const rpc::Address &publisher_address, const std::string &key_id_binary, + SubscriptionCallback subscription_callback, + SubscriptionFailureCallback subscription_failure_callback), + (override)); + MOCK_METHOD(bool, Unsubscribe, + (const rpc::Address &publisher_address, const std::string &key_id_binary), + (override)); + MOCK_METHOD(void, HandlePublishedMessage, + (const rpc::Address &publisher_address, const rpc::PubMessage &pub_message), + (const, override)); + MOCK_METHOD(void, HandlePublisherFailure, (const rpc::Address &publisher_address), + (override)); + MOCK_METHOD(void, HandlePublisherFailure, + (const rpc::Address &publisher_address, const std::string &key_id_binary), + (override)); + MOCK_METHOD(bool, SubscriptionExists, (const PublisherID &publisher_id), (override)); + MOCK_METHOD(const rpc::ChannelType, GetChannelType, (), (const, override)); + MOCK_METHOD(bool, CheckNoLeaks, (), (const, override)); + MOCK_METHOD(std::string, DebugString, (), (const, override)); +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +template +class MockSubscriberChannel : public SubscriberChannel { + public: + MOCK_METHOD(void, Subscribe, + (const rpc::Address &publisher_address, const std::string &key_id, + SubscriptionCallback subscription_callback, + SubscriptionFailureCallback subscription_failure_callback), + (override)); + MOCK_METHOD(bool, Unsubscribe, + (const rpc::Address &publisher_address, const std::string &key_id), + (override)); + MOCK_METHOD(bool, CheckNoLeaks, (), (const, override)); + MOCK_METHOD(void, HandlePublishedMessage, + (const rpc::Address &publisher_address, const rpc::PubMessage &pub_message), + (const, override)); + MOCK_METHOD(void, HandlePublisherFailure, (const rpc::Address &publisher_address), + (override)); + MOCK_METHOD(void, HandlePublisherFailure, + (const rpc::Address &publisher_address, const std::string &key_id_binary), + (override)); + MOCK_METHOD(bool, SubscriptionExists, (const PublisherID &publisher_id), (override)); + MOCK_METHOD(const rpc::ChannelType, GetChannelType, (), (const, override)); + MOCK_METHOD(std::string, DebugString, (), (const, override)); +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockWaitForObjectEvictionChannel : public WaitForObjectEvictionChannel { + public: +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockWaitForRefRemovedChannel : public WaitForRefRemovedChannel { + public: +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockObjectLocationsChannel : public ObjectLocationsChannel { + public: +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockSubscriberInterface : public SubscriberInterface { + public: + MOCK_METHOD(void, Subscribe, + (std::unique_ptr sub_message, + const rpc::ChannelType channel_type, const rpc::Address &publisher_address, + const std::string &key_id_binary, + SubscriptionCallback subscription_callback, + SubscriptionFailureCallback subscription_failure_callback), + (override)); + MOCK_METHOD(bool, Unsubscribe, + (const rpc::ChannelType channel_type, const rpc::Address &publisher_address, + const std::string &key_id_binary), + (override)); + MOCK_METHOD(std::string, DebugString, (), (const, override)); +}; + +} // namespace pubsub +} // namespace ray + +namespace ray { +namespace pubsub { + +class MockSubscriberClientInterface : public SubscriberClientInterface { + public: + MOCK_METHOD(void, PubsubLongPolling, + (const rpc::PubsubLongPollingRequest &request, + const rpc::ClientCallback &callback), + (override)); + MOCK_METHOD(void, PubsubCommandBatch, + (const rpc::PubsubCommandBatchRequest &request, + const rpc::ClientCallback &callback), + (override)); +}; + +} // namespace pubsub +} // namespace ray diff --git a/src/mock/ray/raylet_client/raylet_client.h b/src/mock/ray/raylet_client/raylet_client.h index cafd952e5d6e4..f48f37b90c507 100644 --- a/src/mock/ray/raylet_client/raylet_client.h +++ b/src/mock/ray/raylet_client/raylet_client.h @@ -35,6 +35,12 @@ class MockWorkerLeaseInterface : public WorkerLeaseInterface { const ray::rpc::ClientCallback &callback, const int64_t backlog_size), (override)); + MOCK_METHOD( + void, RequestWorkerLease, + (const rpc::TaskSpec &task_spec, + const ray::rpc::ClientCallback &callback, + const int64_t backlog_size), + (override)); MOCK_METHOD(ray::Status, ReturnWorker, (int worker_port, const WorkerID &worker_id, bool disconnect_worker), (override)); @@ -66,7 +72,7 @@ class MockResourceReserveInterface : public ResourceReserveInterface { (override)); MOCK_METHOD( void, CancelResourceReserve, - (BundleSpecification & bundle_spec, + (const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback), (override)); MOCK_METHOD(void, ReleaseUnusedBundles, @@ -106,31 +112,6 @@ class MockResourceTrackingInterface : public ResourceTrackingInterface { namespace ray { class MockRayletClientInterface : public RayletClientInterface { - public: - MOCK_METHOD(void, GetSystemConfig, - (const rpc::ClientCallback &callback), - (override)); - MOCK_METHOD(void, GetGcsServerAddress, - (const rpc::ClientCallback &callback), - (override)); -}; - -} // namespace ray - -namespace ray { -namespace raylet { - -class MockRayletConnection : public RayletConnection { - public: -}; - -} // namespace raylet -} // namespace ray - -namespace ray { -namespace raylet { - -class MockRayletClient : public RayletClient { public: MOCK_METHOD(ray::Status, WaitForDirectActorCallArgs, (const std::vector &references, int64_t tag), @@ -141,6 +122,13 @@ class MockRayletClient : public RayletClient { const ray::rpc::ClientCallback &callback, const int64_t backlog_size), (override)); + MOCK_METHOD( + void, RequestWorkerLease, + (const rpc::TaskSpec &resource_spec, + const ray::rpc::ClientCallback &callback, + const int64_t backlog_size), + (override)); + MOCK_METHOD(ray::Status, ReturnWorker, (int worker_port, const WorkerID &worker_id, bool disconnect_worker), (override)); @@ -164,7 +152,7 @@ class MockRayletClient : public RayletClient { (override)); MOCK_METHOD( void, CancelResourceReserve, - (BundleSpecification & bundle_spec, + (const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback), (override)); MOCK_METHOD(void, ReleaseUnusedBundles, @@ -191,5 +179,4 @@ class MockRayletClient : public RayletClient { (override)); }; -} // namespace raylet } // namespace ray diff --git a/src/mock/ray/rpc/worker/core_worker_client.h b/src/mock/ray/rpc/worker/core_worker_client.h new file mode 100644 index 0000000000000..a4646cef99e16 --- /dev/null +++ b/src/mock/ray/rpc/worker/core_worker_client.h @@ -0,0 +1,123 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace rpc { + +class MockWorkerAddress : public WorkerAddress { + public: +}; + +} // namespace rpc +} // namespace ray + +namespace ray { +namespace rpc { + +class MockCoreWorkerClientInterface : public ray::pubsub::MockSubscriberClientInterface, + public CoreWorkerClientInterface { + public: + MOCK_METHOD(const rpc::Address &, Addr, (), (const, override)); + MOCK_METHOD(void, PushActorTask, + (std::unique_ptr request, bool skip_queue, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, PushNormalTask, + (std::unique_ptr request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, StealTasks, + (std::unique_ptr request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, DirectActorCallArgWaitComplete, + (const DirectActorCallArgWaitCompleteRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, GetObjectStatus, + (const GetObjectStatusRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, WaitForActorOutOfScope, + (const WaitForActorOutOfScopeRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, PubsubLongPolling, + (const PubsubLongPollingRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, PubsubCommandBatch, + (const PubsubCommandBatchRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, UpdateObjectLocationBatch, + (const UpdateObjectLocationBatchRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, GetObjectLocationsOwner, + (const GetObjectLocationsOwnerRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, KillActor, + (const KillActorRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, CancelTask, + (const CancelTaskRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, RemoteCancelTask, + (const RemoteCancelTaskRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, GetCoreWorkerStats, + (const GetCoreWorkerStatsRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, LocalGC, + (const LocalGCRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, SpillObjects, + (const SpillObjectsRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, RestoreSpilledObjects, + (const RestoreSpilledObjectsRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, DeleteSpilledObjects, + (const DeleteSpilledObjectsRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, AddSpilledUrl, + (const AddSpilledUrlRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, PlasmaObjectReady, + (const PlasmaObjectReadyRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(void, Exit, + (const ExitRequest &request, const ClientCallback &callback), + (override)); + MOCK_METHOD(void, AssignObjectOwner, + (const AssignObjectOwnerRequest &request, + const ClientCallback &callback), + (override)); + MOCK_METHOD(int64_t, ClientProcessedUpToSeqno, (), (override)); +}; + +} // namespace rpc +} // namespace ray diff --git a/src/mock/ray/rpc/worker/core_worker_client_pool.h b/src/mock/ray/rpc/worker/core_worker_client_pool.h new file mode 100644 index 0000000000000..d4e1ec607e5a2 --- /dev/null +++ b/src/mock/ray/rpc/worker/core_worker_client_pool.h @@ -0,0 +1,23 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace ray { +namespace rpc { + +class MockCoreWorkerClientPool : public CoreWorkerClientPool { + public: +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/common/bundle_spec.cc b/src/ray/common/bundle_spec.cc index 339a492360d21..c5b4a711e0275 100644 --- a/src/ray/common/bundle_spec.cc +++ b/src/ray/common/bundle_spec.cc @@ -74,6 +74,10 @@ PlacementGroupID BundleSpecification::PlacementGroupId() const { return PlacementGroupID::FromBinary(message_->bundle_id().placement_group_id()); } +NodeID BundleSpecification::NodeId() const { + return NodeID::FromBinary(message_->node_id()); +} + int64_t BundleSpecification::Index() const { return message_->bundle_id().bundle_index(); } @@ -89,16 +93,19 @@ std::string BundleSpecification::DebugString() const { std::string FormatPlacementGroupResource(const std::string &original_resource_name, const PlacementGroupID &group_id, int64_t bundle_index) { - std::string str; + std::stringstream os; if (bundle_index >= 0) { - str = original_resource_name + "_group_" + std::to_string(bundle_index) + "_" + - group_id.Hex(); + os << original_resource_name << kGroupKeyword << std::to_string(bundle_index) << "_" + << group_id.Hex(); } else { RAY_CHECK(bundle_index == -1) << "Invalid index " << bundle_index; - str = original_resource_name + "_group_" + group_id.Hex(); + os << original_resource_name << kGroupKeyword << group_id.Hex(); } - RAY_CHECK(GetOriginalResourceName(str) == original_resource_name) << str; - return str; + std::string result = os.str(); + RAY_DCHECK(GetOriginalResourceName(result) == original_resource_name) + << "Generated: " << GetOriginalResourceName(result) + << " Original: " << original_resource_name; + return result; } std::string FormatPlacementGroupResource(const std::string &original_resource_name, @@ -109,12 +116,12 @@ std::string FormatPlacementGroupResource(const std::string &original_resource_na bool IsBundleIndex(const std::string &resource, const PlacementGroupID &group_id, const int bundle_index) { - return resource.find("_group_" + std::to_string(bundle_index) + "_" + group_id.Hex()) != - std::string::npos; + return resource.find(kGroupKeyword + std::to_string(bundle_index) + "_" + + group_id.Hex()) != std::string::npos; } std::string GetOriginalResourceName(const std::string &resource) { - auto idx = resource.find("_group_"); + auto idx = resource.find(kGroupKeyword); RAY_CHECK(idx >= 0) << "This isn't a placement group resource " << resource; return resource.substr(0, idx); } diff --git a/src/ray/common/bundle_spec.h b/src/ray/common/bundle_spec.h index 8437704509b58..bca5396fdc71a 100644 --- a/src/ray/common/bundle_spec.h +++ b/src/ray/common/bundle_spec.h @@ -32,6 +32,9 @@ typedef std::function ScheduleBundleCallback; /// address and the raylet's port. typedef std::function SpillbackBundleCallback; +const std::string kGroupKeyword = "_group_"; +const size_t kGroupKeywordSize = kGroupKeyword.size(); + class BundleSpecification : public MessageWrapper { public: /// Construct from a protobuf message object. @@ -54,6 +57,9 @@ class BundleSpecification : public MessageWrapper { // Return the Placement Group id which the Bundle belong to. PlacementGroupID PlacementGroupId() const; + // Get a node ID that this bundle is scheduled on. + NodeID NodeId() const; + // Return the index of the bundle. int64_t Index() const; diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index 73743820b2b9b..7eb51a953e215 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/ray/common/id.h b/src/ray/common/id.h index 0fc5d45599392..889128e81df11 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -492,6 +492,7 @@ std::string BaseID::Hex() const { constexpr char hex[] = "0123456789abcdef"; const uint8_t *id = Data(); std::string result; + result.reserve(T::Size()); for (size_t i = 0; i < T::Size(); i++) { unsigned int val = id[i]; result.push_back(hex[val >> 4]); diff --git a/src/ray/common/network_util.h b/src/ray/common/network_util.h index 08bef7ae873af..8f268ec46b389 100644 --- a/src/ray/common/network_util.h +++ b/src/ray/common/network_util.h @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 53e0bf4d72450..cc11cfbd2a6e6 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -183,8 +183,7 @@ RAY_CONFIG(int64_t, worker_register_timeout_seconds, 30) RAY_CONFIG(int64_t, redis_db_connect_retries, 50) RAY_CONFIG(int64_t, redis_db_connect_wait_milliseconds, 100) -/// Timeout, in milliseconds, to wait before retrying a failed pull in the -/// ObjectManager. +/// The object manager's global timer interval in milliseconds. RAY_CONFIG(int, object_manager_timer_freq_ms, 100) /// Timeout, in milliseconds, to wait before retrying a failed pull in the @@ -221,14 +220,8 @@ RAY_CONFIG(int32_t, maximum_profile_table_rows_count, 10 * 1000) /// message. RAY_CONFIG(uint32_t, object_store_get_max_ids_to_print_in_warning, 20) -// TODO: fix win32 timeout in ci and unify these two. -#ifdef _MSC_VER /// Number of threads used by rpc server in gcs server. RAY_CONFIG(uint32_t, gcs_server_rpc_server_thread_num, 1) -#else -/// Number of threads used by rpc server in gcs server. -RAY_CONFIG(uint32_t, gcs_server_rpc_server_thread_num, 8) -#endif /// Allow up to 5 seconds for connecting to gcs service. /// Note: this only takes effect when gcs service is enabled. RAY_CONFIG(int64_t, gcs_service_connect_retries, 50) @@ -241,8 +234,10 @@ RAY_CONFIG(uint64_t, gcs_redis_heartbeat_interval_milliseconds, 100) RAY_CONFIG(uint32_t, gcs_lease_worker_retry_interval_ms, 200) /// Duration to wait between retries for creating actor in gcs server. RAY_CONFIG(uint32_t, gcs_create_actor_retry_interval_ms, 200) -/// Duration to wait between retries for creating placement group in gcs server. -RAY_CONFIG(uint32_t, gcs_create_placement_group_retry_interval_ms, 200) +/// Exponential backoff params for gcs to retry creating a placement group +RAY_CONFIG(uint32_t, gcs_create_placement_group_retry_min_interval_ms, 200) +RAY_CONFIG(uint32_t, gcs_create_placement_group_retry_max_interval_ms, 5000) +RAY_CONFIG(double, gcs_create_placement_group_retry_multiplier, 1.5); /// Maximum number of destroyed actors in GCS server memory cache. RAY_CONFIG(uint32_t, maximum_gcs_destroyed_actor_cached_count, 100000) /// Maximum number of dead nodes in GCS server memory cache. @@ -317,6 +312,9 @@ RAY_CONFIG(uint32_t, agent_restart_interval_ms, 1000) /// Wait timeout for dashboard agent register. RAY_CONFIG(uint32_t, agent_register_timeout_ms, 30 * 1000) +/// Max restart count for the dashboard agent. +RAY_CONFIG(uint32_t, agent_max_restart_count, 5) + /// If the agent manager fails to communicate with the dashboard agent, we will retry /// after this interval. RAY_CONFIG(uint32_t, agent_manager_retry_interval_ms, 1000); @@ -330,7 +328,7 @@ RAY_CONFIG(int64_t, max_resource_shapes_per_load_report, 100) RAY_CONFIG(bool, report_worker_backlog, true) /// The timeout for synchronous GCS requests in seconds. -RAY_CONFIG(int64_t, gcs_server_request_timeout_seconds, 5) +RAY_CONFIG(int64_t, gcs_server_request_timeout_seconds, 60) /// Whether to enable worker prestarting: https://github.com/ray-project/ray/issues/12052 RAY_CONFIG(bool, enable_worker_prestart, true) @@ -478,7 +476,7 @@ RAY_CONFIG(int64_t, grpc_keepalive_time_ms, 10000); RAY_CONFIG(int64_t, grpc_keepalive_timeout_ms, 20000); /// Whether to use log reporter in event framework -RAY_CONFIG(bool, event_log_reporter_enabled, false) +RAY_CONFIG(bool, event_log_reporter_enabled, true) /// Whether to use log reporter in event framework RAY_CONFIG(bool, actor_register_async, true) diff --git a/src/ray/common/runtime_env_manager.cc b/src/ray/common/runtime_env_manager.cc index 2ec95cdecee8f..9e39488fa9149 100644 --- a/src/ray/common/runtime_env_manager.cc +++ b/src/ray/common/runtime_env_manager.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/common/runtime_env_manager.h" + #include "ray/util/logging.h" namespace ray { @@ -20,17 +21,12 @@ void RuntimeEnvManager::AddURIReference(const std::string &hex_id, const rpc::RuntimeEnv &runtime_env) { const auto &uris = runtime_env.uris(); for (const auto &uri : uris) { - AddURIReference(hex_id, uri); - } -} - -void RuntimeEnvManager::AddURIReference(const std::string &hex_id, - const std::string &uri) { - if (unused_uris_.count(uri)) { - unused_uris_.erase(uri); + if (unused_uris_.count(uri)) { + unused_uris_.erase(uri); + } + uri_reference_[uri]++; + id_to_uris_[hex_id].push_back(uri); } - uri_reference_[uri]++; - id_to_uris_[hex_id].push_back(uri); } const std::vector &RuntimeEnvManager::GetReferences( diff --git a/src/ray/common/runtime_env_manager.h b/src/ray/common/runtime_env_manager.h index 510aa5fe53aa9..f9c59d74784bb 100644 --- a/src/ray/common/runtime_env_manager.h +++ b/src/ray/common/runtime_env_manager.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once #include + #include "ray/common/id.h" #include "src/ray/protobuf/common.pb.h" @@ -37,12 +38,6 @@ class RuntimeEnvManager { /// \param[in] runtime_env The runtime env used by the id. void AddURIReference(const std::string &hex_id, const rpc::RuntimeEnv &runtime_env); - /// Increase the reference of URI by URI and runtime_env. - /// - /// \param[in] hex_id The id of the runtime env. It can be an actor or job id. - /// \param[in] uri The URI referenced by the id. - void AddURIReference(const std::string &hex_id, const std::string &uri); - /// Get the reference of URIs by id. /// /// \param[in] hex_id The id of to look. diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 353406fd3c820..0c3d77beb5993 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -132,8 +132,10 @@ ray::FunctionDescriptor TaskSpecification::FunctionDescriptor() const { return ray::FunctionDescriptorBuilder::FromProto(message_->function_descriptor()); } +rpc::RuntimeEnv TaskSpecification::RuntimeEnv() const { return message_->runtime_env(); } + std::string TaskSpecification::SerializedRuntimeEnv() const { - return message_->serialized_runtime_env(); + return message_->runtime_env().serialized_runtime_env(); } bool TaskSpecification::HasRuntimeEnv() const { @@ -145,8 +147,7 @@ int TaskSpecification::GetRuntimeEnvHash() const { if (RayConfig::instance().worker_resource_limits_enabled()) { required_resource = GetRequiredResources().GetResourceMap(); } - WorkerCacheKey env = {OverrideEnvironmentVariables(), SerializedRuntimeEnv(), - required_resource}; + WorkerCacheKey env = {SerializedRuntimeEnv(), required_resource}; return env.IntHash(); } @@ -239,11 +240,6 @@ std::string TaskSpecification::GetDebuggerBreakpoint() const { return message_->debugger_breakpoint(); } -std::unordered_map -TaskSpecification::OverrideEnvironmentVariables() const { - return MapFromProtobuf(message_->override_environment_variables()); -} - bool TaskSpecification::IsDriverTask() const { return message_->type() == TaskType::DRIVER_TASK; } @@ -398,11 +394,9 @@ std::string TaskSpecification::CallSiteString() const { } WorkerCacheKey::WorkerCacheKey( - const std::unordered_map override_environment_variables, const std::string serialized_runtime_env, const std::unordered_map required_resources) - : override_environment_variables(override_environment_variables), - serialized_runtime_env(serialized_runtime_env), + : serialized_runtime_env(serialized_runtime_env), required_resources(std::move(required_resources)) {} bool WorkerCacheKey::operator==(const WorkerCacheKey &k) const { @@ -411,8 +405,7 @@ bool WorkerCacheKey::operator==(const WorkerCacheKey &k) const { } bool WorkerCacheKey::EnvIsEmpty() const { - return override_environment_variables.size() == 0 && - (serialized_runtime_env == "" || serialized_runtime_env == "{}") && + return (serialized_runtime_env == "" || serialized_runtime_env == "{}") && required_resources.empty(); } @@ -424,19 +417,6 @@ std::size_t WorkerCacheKey::Hash() const { // runtime envs. hash_ = 0; } else { - std::vector> env_vars( - override_environment_variables.begin(), override_environment_variables.end()); - // The environment doesn't depend the order of the variables, so the hash should not - // either. Sort the variables so different permutations yield the same hash. - std::sort(env_vars.begin(), env_vars.end()); - for (auto &pair : env_vars) { - // TODO(architkulkarni): boost::hash_combine isn't guaranteed to be equal during - // separate runs of a program, which may cause problems if these hashes are - // communicated between different Raylets and compared. - boost::hash_combine(hash_, pair.first); - boost::hash_combine(hash_, pair.second); - } - boost::hash_combine(hash_, serialized_runtime_env); std::vector> resource_vars( diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 8b10b163cc3cc..7ec1480c3fa66 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -100,6 +100,8 @@ class TaskSpecification : public MessageWrapper { ray::FunctionDescriptor FunctionDescriptor() const; + [[nodiscard]] rpc::RuntimeEnv RuntimeEnv() const; + std::string SerializedRuntimeEnv() const; bool HasRuntimeEnv() const; @@ -170,8 +172,6 @@ class TaskSpecification : public MessageWrapper { std::string GetDebuggerBreakpoint() const; - std::unordered_map OverrideEnvironmentVariables() const; - bool IsDriverTask() const; Language GetLanguage() const; @@ -275,13 +275,10 @@ class WorkerCacheKey { /// Create a cache key with the given environment variable overrides and serialized /// runtime_env. /// - /// \param override_environment_variables The environment variable overrides set in this /// worker. \param serialized_runtime_env The JSON-serialized runtime env for this /// worker. \param required_resources The required resouce. - WorkerCacheKey( - const std::unordered_map override_environment_variables, - const std::string serialized_runtime_env, - const std::unordered_map required_resources); + WorkerCacheKey(const std::string serialized_runtime_env, + const std::unordered_map required_resources); bool operator==(const WorkerCacheKey &k) const; @@ -293,8 +290,7 @@ class WorkerCacheKey { /// Get the hash for this worker's environment. /// - /// \return The hash of the override_environment_variables and the serialized - /// runtime_env. + /// \return The hash of the serialized runtime_env. std::size_t Hash() const; /// Get the int-valued hash for this worker's environment, useful for portability in @@ -304,8 +300,6 @@ class WorkerCacheKey { int IntHash() const; private: - /// The environment variable overrides for this worker. - const std::unordered_map override_environment_variables; /// The JSON-serialized runtime env for this worker. const std::string serialized_runtime_env; /// The required resources for this worker. diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index c011829c2603d..57ee5b811663e 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -106,8 +106,7 @@ class TaskSpecBuilder { const BundleID &bundle_id, bool placement_group_capture_child_tasks, const std::string &debugger_breakpoint, const std::string &serialized_runtime_env = "{}", - const std::unordered_map &override_environment_variables = - {}, + const std::vector &runtime_env_uris = {}, const std::string &concurrency_group_name = "") { message_->set_type(TaskType::NORMAL_TASK); message_->set_name(name); @@ -129,11 +128,11 @@ class TaskSpecBuilder { message_->set_placement_group_capture_child_tasks( placement_group_capture_child_tasks); message_->set_debugger_breakpoint(debugger_breakpoint); - message_->set_serialized_runtime_env(serialized_runtime_env); - message_->set_concurrency_group_name(concurrency_group_name); - for (const auto &env : override_environment_variables) { - (*message_->mutable_override_environment_variables())[env.first] = env.second; + message_->mutable_runtime_env()->set_serialized_runtime_env(serialized_runtime_env); + for (const std::string &uri : runtime_env_uris) { + message_->mutable_runtime_env()->add_uris(uri); } + message_->set_concurrency_group_name(concurrency_group_name); return *this; } diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 016d16ddc8851..dfb4fd9a39f28 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -60,14 +60,13 @@ struct TaskOptions { std::unordered_map &resources, const std::string &concurrency_group_name = "", const std::string &serialized_runtime_env = "{}", - const std::unordered_map - &override_environment_variables = {}) + const std::vector &runtime_env_uris = {}) : name(name), num_returns(num_returns), resources(resources), concurrency_group_name(concurrency_group_name), serialized_runtime_env(serialized_runtime_env), - override_environment_variables(override_environment_variables) {} + runtime_env_uris(runtime_env_uris) {} /// The name of this task. std::string name; @@ -77,12 +76,10 @@ struct TaskOptions { std::unordered_map resources; /// The name of the concurrency group in which this task will be executed. std::string concurrency_group_name; - // Runtime Env used by this task. Propagated to child actors and tasks. + // Runtime Env used by this task. Propagated to child actors and tasks. std::string serialized_runtime_env; - /// Environment variables to update for this task. Maps a variable name to its - /// value. Can override existing environment variables and introduce new ones. - /// Propagated to child actors and/or tasks. - const std::unordered_map override_environment_variables; + // URIs contained in the runtime_env. + std::vector runtime_env_uris; }; /// Options for actor creation tasks. @@ -97,8 +94,7 @@ struct ActorCreationOptions { BundleID placement_options = std::make_pair(PlacementGroupID::Nil(), -1), bool placement_group_capture_child_tasks = true, const std::string &serialized_runtime_env = "{}", - const std::unordered_map &override_environment_variables = - {}, + const std::vector &runtime_env_uris = {}, const std::vector &concurrency_groups = {}) : max_restarts(max_restarts), max_task_retries(max_task_retries), @@ -113,7 +109,7 @@ struct ActorCreationOptions { placement_options(placement_options), placement_group_capture_child_tasks(placement_group_capture_child_tasks), serialized_runtime_env(serialized_runtime_env), - override_environment_variables(override_environment_variables), + runtime_env_uris(runtime_env_uris), concurrency_groups(concurrency_groups.begin(), concurrency_groups.end()){}; /// Maximum number of times that the actor should be restarted if it dies @@ -155,10 +151,8 @@ struct ActorCreationOptions { bool placement_group_capture_child_tasks = true; // Runtime Env used by this actor. Propagated to child actors and tasks. std::string serialized_runtime_env; - /// Environment variables to update for this actor. Maps a variable name to its - /// value. Can override existing environment variables and introduce new ones. - /// Propagated to child actors and/or tasks. - const std::unordered_map override_environment_variables; + // URIs contained in the runtime_env. + std::vector runtime_env_uris; /// The actor concurrency groups to indicate how this actor perform its /// methods concurrently. const std::vector concurrency_groups; diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 37e7797e62676..ab8f6c1884764 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -168,12 +168,7 @@ bool WorkerContext::ShouldCaptureChildTasksInPlacementGroup() const { } const std::string &WorkerContext::GetCurrentSerializedRuntimeEnv() const { - return serialized_runtime_env_; -} - -const std::unordered_map - &WorkerContext::GetCurrentOverrideEnvironmentVariables() const { - return override_environment_variables_; + return runtime_env_.serialized_runtime_env(); } void WorkerContext::SetCurrentTaskId(const TaskID &task_id) { @@ -186,10 +181,9 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { if (task_spec.IsNormalTask()) { current_task_is_direct_call_ = true; // TODO(architkulkarni): Once workers are cached by runtime env, we should - // only set serialized_runtime_env_ once and then RAY_CHECK that we + // only set runtime_env_ once and then RAY_CHECK that we // never see a new one. - serialized_runtime_env_ = task_spec.SerializedRuntimeEnv(); - override_environment_variables_ = task_spec.OverrideEnvironmentVariables(); + runtime_env_ = task_spec.RuntimeEnv(); } else if (task_spec.IsActorCreationTask()) { RAY_CHECK(current_actor_id_.IsNil()); current_actor_id_ = task_spec.ActorCreationId(); @@ -199,8 +193,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { is_detached_actor_ = task_spec.IsDetachedActor(); current_actor_placement_group_id_ = task_spec.PlacementGroupBundleId().first; placement_group_capture_child_tasks_ = task_spec.PlacementGroupCaptureChildTasks(); - serialized_runtime_env_ = task_spec.SerializedRuntimeEnv(); - override_environment_variables_ = task_spec.OverrideEnvironmentVariables(); + runtime_env_ = task_spec.RuntimeEnv(); } else if (task_spec.IsActorTask()) { RAY_CHECK(current_actor_id_ == task_spec.ActorId()); } else { diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index a403ee367c973..3c5f35718235a 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -42,9 +42,6 @@ class WorkerContext { const std::string &GetCurrentSerializedRuntimeEnv() const; - const std::unordered_map - &GetCurrentOverrideEnvironmentVariables() const; - // TODO(edoakes): remove this once Python core worker uses the task interfaces. void SetCurrentTaskId(const TaskID &task_id); @@ -98,10 +95,8 @@ class WorkerContext { PlacementGroupID current_actor_placement_group_id_; // Whether or not we should implicitly capture parent's placement group. bool placement_group_capture_child_tasks_; - // The JSON-serialized runtime env for the current actor or task. - std::string serialized_runtime_env_ = "{}"; - // The environment variable overrides for the current actor or task. - std::unordered_map override_environment_variables_; + // The runtime env for the current actor or task. + rpc::RuntimeEnv runtime_env_; /// The id of the (main) thread that constructed this worker context. boost::thread::id main_thread_id_; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index e9251cbf990ac..4270a31f5561f 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -27,34 +27,11 @@ namespace ray { namespace core { +namespace { // Duration between internal book-keeping heartbeats. const uint64_t kInternalHeartbeatMillis = 1000; -void BuildCommonTaskSpec( - TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, - const std::string name, const TaskID ¤t_task_id, const uint64_t task_index, - const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, - const std::vector> &args, uint64_t num_returns, - const std::unordered_map &required_resources, - const std::unordered_map &required_placement_resources, - const BundleID &bundle_id, bool placement_group_capture_child_tasks, - const std::string debugger_breakpoint, const std::string &serialized_runtime_env, - const std::unordered_map &override_environment_variables, - const std::string &concurrency_group_name = "") { - // Build common task spec. - builder.SetCommonTaskSpec( - task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, - current_task_id, task_index, caller_id, address, num_returns, required_resources, - required_placement_resources, bundle_id, placement_group_capture_child_tasks, - debugger_breakpoint, serialized_runtime_env, override_environment_variables, - concurrency_group_name); - // Set task arguments. - for (const auto &arg : args) { - builder.AddArg(*arg); - } -} - JobID GetProcessJobID(const CoreWorkerOptions &options) { if (options.worker_type == WorkerType::DRIVER) { RAY_CHECK(!options.job_id.IsNil()); @@ -89,6 +66,16 @@ ObjectLocation CreateObjectLocation(const rpc::GetObjectLocationsOwnerReply &rep /// The global instance of `CoreWorkerProcess`. std::unique_ptr core_worker_process; +/// Teriminate the process without cleaning up the resources. +/// It will flush the log if logging_enabled is set to true. +void QuickExit(bool logging_enabled) { + if (logging_enabled) { + RayLog::ShutDownRayLog(); + } + _Exit(1); +} +} // namespace + thread_local std::weak_ptr CoreWorkerProcess::current_core_worker_; void CoreWorkerProcess::Initialize(const CoreWorkerOptions &options) { @@ -103,10 +90,11 @@ void CoreWorkerProcess::Shutdown() { } RAY_CHECK(core_worker_process->options_.worker_type == WorkerType::DRIVER) << "The `Shutdown` interface is for driver only."; - RAY_CHECK(core_worker_process->global_worker_); - core_worker_process->global_worker_->Disconnect(); - core_worker_process->global_worker_->Shutdown(); - core_worker_process->RemoveWorker(core_worker_process->global_worker_); + auto global_worker = core_worker_process->GetGlobalWorker(); + RAY_CHECK(global_worker); + global_worker->Disconnect(); + global_worker->Shutdown(); + core_worker_process->RemoveWorker(global_worker); core_worker_process.reset(); } @@ -147,18 +135,8 @@ CoreWorkerProcess::CoreWorkerProcess(const CoreWorkerOptions &options) // NOTE(kfstorm): any initialization depending on RayConfig must happen after this line. InitializeSystemConfig(); - if (options_.num_workers == 1) { - // We need to create the worker instance here if: - // 1. This is a driver process. In this case, the driver is ready to use right after - // the CoreWorkerProcess::Initialize. - // 2. This is a Python worker process. In this case, Python will invoke some core - // worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need - // to create the worker instance here. One example of invocations is - // https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281. - if (options_.worker_type == WorkerType::DRIVER || - options_.language == Language::PYTHON) { - CreateWorker(); - } + if (ShouldCreateGlobalWorkerOnConstruction()) { + CreateWorker(); } // Assume stats module will be initialized exactly once in once process. @@ -256,11 +234,23 @@ void CoreWorkerProcess::InitializeSystemConfig() { RayConfig::instance().initialize(promise.get_future().get()); } +bool CoreWorkerProcess::ShouldCreateGlobalWorkerOnConstruction() const { + // We need to create the worker instance here if: + // 1. This is a driver process. In this case, the driver is ready to use right after + // the CoreWorkerProcess::Initialize. + // 2. This is a Python worker process. In this case, Python will invoke some core + // worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need + // to create the worker instance here. One example of invocations is + // https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281. + return options_.num_workers == 1 && (options_.worker_type == WorkerType::DRIVER || + options_.language == Language::PYTHON); +} + std::shared_ptr CoreWorkerProcess::TryGetWorker(const WorkerID &worker_id) { if (!core_worker_process) { return nullptr; } - absl::ReaderMutexLock workers_lock(&core_worker_process->worker_map_mutex_); + absl::ReaderMutexLock workers_lock(&core_worker_process->mutex_); auto it = core_worker_process->workers_.find(worker_id); if (it != core_worker_process->workers_.end()) { return it->second; @@ -271,8 +261,19 @@ std::shared_ptr CoreWorkerProcess::TryGetWorker(const WorkerID &work CoreWorker &CoreWorkerProcess::GetCoreWorker() { EnsureInitialized(); if (core_worker_process->options_.num_workers == 1) { - RAY_CHECK(core_worker_process->global_worker_) << "global_worker_ must not be NULL"; - return *core_worker_process->global_worker_; + auto global_worker = core_worker_process->GetGlobalWorker(); + if (core_worker_process->ShouldCreateGlobalWorkerOnConstruction() && !global_worker) { + // This could only happen when the worker has already been shutdown. + // In this case, we should exit without crashing. + // TODO (scv119): A better solution could be returning error code + // and handling it at language frontend. + RAY_LOG(ERROR) << "The global worker has already been shutdown. This happens when " + "the language frontend accesses the Ray's worker after it is " + "shutdown. The process will exit"; + QuickExit(core_worker_process->options_.enable_logging); + } + RAY_CHECK(global_worker) << "global_worker_ must not be NULL"; + return *global_worker; } auto ptr = current_core_worker_.lock(); RAY_CHECK(ptr != nullptr) @@ -283,7 +284,7 @@ CoreWorker &CoreWorkerProcess::GetCoreWorker() { void CoreWorkerProcess::SetCurrentThreadWorkerId(const WorkerID &worker_id) { EnsureInitialized(); if (core_worker_process->options_.num_workers == 1) { - RAY_CHECK(core_worker_process->global_worker_->GetWorkerID() == worker_id); + RAY_CHECK(core_worker_process->GetGlobalWorker()->GetWorkerID() == worker_id); return; } current_core_worker_ = core_worker_process->GetWorker(worker_id); @@ -291,23 +292,28 @@ void CoreWorkerProcess::SetCurrentThreadWorkerId(const WorkerID &worker_id) { std::shared_ptr CoreWorkerProcess::GetWorker( const WorkerID &worker_id) const { - absl::ReaderMutexLock lock(&worker_map_mutex_); + absl::ReaderMutexLock lock(&mutex_); auto it = workers_.find(worker_id); RAY_CHECK(it != workers_.end()) << "Worker " << worker_id << " not found."; return it->second; } +std::shared_ptr CoreWorkerProcess::GetGlobalWorker() { + absl::ReaderMutexLock lock(&mutex_); + return global_worker_; +} + std::shared_ptr CoreWorkerProcess::CreateWorker() { auto worker = std::make_shared( options_, global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom()); RAY_LOG(DEBUG) << "Worker " << worker->GetWorkerID() << " is created."; + absl::WriterMutexLock lock(&mutex_); if (options_.num_workers == 1) { global_worker_ = worker; } current_core_worker_ = worker; - absl::MutexLock lock(&worker_map_mutex_); workers_.emplace(worker->GetWorkerID(), worker); RAY_CHECK(workers_.size() <= static_cast(options_.num_workers)); return worker; @@ -315,6 +321,7 @@ std::shared_ptr CoreWorkerProcess::CreateWorker() { void CoreWorkerProcess::RemoveWorker(std::shared_ptr worker) { worker->WaitForShutdown(); + absl::WriterMutexLock lock(&mutex_); if (global_worker_) { RAY_CHECK(global_worker_ == worker); } else { @@ -322,7 +329,6 @@ void CoreWorkerProcess::RemoveWorker(std::shared_ptr worker) { } current_core_worker_.reset(); { - absl::MutexLock lock(&worker_map_mutex_); workers_.erase(worker->GetWorkerID()); RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID(); } @@ -336,9 +342,10 @@ void CoreWorkerProcess::RunTaskExecutionLoop() { RAY_CHECK(core_worker_process->options_.worker_type == WorkerType::WORKER); if (core_worker_process->options_.num_workers == 1) { // Run the task loop in the current thread only if the number of workers is 1. - auto worker = core_worker_process->global_worker_ - ? core_worker_process->global_worker_ - : core_worker_process->CreateWorker(); + auto worker = core_worker_process->GetGlobalWorker(); + if (!worker) { + worker = core_worker_process->CreateWorker(); + } worker->RunTaskExecutionLoop(); core_worker_process->RemoveWorker(worker); } else { @@ -370,9 +377,9 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ periodical_runner_(io_service_), task_queue_length_(0), num_executed_tasks_(0), - task_execution_service_work_(task_execution_service_), resource_ids_(new ResourceMappingType()), - grpc_service_(io_service_, *this) { + grpc_service_(io_service_, *this), + task_execution_service_work_(task_execution_service_) { RAY_LOG(DEBUG) << "Constructing CoreWorker, worker_id: " << worker_id; // Initialize task receivers. @@ -409,11 +416,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Avoid using FATAL log or RAY_CHECK here because they may create a core dump file. RAY_LOG(ERROR) << "Failed to register worker " << worker_id << " to Raylet. " << raylet_client_status; - if (options_.enable_logging) { - RayLog::ShutDownRayLog(); - } // Quit the process immediately. - _Exit(1); + QuickExit(options_.enable_logging); } connected_ = true; @@ -427,7 +431,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Start RPC server after all the task receivers are properly initialized and we have // our assigned port from the raylet. core_worker_server_ = std::make_unique( - WorkerTypeString(options_.worker_type), assigned_port); + WorkerTypeString(options_.worker_type), assigned_port, + options_.node_ip_address == "127.0.0.1"); core_worker_server_->RegisterService(grpc_service_); core_worker_server_->Run(); @@ -526,10 +531,6 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ options_.worker_type != WorkerType::RESTORE_WORKER), /*get_current_call_site=*/boost::bind(&CoreWorker::CurrentCallSite, this))); memory_store_.reset(new CoreWorkerMemoryStore( - [this](const RayObject &object, const ObjectID &object_id) { - PutObjectIntoPlasma(object, object_id); - return Status::OK(); - }, reference_counter_, local_raylet_client_, options_.check_signals, [this](const RayObject &obj) { // Run this on the event loop to avoid calling back into the language runtime @@ -737,6 +738,11 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ }, event_stats_print_interval_ms); } + + // Set event context for current core worker thread. + RayEventContext::Instance().SetEventContext( + ray::rpc::Event_SourceType::Event_SourceType_CORE_WORKER, + {{"worker_id", worker_id.Hex()}}); } void CoreWorker::Shutdown() { @@ -933,17 +939,25 @@ void CoreWorker::RegisterToGcs() { } void CoreWorker::CheckForRayletFailure() { + bool should_shutdown = false; // When running worker process in container, the worker parent process is not raylet. // So we add RAY_RAYLET_PID enviroment to ray worker process. if (auto env_pid = RayConfig::instance().RAYLET_PID(); !env_pid.empty()) { auto pid = static_cast(std::stoi(env_pid)); if (!IsProcessAlive(pid)) { RAY_LOG(ERROR) << "Raylet failed. Shutting down. Raylet PID: " << pid; - Shutdown(); + should_shutdown = true; } } else if (!IsParentProcessAlive()) { RAY_LOG(ERROR) << "Raylet failed. Shutting down."; - Shutdown(); + should_shutdown = true; + } + if (should_shutdown) { + if (options_.worker_type == WorkerType::WORKER) { + task_execution_service_.post([this]() { Shutdown(); }, "CoreWorker.Shutdown"); + } else { + Shutdown(); + } } } @@ -992,36 +1006,6 @@ CoreWorker::GetAllReferenceCounts() const { return counts; } -void CoreWorker::PutObjectIntoPlasma(const RayObject &object, const ObjectID &object_id) { - bool object_exists; - // This call will only be used by PromoteObjectToPlasma, which means that the - // object will always owned by us. - RAY_CHECK_OK(plasma_store_provider_->Put( - object, object_id, /* owner_address = */ rpc_address_, &object_exists)); - if (!object_exists) { - // Tell the raylet to pin the object **after** it is created. - RAY_LOG(DEBUG) << "Pinning put object " << object_id; - local_raylet_client_->PinObjectIDs( - rpc_address_, {object_id}, - [this, object_id](const Status &status, const rpc::PinObjectIDsReply &reply) { - // Only release the object once the raylet has responded to avoid the race - // condition that the object could be evicted before the raylet pins it. - if (!plasma_store_provider_->Release(object_id).ok()) { - RAY_LOG(ERROR) << "Failed to release ObjectID (" << object_id - << "), might cause a leak in plasma."; - } - }); - } - RAY_CHECK(memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id)); -} - -void CoreWorker::PromoteObjectToPlasma(const ObjectID &object_id) { - auto value = memory_store_->GetOrPromoteToPlasma(object_id); - if (value) { - PutObjectIntoPlasma(*value, object_id); - } -} - const rpc::Address &CoreWorker::GetRpcAddress() const { return rpc_address_; } rpc::Address CoreWorker::GetOwnerAddress(const ObjectID &object_id) const { @@ -1061,7 +1045,6 @@ void CoreWorker::GetOwnershipInfo(const ObjectID &object_id, rpc::Address *owner "which task will create them. " "If this was not how your object ID was generated, please file an issue " "at https://github.com/ray-project/ray/issues/"; - RAY_LOG(DEBUG) << "Promoted object to plasma " << object_id; rpc::GetObjectStatusReply object_status; // Optimization: if the object exists, serialize and inline its status. This also @@ -1635,6 +1618,37 @@ std::unordered_map AddPlacementGroupConstraint( return resources; } +void CoreWorker::BuildCommonTaskSpec( + TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, + const std::string &name, const TaskID ¤t_task_id, uint64_t task_index, + const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, + const std::vector> &args, uint64_t num_returns, + const std::unordered_map &required_resources, + const std::unordered_map &required_placement_resources, + const BundleID &bundle_id, bool placement_group_capture_child_tasks, + const std::string &debugger_breakpoint, const std::string &serialized_runtime_env, + const std::vector &runtime_env_uris, + const std::string &concurrency_group_name) { + // Build common task spec. + builder.SetCommonTaskSpec( + task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, + current_task_id, task_index, caller_id, address, num_returns, required_resources, + required_placement_resources, bundle_id, placement_group_capture_child_tasks, + debugger_breakpoint, + // TODO(SongGuyang): Move the logic of `prepare_runtime_env` from Python to Core + // Worker. A common process is needed. + // If runtime env is not provided, use job config. Only for Java and C++ because it + // has been set in Python by `prepare_runtime_env`. + (serialized_runtime_env.empty() || serialized_runtime_env == "{}") + ? job_config_->runtime_env().serialized_runtime_env() + : serialized_runtime_env, + runtime_env_uris, concurrency_group_name); + // Set task arguments. + for (const auto &arg : args) { + builder.AddArg(*arg); + } +} + std::vector CoreWorker::SubmitTask( const RayFunction &function, const std::vector> &args, const TaskOptions &task_options, int max_retries, bool retry_exceptions, @@ -1652,21 +1666,13 @@ std::vector CoreWorker::SubmitTask( auto task_name = task_options.name.empty() ? function.GetFunctionDescriptor()->DefaultTaskName() : task_options.name; - // Propagate existing environment variable overrides, but override them with any new - // ones - std::unordered_map current_override_environment_variables = - worker_context_.GetCurrentOverrideEnvironmentVariables(); - std::unordered_map override_environment_variables = - task_options.override_environment_variables; - override_environment_variables.insert(current_override_environment_variables.begin(), - current_override_environment_variables.end()); // TODO(ekl) offload task building onto a thread pool for performance - BuildCommonTaskSpec( - builder, worker_context_.GetCurrentJobID(), task_id, task_name, - worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, - function, args, task_options.num_returns, constrained_resources, required_resources, - placement_options, placement_group_capture_child_tasks, debugger_breakpoint, - task_options.serialized_runtime_env, override_environment_variables); + BuildCommonTaskSpec(builder, worker_context_.GetCurrentJobID(), task_id, task_name, + worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), + rpc_address_, function, args, task_options.num_returns, + constrained_resources, required_resources, placement_options, + placement_group_capture_child_tasks, debugger_breakpoint, + task_options.serialized_runtime_env, task_options.runtime_env_uris); builder.SetNormalTaskSpec(max_retries, retry_exceptions); TaskSpecification task_spec = builder.Build(); RAY_LOG(DEBUG) << "Submit task " << task_spec.DebugString(); @@ -1702,12 +1708,7 @@ Status CoreWorker::CreateActor(const RayFunction &function, const JobID job_id = worker_context_.GetCurrentJobID(); // Propagate existing environment variable overrides, but override them with any new // ones - std::unordered_map current_override_environment_variables = - worker_context_.GetCurrentOverrideEnvironmentVariables(); - std::unordered_map override_environment_variables = - actor_creation_options.override_environment_variables; - override_environment_variables.insert(current_override_environment_variables.begin(), - current_override_environment_variables.end()); + std::vector return_ids; TaskSpecBuilder builder; auto new_placement_resources = AddPlacementGroupConstraint(actor_creation_options.placement_resources, @@ -1728,7 +1729,7 @@ Status CoreWorker::CreateActor(const RayFunction &function, actor_creation_options.placement_group_capture_child_tasks, "", /* debugger_breakpoint */ actor_creation_options.serialized_runtime_env, - override_environment_variables); + actor_creation_options.runtime_env_uris); auto actor_handle = std::make_unique( actor_id, GetCallerId(), rpc_address_, job_id, @@ -1905,7 +1906,6 @@ std::vector CoreWorker::SubmitActorTask( const auto task_name = task_options.name.empty() ? function.GetFunctionDescriptor()->DefaultTaskName() : task_options.name; - const std::unordered_map override_environment_variables = {}; BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id, task_name, worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, function, args, num_returns, task_options.resources, @@ -1913,7 +1913,7 @@ std::vector CoreWorker::SubmitActorTask( true, /* placement_group_capture_child_tasks */ "", /* debugger_breakpoint */ "{}", /* serialized_runtime_env */ - override_environment_variables, + {}, /* runtime_env_uris */ task_options.concurrency_group_name); // NOTE: placement_group_capture_child_tasks and runtime_env will // be ignored in the actor because we should always follow the actor's option. @@ -2184,6 +2184,14 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, task_queue_length_ -= 1; num_executed_tasks_ += 1; + // Modify the worker's per function counters. + std::string func_name = task_spec.FunctionDescriptor()->CallString(); + { + absl::MutexLock l(&task_counter_.tasks_counter_mutex_); + task_counter_.Add(TaskCounter::kPending, func_name, -1); + task_counter_.Add(TaskCounter::kRunning, func_name, 1); + } + if (!options_.is_local_mode) { worker_context_.SetCurrentTask(task_spec); SetCurrentTaskId(task_spec.TaskId()); @@ -2279,8 +2287,16 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, resource_ids_.reset(new ResourceMappingType()); } } - RAY_LOG(INFO) << "Finished executing task " << task_spec.TaskId() - << ", status=" << status; + + // Modify the worker's per function counters. + { + absl::MutexLock l(&task_counter_.tasks_counter_mutex_); + task_counter_.Add(TaskCounter::kRunning, func_name, -1); + task_counter_.Add(TaskCounter::kFinished, func_name, 1); + } + + RAY_LOG(DEBUG) << "Finished executing task " << task_spec.TaskId() + << ", status=" << status; if (status.IsCreationTaskError()) { Exit(rpc::WorkerExitType::CREATION_TASK_ERROR, creation_task_exception_pb_bytes); } else if (status.IsIntentionalSystemExit()) { @@ -2447,8 +2463,15 @@ void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request, return; } - // Increment the task_queue_length + // Increment the task_queue_length and per function counter. task_queue_length_ += 1; + std::string func_name = + FunctionDescriptorBuilder::FromProto(request.task_spec().function_descriptor()) + ->CallString(); + { + absl::MutexLock l(&task_counter_.tasks_counter_mutex_); + task_counter_.Add(TaskCounter::kPending, func_name, 1); + } // For actor tasks, we just need to post a HandleActorTask instance to the task // execution service. @@ -2855,13 +2878,10 @@ void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request, << " has received a force kill request after the cancellation. Killing " "a worker..."; Disconnect(); - if (options_.enable_logging) { - RayLog::ShutDownRayLog(); - } - // NOTE(hchen): Use `_Exit()` to force-exit this process without doing cleanup. + // NOTE(hchen): Use `QuickExit()` to force-exit this process without doing cleanup. // `exit()` will destruct static objects in an incorrect order, which will lead to // core dumps. - _Exit(1); + QuickExit(options_.enable_logging); } } @@ -2894,13 +2914,10 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request, "please create the Java actor with some dynamic options to make it being " "hosted in a dedicated worker process."; } - if (options_.enable_logging) { - RayLog::ShutDownRayLog(); - } - // NOTE(hchen): Use `_Exit()` to force-exit this process without doing cleanup. + // NOTE(hchen): Use `QuickExit()` to force-exit this process without doing cleanup. // `exit()` will destruct static objects in an incorrect order, which will lead to // core dumps. - _Exit(1); + QuickExit(options_.enable_logging); } else { Exit(rpc::WorkerExitType::INTENDED_EXIT); } @@ -3057,15 +3074,16 @@ void CoreWorker::HandleExit(const rpc::ExitRequest &request, rpc::ExitReply *rep // any object pinning RPCs in flight. bool is_idle = !own_objects && pins_in_flight == 0; reply->set_success(is_idle); - send_reply_callback(Status::OK(), - [this, is_idle]() { - // If the worker is idle, we exit. - if (is_idle) { - Exit(rpc::WorkerExitType::IDLE_EXIT); - } - }, - // We need to kill it regardless if the RPC failed. - [this]() { Exit(rpc::WorkerExitType::INTENDED_EXIT); }); + send_reply_callback( + Status::OK(), + [this, is_idle]() { + // If the worker is idle, we exit. + if (is_idle) { + Exit(rpc::WorkerExitType::IDLE_EXIT); + } + }, + // We need to kill it regardless if the RPC failed. + [this]() { Exit(rpc::WorkerExitType::INTENDED_EXIT); }); } void CoreWorker::HandleAssignObjectOwner(const rpc::AssignObjectOwnerRequest &request, @@ -3191,6 +3209,25 @@ std::shared_ptr CoreWorker::GetGcsClient() const { return gcs_cl bool CoreWorker::IsExiting() const { return exiting_; } +std::unordered_map> CoreWorker::GetActorCallStats() + const { + absl::MutexLock l(&task_counter_.tasks_counter_mutex_); + std::unordered_map> total_counts; + + for (const auto &count : task_counter_.pending_tasks_counter_map_) { + total_counts[count.first].resize(3, 0); + total_counts[count.first][0] = count.second; + } + for (const auto &count : task_counter_.running_tasks_counter_map_) { + total_counts[count.first][1] = count.second; + } + for (const auto &count : task_counter_.finished_tasks_counter_map_) { + total_counts[count.first][2] = count.second; + } + + return total_counts; +} + Status CoreWorker::WaitForActorRegistered(const std::vector &ids) { std::vector actor_ids; for (const auto &id : ids) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 3ef1e2476f6d2..883a1b013ff81 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -294,23 +294,29 @@ class CoreWorkerProcess { void InitializeSystemConfig(); + /// Check that if the global worker should be created on construction. + bool ShouldCreateGlobalWorkerOnConstruction() const; + /// Get the `CoreWorker` instance by worker ID. /// /// \param[in] workerId The worker ID. /// \return The `CoreWorker` instance. std::shared_ptr GetWorker(const WorkerID &worker_id) const - LOCKS_EXCLUDED(worker_map_mutex_); + LOCKS_EXCLUDED(mutex_); /// Create a new `CoreWorker` instance. /// /// \return The newly created `CoreWorker` instance. - std::shared_ptr CreateWorker() LOCKS_EXCLUDED(worker_map_mutex_); + std::shared_ptr CreateWorker() LOCKS_EXCLUDED(mutex_); /// Remove an existing `CoreWorker` instance. /// /// \param[in] The existing `CoreWorker` instance. /// \return Void. - void RemoveWorker(std::shared_ptr worker) LOCKS_EXCLUDED(worker_map_mutex_); + void RemoveWorker(std::shared_ptr worker) LOCKS_EXCLUDED(mutex_); + + /// Get the `GlobalWorker` instance, if the number of workers is 1. + std::shared_ptr GetGlobalWorker() LOCKS_EXCLUDED(mutex_); /// The various options. const CoreWorkerOptions options_; @@ -320,17 +326,16 @@ class CoreWorkerProcess { static thread_local std::weak_ptr current_core_worker_; /// The only core worker instance, if the number of workers is 1. - std::shared_ptr global_worker_; + std::shared_ptr global_worker_ GUARDED_BY(mutex_); /// The worker ID of the global worker, if the number of workers is 1. const WorkerID global_worker_id_; /// Map from worker ID to worker. - std::unordered_map> workers_ - GUARDED_BY(worker_map_mutex_); + std::unordered_map> workers_ GUARDED_BY(mutex_); - /// To protect accessing the `workers_` map. - mutable absl::Mutex worker_map_mutex_; + /// To protect access to workers_ and global_worker_ + mutable absl::Mutex mutex_; }; /// The root class that contains all the core and language-independent functionalities @@ -440,22 +445,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// (local, submitted_task) reference counts. For debugging purposes. std::unordered_map> GetAllReferenceCounts() const; - /// Put an object into plasma. It's a version of Put that directly put the - /// object into plasma and also pin the object. - /// - /// \param[in] The ray object. - /// \param[in] object_id The object ID to serialize. - /// appended to the serialized object ID. - void PutObjectIntoPlasma(const RayObject &object, const ObjectID &object_id); - - /// Promote an object to plasma. If the - /// object already exists locally, it will be put into the plasma store. If - /// it doesn't yet exist, it will be spilled to plasma once available. - /// - /// \param[in] object_id The object ID to serialize. - /// appended to the serialized object ID. - void PromoteObjectToPlasma(const ObjectID &object_id); - /// Get the RPC address of this worker. /// /// \param[out] The RPC address of this worker. @@ -1044,7 +1033,24 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Return true if the core worker is in the exit process. bool IsExiting() const; + /// Retrieve the current statistics about tasks being received and executing. + /// \return an unordered_map mapping function name to list of (num_received, + /// num_executing, num_executed). It is a std map instead of absl due to its + /// interface with language bindings. + std::unordered_map> GetActorCallStats() const; + private: + void BuildCommonTaskSpec( + TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, + const std::string &name, const TaskID ¤t_task_id, uint64_t task_index, + const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, + const std::vector> &args, uint64_t num_returns, + const std::unordered_map &required_resources, + const std::unordered_map &required_placement_resources, + const BundleID &bundle_id, bool placement_group_capture_child_tasks, + const std::string &debugger_breakpoint, const std::string &serialized_runtime_env, + const std::vector &runtime_env_uris, + const std::string &concurrency_group_name = ""); void SetCurrentTaskId(const TaskID &task_id); void SetActorId(const ActorID &actor_id); @@ -1366,12 +1372,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Number of executed tasks. std::atomic num_executed_tasks_; - /// Event loop where tasks are processed. - instrumented_io_context task_execution_service_; - - /// The asio work to keep task_execution_service_ alive. - boost::asio::io_service::work task_execution_service_work_; - /// Profiler including a background thread that pushes profiling events to the GCS. std::shared_ptr profiler_; @@ -1390,6 +1390,14 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { // Interface that receives tasks from direct actor calls. std::unique_ptr direct_task_receiver_; + /// Event loop where tasks are processed. + /// task_execution_service_ should be destructed first to avoid + /// issues like https://github.com/ray-project/ray/issues/18857 + instrumented_io_context task_execution_service_; + + /// The asio work to keep task_execution_service_ alive. + boost::asio::io_service::work task_execution_service_work_; + // Queue of tasks to resubmit when the specified time passes. std::deque> to_resubmit_ GUARDED_BY(mutex_); @@ -1408,14 +1416,47 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void PlasmaCallback(SetResultCallback success, std::shared_ptr ray_object, ObjectID object_id, void *py_future); - /// Whether we are shutting down and not running further tasks. - bool exiting_ = false; + /// we are shutting down and not running further tasks. + /// when exiting_ is set to true HandlePushTask becomes no-op. + std::atomic exiting_ = false; int64_t max_direct_call_object_size_; friend class CoreWorkerTest; std::unique_ptr job_config_; + + /// Simple container for per function task counters. The counters will be + /// keyed by the function name in task spec. + struct TaskCounter { + /// A task can only be one of the following state. Received state in particular + /// covers from the point of RPC call to beginning execution. + enum TaskStatusType { kPending, kRunning, kFinished }; + + /// This mutex should be used by caller to ensure consistency when transitioning + /// a task's state. + mutable absl::Mutex tasks_counter_mutex_; + absl::flat_hash_map pending_tasks_counter_map_ + GUARDED_BY(tasks_counter_mutex_); + absl::flat_hash_map running_tasks_counter_map_ + GUARDED_BY(tasks_counter_mutex_); + absl::flat_hash_map finished_tasks_counter_map_ + GUARDED_BY(tasks_counter_mutex_); + + void Add(TaskStatusType type, const std::string &func_name, int value) { + tasks_counter_mutex_.AssertHeld(); + if (type == kPending) { + pending_tasks_counter_map_[func_name] += value; + } else if (type == kRunning) { + running_tasks_counter_map_[func_name] += value; + } else if (type == kFinished) { + finished_tasks_counter_map_[func_name] += value; + } else { + RAY_CHECK(false) << "This line should not be reached."; + } + } + }; + TaskCounter task_counter_; }; } // namespace core diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc index 70a2626847574..6af083669f5ed 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc @@ -217,10 +217,9 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *env, } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativePromoteAndGetOwnershipInfo( - JNIEnv *env, jclass, jbyteArray objectId) { +Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *env, jclass, + jbyteArray objectId) { auto object_id = JavaByteArrayToId(env, objectId); - CoreWorkerProcess::GetCoreWorker().PromoteObjectToPlasma(object_id); rpc::Address address; // TODO(ekl) send serialized object status to Java land. std::string serialized_object_status; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h index 8001bbf20df06..9358f4473c228 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h @@ -105,13 +105,12 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *, jcl /* * Class: io_ray_runtime_object_NativeObjectStore - * Method: nativePromoteAndGetOwnershipInfo + * Method: nativeGetOwnershipInfo * Signature: ([B)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativePromoteAndGetOwnershipInfo(JNIEnv *, - jclass, - jbyteArray); +Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *, jclass, + jbyteArray); /* * Class: io_ray_runtime_object_NativeObjectStore diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index dd05bc76aa6e0..56a0ad473c64d 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -235,9 +235,9 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env, ray_namespace, /*is_asyncio=*/false, placement_options, - true, - "{}", - {}, + /*placement_group_capture_child_tasks=*/true, + /*serialized_runtime_env=*/"{}", + /*runtime_env_uris=*/{}, concurrency_groups}; return actor_creation_options; } diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 58c67a2010213..e0a9a783dd657 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -14,8 +14,6 @@ #pragma once -#include - #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index 5877d7f654dfc..c265bc7af753b 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -16,9 +16,9 @@ #include +#include "absl/functional/bind_front.h" #include "gmock/gmock.h" #include "gtest/gtest.h" - #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/asio/periodical_runner.h" #include "ray/common/ray_object.h" @@ -270,7 +270,7 @@ class MockWorkerClient : public MockCoreWorkerClientInterface { auto borrower_callback = [=]() { auto ref_removed_callback = - boost::bind(&ReferenceCounter::HandleRefRemoved, &rc_, _1); + absl::bind_front(&ReferenceCounter::HandleRefRemoved, &rc_); rc_.SetRefRemovedCallback(object_id, contained_in_id, owner_address, ref_removed_callback); }; @@ -656,7 +656,7 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { auto subscriber = std::make_shared(); auto rc = std::shared_ptr(new ReferenceCounter( rpc::WorkerAddress(rpc::Address()), publisher.get(), subscriber.get())); - CoreWorkerMemoryStore store(nullptr, rc); + CoreWorkerMemoryStore store(rc); // Tests putting an object with no references is ignored. RAY_CHECK(store.Put(buffer, id2)); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 680c9c13616bc..b32b612166820 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -139,13 +139,11 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { } CoreWorkerMemoryStore::CoreWorkerMemoryStore( - std::function store_in_plasma, std::shared_ptr counter, std::shared_ptr raylet_client, std::function check_signals, std::function unhandled_exception_handler) - : store_in_plasma_(store_in_plasma), - ref_counter_(counter), + : ref_counter_(std::move(counter)), raylet_client_(raylet_client), check_signals_(check_signals), unhandled_exception_handler_(unhandled_exception_handler) {} @@ -186,24 +184,6 @@ std::shared_ptr CoreWorkerMemoryStore::GetIfExists(const ObjectID &ob return ptr; } -std::shared_ptr CoreWorkerMemoryStore::GetOrPromoteToPlasma( - const ObjectID &object_id) { - absl::MutexLock lock(&mu_); - auto iter = objects_.find(object_id); - if (iter != objects_.end()) { - auto obj = iter->second; - obj->SetAccessed(); - if (obj->IsInPlasmaError()) { - return nullptr; - } - return obj; - } - RAY_CHECK(store_in_plasma_ != nullptr) - << "Cannot promote object without plasma provider callback."; - promoted_to_plasma_.insert(object_id); - return nullptr; -} - bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_id) { std::vector)>> async_callbacks; auto object_entry = std::make_shared(object.GetData(), object.GetMetadata(), @@ -212,7 +192,6 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ // TODO(edoakes): we should instead return a flag to the caller to put the object in // plasma. - bool should_put_in_plasma = false; { absl::MutexLock lock(&mu_); @@ -228,15 +207,6 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ object_async_get_requests_.erase(async_callback_it); } - auto promoted_it = promoted_to_plasma_.find(object_id); - if (promoted_it != promoted_to_plasma_.end()) { - RAY_CHECK(store_in_plasma_ != nullptr); - // Only need to promote to plasma if it wasn't already put into plasma - // by the task that created the object. - should_put_in_plasma = !object.IsInPlasmaError(); - promoted_to_plasma_.erase(promoted_it); - } - bool should_add_entry = true; auto object_request_iter = object_get_requests_.find(object_id); if (object_request_iter != object_get_requests_.end()) { @@ -268,14 +238,6 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ } } - // Must be called without holding the lock because store_in_plasma_ goes - // through the regular CoreWorker::Put() codepath, which calls into the - // in-memory store (would cause deadlock). - if (should_put_in_plasma) { - store_in_plasma_(object, object_id); - stored_in_direct_memory = false; - } - // It's important for performance to run the callbacks outside the lock. for (const auto &cb : async_callbacks) { cb(object_entry); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 542fac1ea2ea6..70bebac7f01a5 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -44,12 +44,10 @@ class CoreWorkerMemoryStore { public: /// Create a memory store. /// - /// \param[in] store_in_plasma If not null, this is used to spill to plasma. /// \param[in] counter If not null, this enables ref counting for local objects, /// and the `remove_after_get` flag for Get() will be ignored. /// \param[in] raylet_client If not null, used to notify tasks blocked / unblocked. CoreWorkerMemoryStore( - std::function store_in_plasma = nullptr, std::shared_ptr counter = nullptr, std::shared_ptr raylet_client = nullptr, std::function check_signals = nullptr, @@ -104,14 +102,6 @@ class CoreWorkerMemoryStore { void GetAsync(const ObjectID &object_id, std::function)> callback); - /// Get a single object if available. If the object is not local yet, or if the object - /// is local but is ErrorType::OBJECT_IN_PLASMA, then nullptr will be returned, and - /// the store will ensure the object is promoted to plasma once available. - /// - /// \param[in] object_id The object id to get. - /// \return pointer to the local object, or nullptr if promoted to plasma. - std::shared_ptr GetOrPromoteToPlasma(const ObjectID &object_id); - /// Delete a list of objects from the object store. /// NOTE(swang): Objects that contain IsInPlasmaError will not be /// deleted from the in-memory store. Instead, any future Get @@ -187,9 +177,6 @@ class CoreWorkerMemoryStore { /// properly. void EraseObjectAndUpdateStats(const ObjectID &object_id) EXCLUSIVE_LOCKS_REQUIRED(mu_); - /// Optional callback for putting objects into the plasma store. - std::function store_in_plasma_; - /// If enabled, holds a reference to local worker ref counter. TODO(ekl) make this /// mandatory once Java is supported. std::shared_ptr ref_counter_ = nullptr; @@ -200,9 +187,6 @@ class CoreWorkerMemoryStore { /// Protects the data structures below. mutable absl::Mutex mu_; - /// Set of objects that should be promoted to plasma once available. - absl::flat_hash_set promoted_to_plasma_ GUARDED_BY(mu_); - /// Map from object ID to `RayObject`. /// NOTE: This map should be modified by EmplaceObjectAndUpdateStats and /// EraseObjectAndUpdateStats. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 29d95cb8fa9b8..3e0ddd631d45f 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include "absl/container/flat_hash_map.h" diff --git a/src/ray/core_worker/test/direct_task_transport_mock_test.cc b/src/ray/core_worker/test/direct_task_transport_mock_test.cc index 0af5c20c4eb15..8312d79a0bc43 100644 --- a/src/ray/core_worker/test/direct_task_transport_mock_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_mock_test.cc @@ -28,7 +28,7 @@ using namespace ::testing; class DirectTaskTransportTest : public ::testing::Test { public: void SetUp() override { - raylet_client = std::make_shared(); + raylet_client = std::make_shared(); task_finisher = std::make_shared(); actor_creator = std::make_shared(); lease_policy = std::make_shared(); @@ -57,7 +57,7 @@ class DirectTaskTransportTest : public ::testing::Test { } std::unique_ptr task_submitter; - std::shared_ptr raylet_client; + std::shared_ptr raylet_client; std::shared_ptr task_finisher; std::shared_ptr actor_creator; std::shared_ptr lease_policy; diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 473136255bc72..d5716f10a3981 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -161,6 +161,14 @@ class MockRayletClient : public WorkerLeaseInterface { callbacks.push_back(callback); } + void RequestWorkerLease( + const rpc::TaskSpec &task_spec, + const ray::rpc::ClientCallback &callback, + const int64_t backlog_size = -1) override { + num_workers_requested += 1; + callbacks.push_back(callback); + } + void ReleaseUnusedWorkers( const std::vector &workers_in_use, const rpc::ClientCallback &callback) override {} @@ -246,11 +254,18 @@ class MockActorCreator : public ActorCreatorInterface { } void AsyncWaitForActorRegisterFinish(const ActorID &, - gcs::StatusCallback callback) override {} + gcs::StatusCallback callback) override { + callbacks.push_back(callback); + } - bool IsActorInRegistering(const ActorID &actor_id) const override { return false; } + [[nodiscard]] bool IsActorInRegistering(const ActorID &actor_id) const override { + return actor_pending; + } ~MockActorCreator() {} + + std::list callbacks; + bool actor_pending = false; }; class MockLeasePolicy : public LeasePolicyInterface { @@ -272,30 +287,6 @@ class MockLeasePolicy : public LeasePolicyInterface { int num_lease_policy_consults = 0; }; -TEST(TestMemoryStore, TestPromoteToPlasma) { - size_t num_plasma_puts = 0; - auto mem = std::make_shared( - [&](const RayObject &obj, const ObjectID &obj_id) { num_plasma_puts += 1; }); - ObjectID obj1 = ObjectID::FromRandom(); - ObjectID obj2 = ObjectID::FromRandom(); - auto data = GenerateRandomObject(); - ASSERT_TRUE(mem->Put(*data, obj1)); - - // Test getting an already existing object. - ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj1) != nullptr); - ASSERT_TRUE(num_plasma_puts == 0); - - // Testing getting an object that doesn't exist yet causes promotion. - ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) == nullptr); - ASSERT_TRUE(num_plasma_puts == 0); - ASSERT_FALSE(mem->Put(*data, obj2)); - ASSERT_TRUE(num_plasma_puts == 1); - - // The next time you get it, it's already there so no need to promote. - ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) != nullptr); - ASSERT_TRUE(num_plasma_puts == 1); -} - TEST(LocalDependencyResolverTest, TestNoDependencies) { auto store = std::make_shared(); auto task_finisher = std::make_shared(); @@ -308,6 +299,77 @@ TEST(LocalDependencyResolverTest, TestNoDependencies) { ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); } +TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies1) { + // Actor dependency resolved first. + auto store = std::make_shared(); + auto task_finisher = std::make_shared(); + MockActorCreator actor_creator; + LocalDependencyResolver resolver(*store, *task_finisher, actor_creator); + TaskSpecification task; + ObjectID obj = ObjectID::FromRandom(); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); + + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id); + task.GetMutableMessage().add_args()->add_nested_inlined_refs()->set_object_id( + actor_handle_id.Binary()); + + int num_resolved = 0; + actor_creator.actor_pending = true; + resolver.ResolveDependencies(task, [&](const Status &) { num_resolved++; }); + ASSERT_EQ(num_resolved, 0); + ASSERT_EQ(resolver.NumPendingTasks(), 1); + + for (const auto &cb : actor_creator.callbacks) { + cb(Status()); + } + ASSERT_EQ(num_resolved, 0); + + std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + auto data = RayObject(nullptr, meta_buffer, std::vector()); + ASSERT_TRUE(store->Put(data, obj)); + ASSERT_EQ(num_resolved, 1); + + ASSERT_EQ(resolver.NumPendingTasks(), 0); +} + +TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies2) { + // Object dependency resolved first. + auto store = std::make_shared(); + auto task_finisher = std::make_shared(); + MockActorCreator actor_creator; + LocalDependencyResolver resolver(*store, *task_finisher, actor_creator); + TaskSpecification task; + ObjectID obj = ObjectID::FromRandom(); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); + + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id); + task.GetMutableMessage().add_args()->add_nested_inlined_refs()->set_object_id( + actor_handle_id.Binary()); + + int num_resolved = 0; + actor_creator.actor_pending = true; + resolver.ResolveDependencies(task, [&](const Status &) { num_resolved++; }); + ASSERT_EQ(num_resolved, 0); + ASSERT_EQ(resolver.NumPendingTasks(), 1); + + std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + auto data = RayObject(nullptr, meta_buffer, std::vector()); + ASSERT_EQ(num_resolved, 0); + ASSERT_TRUE(store->Put(data, obj)); + + for (const auto &cb : actor_creator.callbacks) { + cb(Status()); + } + ASSERT_EQ(num_resolved, 1); + ASSERT_EQ(resolver.NumPendingTasks(), 0); +} + TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { auto store = std::make_shared(); auto task_finisher = std::make_shared(); diff --git a/src/ray/core_worker/test/memory_store_test.cc b/src/ray/core_worker/test/memory_store_test.cc index 84a7c8f7996ac..feee9973db850 100644 --- a/src/ray/core_worker/test/memory_store_test.cc +++ b/src/ray/core_worker/test/memory_store_test.cc @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/synchronization/mutex.h" - #include "ray/core_worker/store_provider/memory_store/memory_store.h" +#include "absl/synchronization/mutex.h" #include "gtest/gtest.h" #include "ray/common/test_util.h" @@ -29,8 +28,7 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { std::shared_ptr provider = std::make_shared( - nullptr, nullptr, nullptr, nullptr, - [&](const RayObject &obj) { unhandled_count++; }); + nullptr, nullptr, nullptr, [&](const RayObject &obj) { unhandled_count++; }); RayObject obj1(rpc::ErrorType::TASK_EXECUTION_EXCEPTION); RayObject obj2(rpc::ErrorType::TASK_EXECUTION_EXCEPTION); auto id1 = ObjectID::FromRandom(); @@ -52,7 +50,7 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { RAY_CHECK(provider->Put(obj1, id1)); RAY_CHECK(provider->Put(obj1, id2)); RAY_UNUSED(provider->Get({id1}, 1, 100, context, false, &results)); - provider->GetOrPromoteToPlasma(id2); + RAY_UNUSED(provider->Get({id2}, 1, 100, context, false, &results)); provider->Delete({id1, id2}); ASSERT_EQ(unhandled_count, 0); @@ -68,8 +66,7 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { TEST(TestMemoryStore, TestMemoryStoreStats) { /// Simple validation for test memory store stats. std::shared_ptr provider = - std::make_shared(nullptr, nullptr, nullptr, nullptr, - nullptr); + std::make_shared(nullptr, nullptr, nullptr, nullptr); // Iterate through the memory store and compare the values that are obtained by // GetMemoryStoreStatisticalData. diff --git a/src/ray/core_worker/transport/dependency_resolver.cc b/src/ray/core_worker/transport/dependency_resolver.cc index 3948c3732f1c4..da52aff657627 100644 --- a/src/ray/core_worker/transport/dependency_resolver.cc +++ b/src/ray/core_worker/transport/dependency_resolver.cc @@ -139,11 +139,13 @@ void LocalDependencyResolver::ResolveDependencies( for (const auto &actor_id : state->actor_dependencies) { actor_creator_.AsyncWaitForActorRegisterFinish( - actor_id, [state, on_complete](Status status) { + actor_id, [this, state, on_complete](const Status &status) { if (!status.ok()) { state->status = status; } - if (--state->actor_dependencies_remaining == 0) { + if (--state->actor_dependencies_remaining == 0 && + state->obj_dependencies_remaining == 0) { + num_pending_--; on_complete(state->status); } }); diff --git a/src/ray/gcs/asio.h b/src/ray/gcs/asio.h index fdcbbbf3cc3ef..d37083986ae1e 100644 --- a/src/ray/gcs/asio.h +++ b/src/ray/gcs/asio.h @@ -38,7 +38,7 @@ #include #include -#include +#include #include #include diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index e3bdcd96d79ab..6e54cb6b4f047 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -731,7 +731,7 @@ Status ServiceBasedNodeResourceInfoAccessor::AsyncUpdateResources( }); }; - sequencer_.Post(node_id, operation); + sequencer_.Post(node_id, std::move(operation)); return Status::OK(); } diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc index 223ee7ca71b52..a4950fabb0f14 100644 --- a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -36,6 +36,7 @@ class GlobalStateAccessorTest : public ::testing::Test { config.grpc_server_name = "MockedGcsServer"; config.grpc_server_thread_num = 1; config.redis_address = "127.0.0.1"; + config.node_ip_address = "127.0.0.1"; config.enable_sharding_conn = false; config.redis_port = TEST_REDIS_SERVER_PORTS.front(); diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc index 0e51ca7b84cce..0adf74b5c4e8b 100644 --- a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -47,6 +47,7 @@ class ServiceBasedGcsClientTest : public ::testing::Test { config_.grpc_server_name = "MockedGcsServer"; config_.grpc_server_thread_num = 1; config_.redis_address = "127.0.0.1"; + config_.node_ip_address = "127.0.0.1"; config_.enable_sharding_conn = false; config_.redis_port = TEST_REDIS_SERVER_PORTS.front(); // Tests legacy code paths. The poller and broadcaster have their own dedicated unit diff --git a/src/ray/gcs/gcs_server/gcs_actor_distribution.cc b/src/ray/gcs/gcs_server/gcs_actor_distribution.cc index 6eb523cdf730b..bec0fb7b89f7c 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_distribution.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_distribution.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/gcs/gcs_server/gcs_actor_distribution.h" + #include "ray/util/event.h" namespace ray { @@ -49,6 +50,9 @@ GcsBasedActorScheduler::GcsBasedActorScheduler( gcs_resource_scheduler_(std::move(gcs_resource_scheduler)) {} NodeID GcsBasedActorScheduler::SelectNode(std::shared_ptr actor) { + if (actor->GetActorWorkerAssignment()) { + ResetActorWorkerAssignment(actor.get()); + } // TODO(Chong-Li): Java actors may not need a sole assignment (worker process). bool need_sole_actor_worker_assignment = true; if (auto selected_actor_worker_assignment = SelectOrAllocateActorWorkerAssignment( @@ -221,5 +225,31 @@ void GcsBasedActorScheduler::HandleWorkerLeaseRejectedReply( Reschedule(actor); } +void GcsBasedActorScheduler::AddResourcesChangedListener(std::function listener) { + RAY_CHECK(listener != nullptr); + resource_changed_listeners_.emplace_back(std::move(listener)); +} + +void GcsBasedActorScheduler::NotifyClusterResourcesChanged() { + for (auto &listener : resource_changed_listeners_) { + listener(); + } +} + +void GcsBasedActorScheduler::ResetActorWorkerAssignment(GcsActor *actor) { + if (gcs_resource_manager_->ReleaseResources( + actor->GetActorWorkerAssignment()->GetNodeID(), + actor->GetActorWorkerAssignment()->GetResources())) { + NotifyClusterResourcesChanged(); + }; + actor->SetActorWorkerAssignment(nullptr); +} + +void GcsBasedActorScheduler::OnActorDestruction(std::shared_ptr actor) { + if (actor && actor->GetActorWorkerAssignment()) { + ResetActorWorkerAssignment(actor.get()); + } +} + } // namespace gcs } // namespace ray \ No newline at end of file diff --git a/src/ray/gcs/gcs_server/gcs_actor_distribution.h b/src/ray/gcs/gcs_server/gcs_actor_distribution.h index b8e2b6b2bd6d4..55f0f492e9a74 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_distribution.h +++ b/src/ray/gcs/gcs_server/gcs_actor_distribution.h @@ -93,6 +93,14 @@ class GcsBasedActorScheduler : public GcsActorScheduler { virtual ~GcsBasedActorScheduler() = default; + /// Handle the destruction of an actor. + /// + /// \param actor The actor to be destoryed. + void OnActorDestruction(std::shared_ptr actor) override; + + /// Add resources changed event handler. + void AddResourcesChangedListener(std::function listener); + protected: /// Select a node for the actor based on cluster resources. /// @@ -143,8 +151,17 @@ class GcsBasedActorScheduler : public GcsActorScheduler { void HandleWorkerLeaseRejectedReply(std::shared_ptr actor, const rpc::RequestWorkerLeaseReply &reply); + /// Reset the actor's current assignment, while releasing acquired resources. + void ResetActorWorkerAssignment(GcsActor *actor); + + /// Notify that the cluster resources are changed. + void NotifyClusterResourcesChanged(); + std::shared_ptr gcs_resource_manager_; + /// The resource changed listeners. + std::vector> resource_changed_listeners_; + /// Gcs resource scheduler std::shared_ptr gcs_resource_scheduler_; }; diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index c9f8c62375a6f..48a469f1cb377 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -107,6 +107,7 @@ void GcsActor::SetActorWorkerAssignment( ///////////////////////////////////////////////////////////////////////////////////////// GcsActorManager::GcsActorManager( + boost::asio::io_context &io_context, std::shared_ptr scheduler, std::shared_ptr gcs_table_storage, std::shared_ptr gcs_pub_sub, RuntimeEnvManager &runtime_env_manager, @@ -115,7 +116,8 @@ GcsActorManager::GcsActorManager( std::function, boost::posix_time::milliseconds)> run_delayed, const rpc::ClientFactoryFn &worker_client_factory) - : gcs_actor_scheduler_(std::move(scheduler)), + : io_context_(io_context), + gcs_actor_scheduler_(std::move(scheduler)), gcs_table_storage_(std::move(gcs_table_storage)), gcs_pub_sub_(std::move(gcs_pub_sub)), worker_client_factory_(worker_client_factory), @@ -126,6 +128,17 @@ GcsActorManager::GcsActorManager( actor_gc_delay_(RayConfig::instance().gcs_actor_table_min_duration_ms()) { RAY_CHECK(worker_client_factory_); RAY_CHECK(destroy_owned_placement_group_if_needed_); + if (RayConfig::instance().gcs_actor_scheduling_enabled()) { + auto gcs_actor_scheduler = + std::dynamic_pointer_cast(gcs_actor_scheduler_); + gcs_actor_scheduler->AddResourcesChangedListener([this] { + bool posted = GetSchedulePendingActorsPosted(); + if (!posted) { + SetSchedulePendingActorsPosted(true); + io_context_.post([this] { SchedulePendingActors(); }); + } + }); + } } void GcsActorManager::HandleRegisterActor(const rpc::RegisterActorRequest &request, @@ -187,13 +200,13 @@ void GcsActorManager::HandleGetActorInfo(const rpc::GetActorInfoRequest &request const auto ®istered_actor_iter = registered_actors_.find(actor_id); if (registered_actor_iter != registered_actors_.end()) { - reply->mutable_actor_table_data()->CopyFrom( - registered_actor_iter->second->GetActorTableData()); + reply->unsafe_arena_set_allocated_actor_table_data( + registered_actor_iter->second->GetMutableActorTableData()); } else { const auto &destroyed_actor_iter = destroyed_actors_.find(actor_id); if (destroyed_actor_iter != destroyed_actors_.end()) { - reply->mutable_actor_table_data()->CopyFrom( - destroyed_actor_iter->second->GetActorTableData()); + reply->unsafe_arena_set_allocated_actor_table_data( + destroyed_actor_iter->second->GetMutableActorTableData()); } } @@ -210,10 +223,12 @@ void GcsActorManager::HandleGetAllActorInfo(const rpc::GetAllActorInfoRequest &r ++counts_[CountType::GET_ALL_ACTOR_INFO_REQUEST]; if (request.show_dead_jobs() == false) { for (const auto &iter : registered_actors_) { - reply->add_actor_table_data()->CopyFrom(iter.second->GetActorTableData()); + reply->mutable_actor_table_data()->UnsafeArenaAddAllocated( + const_cast(iter.second->GetMutableActorTableData())); } for (const auto &iter : destroyed_actors_) { - reply->add_actor_table_data()->CopyFrom(iter.second->GetActorTableData()); + reply->mutable_actor_table_data()->UnsafeArenaAddAllocated( + const_cast(iter.second->GetMutableActorTableData())); } RAY_LOG(DEBUG) << "Finished getting all actor info."; GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); @@ -227,7 +242,9 @@ void GcsActorManager::HandleGetAllActorInfo(const rpc::GetAllActorInfoRequest &r [reply, send_reply_callback]( const std::unordered_map &result) { for (const auto &pair : result) { - reply->add_actor_table_data()->CopyFrom(pair.second); + // TODO yic: Fix const cast + reply->mutable_actor_table_data()->UnsafeArenaAddAllocated( + const_cast(&pair.second)); } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); RAY_LOG(DEBUG) << "Finished getting all actor info."; @@ -258,7 +275,8 @@ void GcsActorManager::HandleGetNamedActorInfo( RAY_LOG(WARNING) << stream.str(); status = Status::NotFound(stream.str()); } else { - reply->mutable_actor_table_data()->CopyFrom(iter->second->GetActorTableData()); + reply->unsafe_arena_set_allocated_actor_table_data( + iter->second->GetMutableActorTableData()); RAY_LOG(DEBUG) << "Finished getting actor info, job id = " << actor_id.JobId() << ", actor id = " << actor_id; } @@ -275,10 +293,9 @@ void GcsActorManager::HandleListNamedActors(const rpc::ListNamedActorsRequest &r std::vector> actors = ListNamedActors(request.all_namespaces(), ray_namespace); for (const auto &actor : actors) { - rpc::NamedActorInfo named_actor_info; - named_actor_info.set_ray_namespace(actor.first); - named_actor_info.set_name(actor.second); - reply->add_named_actors_list()->CopyFrom(named_actor_info); + auto named_actor_indo = reply->add_named_actors_list(); + named_actor_indo->set_ray_namespace(actor.first); + named_actor_indo->set_name(actor.second); } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); ++counts_[CountType::LIST_NAMED_ACTORS_REQUEST]; @@ -381,13 +398,9 @@ Status GcsActorManager::RegisterActor(const ray::rpc::RegisterActorRequest &requ // owner to determine when the actor should be removed. PollOwnerForActorOutOfScope(actor); } else { - // If it's a detached actor, we need to register the runtime env it used to GC - auto job_id = JobID::FromBinary(request.task_spec().job_id()); - const auto &uris = runtime_env_manager_.GetReferences(job_id.Hex()); - auto actor_id_hex = actor->GetActorID().Hex(); - for (const auto &uri : uris) { - runtime_env_manager_.AddURIReference(actor_id_hex, uri); - } + // If it's a detached actor, we need to register the runtime env it used to GC. + runtime_env_manager_.AddURIReference(actor->GetActorID().Hex(), + request.task_spec().runtime_env()); } // The backend storage is supposed to be reliable, so the status must be ok. @@ -575,6 +588,11 @@ void GcsActorManager::DestroyActor(const ActorID &actor_id) { RAY_LOG(INFO) << "Tried to destroy actor that does not exist " << actor_id; return; } + + if (RayConfig::instance().gcs_actor_scheduling_enabled()) { + gcs_actor_scheduler_->OnActorDestruction(it->second); + } + const auto &task_id = it->second->GetCreationTaskSpecification().TaskId(); it->second->GetMutableActorTableData()->mutable_task_spec()->Clear(); it->second->GetMutableActorTableData()->set_timestamp(current_sys_time_ms()); @@ -957,6 +975,7 @@ void GcsActorManager::OnActorCreationSuccess(const std::shared_ptr &ac } void GcsActorManager::SchedulePendingActors() { + schedule_pending_actors_posted_ = false; if (pending_actors_.empty()) { return; } @@ -968,6 +987,14 @@ void GcsActorManager::SchedulePendingActors() { } } +bool GcsActorManager::GetSchedulePendingActorsPosted() const { + return schedule_pending_actors_posted_; +} + +void GcsActorManager::SetSchedulePendingActorsPosted(bool posted) { + schedule_pending_actors_posted_ = posted; +} + void GcsActorManager::Initialize(const GcsInitData &gcs_init_data) { const auto &jobs = gcs_init_data.Jobs(); std::unordered_map> node_to_workers; diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.h b/src/ray/gcs/gcs_server/gcs_actor_manager.h index 569c9b2b19172..9050eb4dfc9fe 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.h +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.h @@ -21,6 +21,7 @@ #include "ray/common/runtime_env_manager.h" #include "ray/common/task/task_execution_spec.h" #include "ray/common/task/task_spec.h" +#include "ray/gcs/gcs_server/gcs_actor_distribution.h" #include "ray/gcs/gcs_server/gcs_actor_scheduler.h" #include "ray/gcs/gcs_server/gcs_init_data.h" #include "ray/gcs/gcs_server/gcs_table_storage.h" @@ -86,7 +87,8 @@ class GcsActor { break; } - actor_table_data_.set_serialized_runtime_env(task_spec.serialized_runtime_env()); + actor_table_data_.set_serialized_runtime_env( + task_spec.runtime_env().serialized_runtime_env()); } /// Get the node id on which this actor is created. @@ -193,6 +195,7 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// \param gcs_table_storage Used to flush actor data to storage. /// \param gcs_pub_sub Used to publish gcs message. GcsActorManager( + boost::asio::io_context &io_context, std::shared_ptr scheduler, std::shared_ptr gcs_table_storage, std::shared_ptr gcs_pub_sub, RuntimeEnvManager &runtime_env_manager, @@ -341,6 +344,10 @@ class GcsActorManager : public rpc::ActorInfoHandler { std::string DebugString() const; + bool GetSchedulePendingActorsPosted() const; + + void SetSchedulePendingActorsPosted(bool posted); + private: /// A data structure representing an actor's owner. struct Owner { @@ -485,6 +492,7 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// according to its owner, or the owner dies. absl::flat_hash_map> owners_; + boost::asio::io_context &io_context_; /// The scheduler to schedule all registered actors. std::shared_ptr gcs_actor_scheduler_; /// Used to update actor information upon creation, deletion, etc. @@ -508,6 +516,9 @@ class GcsActorManager : public rpc::ActorInfoHandler { run_delayed_; const boost::posix_time::milliseconds actor_gc_delay_; + /// Indicate whether a call of SchedulePendingActors has been posted. + bool schedule_pending_actors_posted_; + // Debug info. enum CountType { REGISTER_ACTOR_REQUEST = 0, diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc index 81d476a80854b..28b7a7453eef0 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc @@ -232,7 +232,7 @@ void GcsActorScheduler::LeaseWorkerFromNode(std::shared_ptr actor, // backlog in GCS. int backlog_size = report_worker_backlog_ ? 0 : -1; lease_client->RequestWorkerLease( - actor->GetCreationTaskSpecification(), + actor->GetActorTableData().task_spec(), [this, actor, node](const Status &status, const rpc::RequestWorkerLeaseReply &reply) { HandleWorkerLeaseReply(actor, node, status, reply); diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h index 34d7d3ea3a186..d34c97767ff63 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h @@ -75,6 +75,11 @@ class GcsActorSchedulerInterface { virtual void ReleaseUnusedWorkers( const std::unordered_map> &node_to_workers) = 0; + /// Handle the destruction of an actor. + /// + /// \param actor The actor to be destoryed. + virtual void OnActorDestruction(std::shared_ptr actor) = 0; + virtual ~GcsActorSchedulerInterface() {} }; @@ -146,6 +151,11 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { void ReleaseUnusedWorkers( const std::unordered_map> &node_to_workers) override; + /// Handle the destruction of an actor. + /// + /// \param actor The actor to be destoryed. + void OnActorDestruction(std::shared_ptr actor) override {} + protected: /// The GcsLeasedWorker is kind of abstraction of remote leased worker inside raylet. It /// contains the address of remote leased worker as well as the leased resources and the diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index 91782a712db5b..c84e19372a6b4 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -84,11 +84,15 @@ void GcsNodeManager::HandleUnregisterNode(const rpc::UnregisterNodeRequest &requ void GcsNodeManager::HandleGetAllNodeInfo(const rpc::GetAllNodeInfoRequest &request, rpc::GetAllNodeInfoReply *reply, rpc::SendReplyCallback send_reply_callback) { + // Here the unsafe allocate is safe here, because entry.second's life cycle is longer + // then reply. + // The request will be sent when call send_reply_callback and after that, reply will + // not be used any more. But entry is still valid. for (const auto &entry : alive_nodes_) { - reply->add_node_info_list()->CopyFrom(*entry.second); + reply->mutable_node_info_list()->UnsafeArenaAddAllocated(entry.second.get()); } for (const auto &entry : dead_nodes_) { - reply->add_node_info_list()->CopyFrom(*entry.second); + reply->mutable_node_info_list()->UnsafeArenaAddAllocated(entry.second.get()); } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); ++counts_[CountType::GET_ALL_NODE_INFO_REQUEST]; diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc index f41f9d45bd6e7..7879a9fd71bce 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc @@ -45,22 +45,26 @@ std::string GcsPlacementGroup::GetRayNamespace() const { return placement_group_table_data_.ray_namespace(); } -std::vector> GcsPlacementGroup::GetBundles() const { - const auto &bundles = placement_group_table_data_.bundles(); - std::vector> ret_bundles; - for (const auto &bundle : bundles) { - ret_bundles.push_back(std::make_shared(bundle)); +std::vector> &GcsPlacementGroup::GetBundles() + const { + // Fill the cache if it wasn't. + if (cached_bundle_specs_.empty()) { + const auto &bundles = placement_group_table_data_.bundles(); + for (const auto &bundle : bundles) { + cached_bundle_specs_.push_back(std::make_shared(bundle)); + } } - return ret_bundles; + return cached_bundle_specs_; } -std::vector> GcsPlacementGroup::GetUnplacedBundles() - const { - const auto &bundles = placement_group_table_data_.bundles(); - std::vector> unplaced_bundles; - for (const auto &bundle : bundles) { - if (NodeID::FromBinary(bundle.node_id()).IsNil()) { - unplaced_bundles.push_back(std::make_shared(bundle)); +std::vector> +GcsPlacementGroup::GetUnplacedBundles() const { + const auto &bundle_specs = GetBundles(); + + std::vector> unplaced_bundles; + for (const auto &bundle : bundle_specs) { + if (bundle->NodeId().IsNil()) { + unplaced_bundles.push_back(bundle); } } return unplaced_bundles; @@ -83,6 +87,8 @@ std::string GcsPlacementGroup::DebugString() const { } rpc::Bundle *GcsPlacementGroup::GetMutableBundle(int bundle_index) { + // Invalidate the cache. + cached_bundle_specs_.clear(); return placement_group_table_data_.mutable_bundles(bundle_index); } @@ -176,7 +182,7 @@ void GcsPlacementGroupManager::RegisterPlacementGroup( .emplace_back(std::move(callback)); registered_placement_groups_.emplace(placement_group->GetPlacementGroupID(), placement_group); - pending_placement_groups_.emplace_back(placement_group); + AddToPendingQueue(placement_group); RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group_id, placement_group->GetPlacementGroupTableData(), @@ -221,7 +227,8 @@ PlacementGroupID GcsPlacementGroupManager::GetPlacementGroupIDByName( } void GcsPlacementGroupManager::OnPlacementGroupCreationFailed( - std::shared_ptr placement_group, bool is_feasible) { + std::shared_ptr placement_group, ExponentialBackOff backoff, + bool is_feasible) { RAY_LOG(DEBUG) << "Failed to create placement group " << placement_group->GetName() << ", id: " << placement_group->GetPlacementGroupID() << ", try again."; @@ -229,7 +236,6 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationFailed( // We will attempt to schedule this placement_group once an eligible node is // registered. infeasible_placement_groups_.emplace_back(std::move(placement_group)); - MarkSchedulingDone(); } else { auto state = placement_group->GetState(); RAY_CHECK(state == rpc::PlacementGroupTableData::RESCHEDULING || @@ -241,14 +247,13 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationFailed( // NOTE: If a node is dead, the placement group scheduler should try to recover the // group by rescheduling the bundles of the dead node. This should have higher // priority than trying to place other placement groups. - pending_placement_groups_.emplace_front(std::move(placement_group)); + AddToPendingQueue(std::move(placement_group), /* rank */ 0); } else { - pending_placement_groups_.emplace_back(std::move(placement_group)); + AddToPendingQueue(std::move(placement_group), std::nullopt, backoff); } - - MarkSchedulingDone(); - RetryCreatingPlacementGroup(); } + io_context_.post([this] { SchedulePendingPlacementGroups(); }); + MarkSchedulingDone(); } void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( @@ -256,16 +261,11 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( RAY_LOG(INFO) << "Successfully created placement group " << placement_group->GetName() << ", id: " << placement_group->GetPlacementGroupID(); placement_group->UpdateState(rpc::PlacementGroupTableData::CREATED); - // Mark the scheduling done firstly. - MarkSchedulingDone(); auto placement_group_id = placement_group->GetPlacementGroupID(); RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group_id, placement_group->GetPlacementGroupTableData(), [this, placement_group_id](Status status) { RAY_CHECK_OK(status); - - SchedulePendingPlacementGroups(); - // Invoke all callbacks for all `WaitPlacementGroupUntilReady` requests of this // placement group and remove all of them from // placement_group_to_create_callbacks_. @@ -278,6 +278,8 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( placement_group_to_create_callbacks_.erase(pg_to_create_iter); } })); + io_context_.post([this] { SchedulePendingPlacementGroups(); }); + MarkSchedulingDone(); } void GcsPlacementGroupManager::SchedulePendingPlacementGroups() { @@ -294,16 +296,28 @@ void GcsPlacementGroupManager::SchedulePendingPlacementGroups() { bool is_new_placement_group_scheduled = false; while (!pending_placement_groups_.empty() && !is_new_placement_group_scheduled) { - const auto placement_group = pending_placement_groups_.front(); - pending_placement_groups_.pop_front(); + auto iter = pending_placement_groups_.begin(); + if (iter->first > absl::GetCurrentTimeNanos()) { + // Here the rank equals the time to schedule, and it's an ordered tree, + // it means all the other tasks should be scheduled after this one. + // If the first one won't be scheduled, we just skip. + // Tick will cover the next time retry. + break; + } + auto backoff = iter->second.first; + auto placement_group = std::move(iter->second.second); + pending_placement_groups_.erase(iter); + const auto &placement_group_id = placement_group->GetPlacementGroupID(); // Do not reschedule if the placement group has removed already. if (registered_placement_groups_.contains(placement_group_id)) { MarkSchedulingStarted(placement_group_id); gcs_placement_group_scheduler_->ScheduleUnplacedBundles( placement_group, - [this](std::shared_ptr placement_group, bool is_insfeasble) { - OnPlacementGroupCreationFailed(std::move(placement_group), is_insfeasble); + [this, backoff](std::shared_ptr placement_group, + bool is_insfeasble) { + OnPlacementGroupCreationFailed(std::move(placement_group), backoff, + is_insfeasble); }, [this](std::shared_ptr placement_group) { OnPlacementGroupCreationSuccess(std::move(placement_group)); @@ -312,6 +326,7 @@ void GcsPlacementGroupManager::SchedulePendingPlacementGroups() { } // If the placement group is not registered == removed. } + ++counts_[CountType::SCHEDULING_PENDING_PLACEMENT_GROUP]; } void GcsPlacementGroupManager::HandleCreatePlacementGroup( @@ -393,18 +408,10 @@ void GcsPlacementGroupManager::RemovePlacementGroup( } // Remove a placement group from a pending list if exists. - auto pending_it = std::find_if( - pending_placement_groups_.begin(), pending_placement_groups_.end(), - [placement_group_id](const std::shared_ptr &placement_group) { - return placement_group->GetPlacementGroupID() == placement_group_id; - }); - if (pending_it != pending_placement_groups_.end()) { - // The placement group was pending scheduling, remove it from the queue. - pending_placement_groups_.erase(pending_it); - } + RemoveFromPendingQueue(placement_group_id); // Remove a placement group from infeasible queue if exists. - pending_it = std::find_if( + auto pending_it = std::find_if( infeasible_placement_groups_.begin(), infeasible_placement_groups_.end(), [placement_group_id](const std::shared_ptr &placement_group) { return placement_group->GetPlacementGroupID() == placement_group_id; @@ -573,9 +580,36 @@ void GcsPlacementGroupManager::WaitPlacementGroup( } } -void GcsPlacementGroupManager::RetryCreatingPlacementGroup() { - execute_after(io_context_, [this] { SchedulePendingPlacementGroups(); }, - RayConfig::instance().gcs_create_placement_group_retry_interval_ms()); +void GcsPlacementGroupManager::AddToPendingQueue( + std::shared_ptr pg, std::optional rank, + std::optional exp_backer) { + if (!rank) { + rank = absl::GetCurrentTimeNanos(); + } + + if (!exp_backer) { + exp_backer = ExponentialBackOff( + 1000000 * + RayConfig::instance().gcs_create_placement_group_retry_min_interval_ms(), + RayConfig::instance().gcs_create_placement_group_retry_multiplier(), + 1000000 * + RayConfig::instance().gcs_create_placement_group_retry_max_interval_ms()); + } else { + *rank += static_cast(exp_backer->Next()); + } + auto val = std::make_pair(*exp_backer, std::move(pg)); + pending_placement_groups_.emplace(*rank, std::move(val)); +} + +void GcsPlacementGroupManager::RemoveFromPendingQueue(const PlacementGroupID &pg_id) { + auto it = std::find_if(pending_placement_groups_.begin(), + pending_placement_groups_.end(), [&pg_id](const auto &val) { + return val.second.second->GetPlacementGroupID() == pg_id; + }); + // The placement group was pending scheduling, remove it from the queue. + if (it != pending_placement_groups_.end()) { + pending_placement_groups_.erase(it); + } } void GcsPlacementGroupManager::OnNodeDead(const NodeID &node_id) { @@ -593,7 +627,7 @@ void GcsPlacementGroupManager::OnNodeDead(const NodeID &node_id) { // creating until a node with the resources is added. we will solve it in next pr. if (iter->second->GetState() != rpc::PlacementGroupTableData::RESCHEDULING) { iter->second->UpdateState(rpc::PlacementGroupTableData::RESCHEDULING); - pending_placement_groups_.emplace_front(iter->second); + AddToPendingQueue(iter->second, 0); } } } @@ -609,9 +643,9 @@ void GcsPlacementGroupManager::OnNodeAdd(const NodeID &node_id) { // Move all the infeasible placement groups to the pending queue so that we can // reschedule them. if (infeasible_placement_groups_.size() > 0) { - auto end_it = pending_placement_groups_.end(); - pending_placement_groups_.insert(end_it, infeasible_placement_groups_.cbegin(), - infeasible_placement_groups_.cend()); + for (auto &pg : infeasible_placement_groups_) { + AddToPendingQueue(std::move(pg)); + } infeasible_placement_groups_.clear(); } SchedulePendingPlacementGroups(); @@ -667,14 +701,16 @@ void GcsPlacementGroupManager::Tick() { // Note that we don't currently have a known race condition that requires this, but we // added as a safety check. https://github.com/ray-project/ray/pull/18419 SchedulePendingPlacementGroups(); - execute_after(io_context_, [this] { Tick(); }, 1000 /* milliseconds */); + execute_after( + io_context_, [this] { Tick(); }, 1000 /* milliseconds */); } void GcsPlacementGroupManager::UpdatePlacementGroupLoad() { std::shared_ptr placement_group_load = std::make_shared(); int total_cnt = 0; - for (const auto &pending_pg_spec : pending_placement_groups_) { + for (const auto &elem : pending_placement_groups_) { + const auto pending_pg_spec = elem.second.second; auto placement_group_data = placement_group_load->add_placement_group_data(); auto placement_group_table_data = pending_pg_spec->GetPlacementGroupTableData(); placement_group_data->Swap(&placement_group_table_data); @@ -710,7 +746,7 @@ void GcsPlacementGroupManager::Initialize(const GcsInitData &gcs_init_data) { if (item.second.state() == rpc::PlacementGroupTableData::PENDING || item.second.state() == rpc::PlacementGroupTableData::RESCHEDULING) { - pending_placement_groups_.emplace_back(std::move(placement_group)); + AddToPendingQueue(std::move(placement_group)); } if (item.second.state() == rpc::PlacementGroupTableData::CREATED || @@ -749,6 +785,8 @@ std::string GcsPlacementGroupManager::DebugString() const { << counts_[CountType::WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST] << ", GetNamedPlacementGroup request count: " << counts_[CountType::GET_NAMED_PLACEMENT_GROUP_REQUEST] + << ", Scheduling pending placement group count: " + << counts_[CountType::SCHEDULING_PENDING_PLACEMENT_GROUP] << ", Registered placement groups count: " << registered_placement_groups_.size() << ", Named placement group count: " << num_pgs << ", Pending placement groups count: " << pending_placement_groups_.size() diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h index bc3407fd8ac02..93bc68d306e43 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h @@ -13,8 +13,12 @@ // limitations under the License. #pragma once +#include + +#include #include +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "ray/common/asio/instrumented_io_context.h" @@ -89,10 +93,10 @@ class GcsPlacementGroup { std::string GetRayNamespace() const; /// Get the bundles of this placement_group (including unplaced). - std::vector> GetBundles() const; + std::vector> &GetBundles() const; /// Get the unplaced bundles of this placement group. - std::vector> GetUnplacedBundles() const; + std::vector> GetUnplacedBundles() const; /// Get the Strategy rpc::PlacementStrategy GetStrategy() const; @@ -121,9 +125,14 @@ class GcsPlacementGroup { bool IsDetached() const; private: + FRIEND_TEST(GcsPlacementGroupManagerTest, TestPlacementGroupBundleCache); /// The placement_group meta data which contains the task specification as well as the /// state of the gcs placement_group and so on (see gcs.proto). rpc::PlacementGroupTableData placement_group_table_data_; + /// Creating bundle specification requires heavy computation because it needs to compute + /// formatted strings for all resources (heavy string operations). To optimize the CPU + /// usage, we cache bundle specs. + mutable std::vector> cached_bundle_specs_; }; /// GcsPlacementGroupManager is responsible for managing the lifecycle of all placement @@ -209,7 +218,7 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { /// \param placement_group The placement_group whose creation task is infeasible. /// \param is_feasible whether the scheduler can be retry or not currently. void OnPlacementGroupCreationFailed(std::shared_ptr placement_group, - bool is_feasible = true); + ExponentialBackOff backoff, bool is_feasible); /// Handle placement_group creation task success. This should be called when the /// placement_group creation task has been scheduled successfully. @@ -277,6 +286,19 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { std::string DebugString() const; private: + /// Push a placement group to pending queue. + /// + /// \param pg The placementgroup we are adding + /// \param rank The rank for this placement group. Semantically it's the time + /// this placement group to be scheduled. By default it'll be assigned to be + /// the current time. + /// \param exp_backer The exponential backoff. A default one will be given if + /// it's not set. This will be used to generate the deferred time for this pg. + void AddToPendingQueue(std::shared_ptr pg, + std::optional rank = std::nullopt, + std::optional exp_backer = std::nullopt); + void RemoveFromPendingQueue(const PlacementGroupID &pg_id); + /// Try to create placement group after a short time. void RetryCreatingPlacementGroup(); @@ -322,12 +344,17 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { absl::flat_hash_map> registered_placement_groups_; - /// The pending placement_groups which will not be scheduled until there's a resource - /// change. - /// NOTE: When we remove placement group, we need to look for - /// `pending_placement_groups_` and delete the specific placement group, so we can't use - /// `std::priority_queue`. - std::deque> pending_placement_groups_; + /// The pending placement_groups which will not be scheduled until there's a + /// resource change. The pending queue is represented as an ordered map, where + /// the key is the time to schedule the pg and value if a pair containing the + /// actual placement group and a exp-backoff. + /// When error happens, we'll retry it later and this can be simply done by + /// inserting an element into the queue with a bigger key. With this, we don't + /// need to post retry job to io context. And when schedule pending placement + /// group, we always start with the one with the smallest key. + absl::btree_multimap>> + pending_placement_groups_; /// The infeasible placement_groups that can't be scheduled currently. std::deque> infeasible_placement_groups_; @@ -363,9 +390,14 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { GET_ALL_PLACEMENT_GROUP_REQUEST = 3, WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST = 4, GET_NAMED_PLACEMENT_GROUP_REQUEST = 5, - CountType_MAX = 6, + SCHEDULING_PENDING_PLACEMENT_GROUP = 6, + CountType_MAX = 7, }; uint64_t counts_[CountType::CountType_MAX] = {0}; + + FRIEND_TEST(GcsPlacementGroupManagerMockTest, PendingQueuePriorityReschedule); + FRIEND_TEST(GcsPlacementGroupManagerMockTest, PendingQueuePriorityFailed); + FRIEND_TEST(GcsPlacementGroupManagerMockTest, PendingQueuePriorityOrder); }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc index c2ca3c3c8cd40..7c9391315a945 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc @@ -39,7 +39,7 @@ GcsPlacementGroupScheduler::GcsPlacementGroupScheduler( } std::vector GcsScheduleStrategy::GetRequiredResourcesFromBundles( - const std::vector> &bundles) { + const std::vector> &bundles) { std::vector required_resources; for (const auto &bundle : bundles) { required_resources.push_back(bundle->GetRequiredResources()); @@ -48,7 +48,7 @@ std::vector GcsScheduleStrategy::GetRequiredResourcesFromBundles( } ScheduleResult GcsScheduleStrategy::GenerateScheduleResult( - const std::vector> &bundles, + const std::vector> &bundles, const std::vector &selected_nodes, const SchedulingResultStatus &status) { ScheduleMap schedule_map; if (status == SUCCESS && !selected_nodes.empty()) { @@ -62,7 +62,7 @@ ScheduleResult GcsScheduleStrategy::GenerateScheduleResult( } ScheduleResult GcsStrictPackStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { const auto &required_resources = GetRequiredResourcesFromBundles(bundles); @@ -73,7 +73,7 @@ ScheduleResult GcsStrictPackStrategy::Schedule( } ScheduleResult GcsPackStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { // The current algorithm is to select a node and deploy as many bundles as possible. @@ -87,7 +87,7 @@ ScheduleResult GcsPackStrategy::Schedule( } ScheduleResult GcsSpreadStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { const auto &required_resources = GetRequiredResourcesFromBundles(bundles); @@ -98,7 +98,7 @@ ScheduleResult GcsSpreadStrategy::Schedule( } ScheduleResult GcsStrictSpreadStrategy::Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) { // TODO(ffbin): A bundle may require special resources, such as GPU. We need to @@ -211,7 +211,7 @@ void GcsPlacementGroupScheduler::MarkScheduleCancelled( } void GcsPlacementGroupScheduler::PrepareResources( - const std::shared_ptr &bundle, + const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback &callback) { if (!node.has_value()) { @@ -240,7 +240,7 @@ void GcsPlacementGroupScheduler::PrepareResources( } void GcsPlacementGroupScheduler::CommitResources( - const std::shared_ptr &bundle, + const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback callback) { RAY_CHECK(node.has_value()); @@ -265,7 +265,7 @@ void GcsPlacementGroupScheduler::CommitResources( } void GcsPlacementGroupScheduler::CancelResourceReserve( - const std::shared_ptr &bundle_spec, + const std::shared_ptr &bundle_spec, const absl::optional> &node) { if (!node.has_value()) { RAY_LOG(INFO) << "Node for a placement group id " << bundle_spec->PlacementGroupId() @@ -660,7 +660,7 @@ void BundleLocationIndex::AddNodes( LeaseStatusTracker::LeaseStatusTracker( std::shared_ptr placement_group, - const std::vector> &unplaced_bundles, + const std::vector> &unplaced_bundles, const ScheduleMap &schedule_map) : placement_group_(placement_group), bundles_to_schedule_(unplaced_bundles) { preparing_bundle_locations_ = std::make_shared(); @@ -675,13 +675,13 @@ LeaseStatusTracker::LeaseStatusTracker( } bool LeaseStatusTracker::MarkPreparePhaseStarted( - const NodeID &node_id, std::shared_ptr bundle) { + const NodeID &node_id, const std::shared_ptr &bundle) { const auto &bundle_id = bundle->BundleId(); return node_to_bundles_when_preparing_[node_id].emplace(bundle_id).second; } void LeaseStatusTracker::MarkPrepareRequestReturned( - const NodeID &node_id, const std::shared_ptr bundle, + const NodeID &node_id, const std::shared_ptr &bundle, const Status &status) { RAY_CHECK(prepare_request_returned_count_ <= bundles_to_schedule_.size()); auto leasing_bundles = node_to_bundles_when_preparing_.find(node_id); @@ -715,7 +715,7 @@ bool LeaseStatusTracker::AllPrepareRequestsSuccessful() const { } void LeaseStatusTracker::MarkCommitRequestReturned( - const NodeID &node_id, const std::shared_ptr bundle, + const NodeID &node_id, const std::shared_ptr &bundle, const Status &status) { commit_request_returned_count_ += 1; // If the request succeeds, record it. @@ -762,7 +762,7 @@ const std::shared_ptr &LeaseStatusTracker::GetBundleLocations() return bundle_locations_; } -const std::vector> +const std::vector> &LeaseStatusTracker::GetBundlesToSchedule() const { return bundles_to_schedule_; } diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h index bdfee4276dec5..4e921ab13e248 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h @@ -49,9 +49,8 @@ struct pair_hash { }; using ScheduleMap = std::unordered_map; using ScheduleResult = std::pair; -using BundleLocations = - absl::flat_hash_map>, - pair_hash>; +using BundleLocations = absl::flat_hash_map< + BundleID, std::pair>, pair_hash>; class GcsPlacementGroupSchedulerInterface { public: @@ -112,7 +111,7 @@ class GcsScheduleStrategy { public: virtual ~GcsScheduleStrategy() {} virtual ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) = 0; @@ -122,7 +121,7 @@ class GcsScheduleStrategy { /// \param bundles Bundles to be scheduled. /// \return Required resources. std::vector GetRequiredResourcesFromBundles( - const std::vector> &bundles); + const std::vector> &bundles); /// Generate `ScheduleResult` from bundles and nodes . /// @@ -131,7 +130,7 @@ class GcsScheduleStrategy { /// \param status Status of the scheduling result. /// \return The scheduling result from the required resource. ScheduleResult GenerateScheduleResult( - const std::vector> &bundles, + const std::vector> &bundles, const std::vector &selected_nodes, const SchedulingResultStatus &status); }; @@ -141,7 +140,7 @@ class GcsScheduleStrategy { class GcsPackStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -150,7 +149,7 @@ class GcsPackStrategy : public GcsScheduleStrategy { class GcsSpreadStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -160,7 +159,7 @@ class GcsSpreadStrategy : public GcsScheduleStrategy { class GcsStrictPackStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -171,7 +170,7 @@ class GcsStrictPackStrategy : public GcsScheduleStrategy { class GcsStrictSpreadStrategy : public GcsScheduleStrategy { public: ScheduleResult Schedule( - const std::vector> &bundles, + const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler) override; }; @@ -192,7 +191,7 @@ class LeaseStatusTracker { public: LeaseStatusTracker( std::shared_ptr placement_group, - const std::vector> &unplaced_bundles, + const std::vector> &unplaced_bundles, const ScheduleMap &schedule_map); ~LeaseStatusTracker() = default; @@ -202,7 +201,7 @@ class LeaseStatusTracker { /// \param bundle Bundle specification the node is supposed to prepare. /// \return False if the prepare phase was already started. True otherwise. bool MarkPreparePhaseStarted(const NodeID &node_id, - std::shared_ptr bundle); + const std::shared_ptr &bundle); /// Indicate the tracker that all prepare requests are returned. /// @@ -210,9 +209,9 @@ class LeaseStatusTracker { /// \param bundle Bundle specification the node was supposed to schedule. /// \param status Status of the prepare response. /// \param void - void MarkPrepareRequestReturned(const NodeID &node_id, - std::shared_ptr bundle, - const Status &status); + void MarkPrepareRequestReturned( + const NodeID &node_id, const std::shared_ptr &bundle, + const Status &status); /// Used to know if all prepare requests are returned. /// @@ -230,7 +229,7 @@ class LeaseStatusTracker { /// \param bundle Bundle specification the node was supposed to schedule. /// \param status Status of the returned commit request. void MarkCommitRequestReturned(const NodeID &node_id, - const std::shared_ptr bundle, + const std::shared_ptr &bundle, const Status &status); /// Used to know if all commit requests are returend. @@ -251,7 +250,8 @@ class LeaseStatusTracker { /// Return bundles that should be scheduled. /// /// \return List of bundle specification that are supposed to be scheduled. - const std::vector> &GetBundlesToSchedule() const; + [[nodiscard]] const std::vector> + &GetBundlesToSchedule() const; /// This method returns bundle locations that succeed to prepare resources. /// @@ -324,7 +324,7 @@ class LeaseStatusTracker { node_to_bundles_when_preparing_; /// Bundles to schedule. - std::vector> bundles_to_schedule_; + std::vector> bundles_to_schedule_; /// Location of bundles. std::shared_ptr bundle_locations_; @@ -460,7 +460,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// \param node A node to prepare resources for a given bundle. /// \param callback void PrepareResources( - const std::shared_ptr &bundle, + const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback &callback); @@ -470,7 +470,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// \param bundle A bundle to schedule on a node. /// \param node A node to commit resources for a given bundle. /// \param callback - void CommitResources(const std::shared_ptr &bundle, + void CommitResources(const std::shared_ptr &bundle, const absl::optional> &node, const StatusCallback callback); @@ -481,7 +481,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// \param bundle A description of the bundle to return. /// \param node The node that the worker will be returned for. void CancelResourceReserve( - const std::shared_ptr &bundle_spec, + const std::shared_ptr &bundle_spec, const absl::optional> &node); /// Get an existing lease client or connect a new one or connect a new one. diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 6a4d60c685e9b..84821c6af1d3a 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -35,7 +35,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, : config_(config), main_service_(main_service), rpc_server_(config.grpc_server_name, config.grpc_server_port, - config.grpc_server_thread_num, + config.node_ip_address == "127.0.0.1", config.grpc_server_thread_num, /*keepalive_time_ms=*/RayConfig::instance().grpc_keepalive_time_ms()), client_call_manager_(main_service), raylet_client_pool_( @@ -267,7 +267,8 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { client_factory); } gcs_actor_manager_ = std::make_shared( - std::move(scheduler), gcs_table_storage_, gcs_pub_sub_, *runtime_env_manager_, + main_service_, std::move(scheduler), gcs_table_storage_, gcs_pub_sub_, + *runtime_env_manager_, [this](const ActorID &actor_id) { gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(actor_id); }, @@ -478,7 +479,7 @@ void GcsServer::InstallEventListeners() { gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenJobDead(*job_id); }); - // Install scheduling policy event listeners. + // Install scheduling event listeners. if (RayConfig::instance().gcs_actor_scheduling_enabled()) { gcs_resource_manager_->AddResourcesChangedListener([this] { main_service_.post([this] { @@ -513,9 +514,10 @@ void GcsServer::PrintDebugInfo() { // TODO(ffbin): We will get the session_dir in the next PR, and write the log to // gcs_debug_state.txt. RAY_LOG(INFO) << stream.str(); - execute_after(main_service_, [this] { PrintDebugInfo(); }, - (RayConfig::instance().gcs_dump_debug_log_interval_minutes() * - 60000) /* milliseconds */); + execute_after( + main_service_, [this] { PrintDebugInfo(); }, + (RayConfig::instance().gcs_dump_debug_log_interval_minutes() * + 60000) /* milliseconds */); } void GcsServer::PrintAsioStats() { @@ -524,8 +526,9 @@ void GcsServer::PrintAsioStats() { RayConfig::instance().event_stats_print_interval_ms(); if (event_stats_print_interval_ms != -1 && RayConfig::instance().event_stats()) { RAY_LOG(INFO) << "Event stats:\n\n" << main_service_.StatsString() << "\n\n"; - execute_after(main_service_, [this] { PrintAsioStats(); }, - event_stats_print_interval_ms /* milliseconds */); + execute_after( + main_service_, [this] { PrintAsioStats(); }, + event_stats_print_interval_ms /* milliseconds */); } } diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index cadb70a3f3541..507ab2820cab7 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -16,7 +16,6 @@ #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/runtime_env_manager.h" -#include "ray/gcs/gcs_server/gcs_actor_distribution.h" #include "ray/gcs/gcs_server/gcs_heartbeat_manager.h" #include "ray/gcs/gcs_server/gcs_init_data.h" #include "ray/gcs/gcs_server/gcs_kv_manager.h" diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.h b/src/ray/gcs/gcs_server/gcs_table_storage.h index 84a70a347ebf7..ed48cf71abdf2 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/ray/gcs/gcs_server/gcs_table_storage.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "ray/common/asio/instrumented_io_context.h" @@ -51,8 +52,8 @@ using rpc::WorkerTableData; template class GcsTable { public: - explicit GcsTable(std::shared_ptr &store_client) - : store_client_(store_client) {} + explicit GcsTable(std::shared_ptr store_client) + : store_client_(std::move(store_client)) {} virtual ~GcsTable() = default; @@ -106,8 +107,8 @@ class GcsTable { template class GcsTableWithJobId : public GcsTable { public: - explicit GcsTableWithJobId(std::shared_ptr &store_client) - : GcsTable(store_client) {} + explicit GcsTableWithJobId(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) {} /// Write data to the table asynchronously. /// @@ -152,16 +153,16 @@ class GcsTableWithJobId : public GcsTable { class GcsJobTable : public GcsTable { public: - explicit GcsJobTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsJobTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::JOB); } }; class GcsActorTable : public GcsTableWithJobId { public: - explicit GcsActorTable(std::shared_ptr &store_client) - : GcsTableWithJobId(store_client) { + explicit GcsActorTable(std::shared_ptr store_client) + : GcsTableWithJobId(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::ACTOR); } @@ -172,16 +173,16 @@ class GcsActorTable : public GcsTableWithJobId { class GcsPlacementGroupTable : public GcsTable { public: - explicit GcsPlacementGroupTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsPlacementGroupTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::PLACEMENT_GROUP); } }; class GcsTaskTable : public GcsTableWithJobId { public: - explicit GcsTaskTable(std::shared_ptr &store_client) - : GcsTableWithJobId(store_client) { + explicit GcsTaskTable(std::shared_ptr store_client) + : GcsTableWithJobId(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::TASK); } @@ -191,8 +192,8 @@ class GcsTaskTable : public GcsTableWithJobId { class GcsTaskLeaseTable : public GcsTableWithJobId { public: - explicit GcsTaskLeaseTable(std::shared_ptr &store_client) - : GcsTableWithJobId(store_client) { + explicit GcsTaskLeaseTable(std::shared_ptr store_client) + : GcsTableWithJobId(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::TASK_LEASE); } @@ -203,8 +204,8 @@ class GcsTaskLeaseTable : public GcsTableWithJobId { class GcsTaskReconstructionTable : public GcsTableWithJobId { public: - explicit GcsTaskReconstructionTable(std::shared_ptr &store_client) - : GcsTableWithJobId(store_client) { + explicit GcsTaskReconstructionTable(std::shared_ptr store_client) + : GcsTableWithJobId(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::TASK_RECONSTRUCTION); } @@ -214,8 +215,8 @@ class GcsTaskReconstructionTable class GcsObjectTable : public GcsTableWithJobId { public: - explicit GcsObjectTable(std::shared_ptr &store_client) - : GcsTableWithJobId(store_client) { + explicit GcsObjectTable(std::shared_ptr store_client) + : GcsTableWithJobId(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::OBJECT); } @@ -225,56 +226,56 @@ class GcsObjectTable : public GcsTableWithJobId { class GcsNodeTable : public GcsTable { public: - explicit GcsNodeTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsNodeTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::NODE); } }; class GcsNodeResourceTable : public GcsTable { public: - explicit GcsNodeResourceTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsNodeResourceTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::NODE_RESOURCE); } }; class GcsPlacementGroupScheduleTable : public GcsTable { public: - explicit GcsPlacementGroupScheduleTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsPlacementGroupScheduleTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::PLACEMENT_GROUP_SCHEDULE); } }; class GcsResourceUsageBatchTable : public GcsTable { public: - explicit GcsResourceUsageBatchTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsResourceUsageBatchTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::RESOURCE_USAGE_BATCH); } }; class GcsProfileTable : public GcsTable { public: - explicit GcsProfileTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsProfileTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::PROFILE); } }; class GcsWorkerTable : public GcsTable { public: - explicit GcsWorkerTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsWorkerTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::WORKERS); } }; class GcsInternalConfigTable : public GcsTable { public: - explicit GcsInternalConfigTable(std::shared_ptr &store_client) - : GcsTable(store_client) { + explicit GcsInternalConfigTable(std::shared_ptr store_client) + : GcsTable(std::move(store_client)) { table_name_ = TablePrefix_Name(TablePrefix::INTERNAL_CONFIG); } }; @@ -285,6 +286,29 @@ class GcsInternalConfigTable : public GcsTable { /// derive from this class and override class member variables. class GcsTableStorage { public: + explicit GcsTableStorage(std::shared_ptr store_client) + : store_client_(std::move(store_client)) { + job_table_ = std::make_unique(store_client_); + actor_table_ = std::make_unique(store_client_); + placement_group_table_ = std::make_unique(store_client_); + task_table_ = std::make_unique(store_client_); + task_lease_table_ = std::make_unique(store_client_); + task_reconstruction_table_ = + std::make_unique(store_client_); + object_table_ = std::make_unique(store_client_); + node_table_ = std::make_unique(store_client_); + node_resource_table_ = std::make_unique(store_client_); + placement_group_schedule_table_ = + std::make_unique(store_client_); + placement_group_schedule_table_ = + std::make_unique(store_client_); + resource_usage_batch_table_ = + std::make_unique(store_client_); + profile_table_ = std::make_unique(store_client_); + worker_table_ = std::make_unique(store_client_); + system_config_table_ = std::make_unique(store_client_); + } + GcsJobTable &JobTable() { RAY_CHECK(job_table_ != nullptr); return *job_table_; @@ -383,26 +407,8 @@ class GcsTableStorage { /// that uses redis as storage. class RedisGcsTableStorage : public GcsTableStorage { public: - explicit RedisGcsTableStorage(std::shared_ptr redis_client) { - store_client_ = std::make_shared(redis_client); - job_table_.reset(new GcsJobTable(store_client_)); - actor_table_.reset(new GcsActorTable(store_client_)); - placement_group_table_.reset(new GcsPlacementGroupTable(store_client_)); - task_table_.reset(new GcsTaskTable(store_client_)); - task_lease_table_.reset(new GcsTaskLeaseTable(store_client_)); - task_reconstruction_table_.reset(new GcsTaskReconstructionTable(store_client_)); - object_table_.reset(new GcsObjectTable(store_client_)); - node_table_.reset(new GcsNodeTable(store_client_)); - node_resource_table_.reset(new GcsNodeResourceTable(store_client_)); - placement_group_schedule_table_.reset( - new GcsPlacementGroupScheduleTable(store_client_)); - placement_group_schedule_table_.reset( - new GcsPlacementGroupScheduleTable(store_client_)); - resource_usage_batch_table_.reset(new GcsResourceUsageBatchTable(store_client_)); - profile_table_.reset(new GcsProfileTable(store_client_)); - worker_table_.reset(new GcsWorkerTable(store_client_)); - system_config_table_.reset(new GcsInternalConfigTable(store_client_)); - } + explicit RedisGcsTableStorage(std::shared_ptr redis_client) + : GcsTableStorage(std::make_shared(std::move(redis_client))) {} }; /// \class InMemoryGcsTableStorage @@ -410,24 +416,8 @@ class RedisGcsTableStorage : public GcsTableStorage { /// that uses memory as storage. class InMemoryGcsTableStorage : public GcsTableStorage { public: - explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service) { - store_client_ = std::make_shared(main_io_service); - job_table_.reset(new GcsJobTable(store_client_)); - actor_table_.reset(new GcsActorTable(store_client_)); - placement_group_table_.reset(new GcsPlacementGroupTable(store_client_)); - task_table_.reset(new GcsTaskTable(store_client_)); - task_lease_table_.reset(new GcsTaskLeaseTable(store_client_)); - task_reconstruction_table_.reset(new GcsTaskReconstructionTable(store_client_)); - object_table_.reset(new GcsObjectTable(store_client_)); - node_table_.reset(new GcsNodeTable(store_client_)); - node_resource_table_.reset(new GcsNodeResourceTable(store_client_)); - placement_group_schedule_table_.reset( - new GcsPlacementGroupScheduleTable(store_client_)); - resource_usage_batch_table_.reset(new GcsResourceUsageBatchTable(store_client_)); - profile_table_.reset(new GcsProfileTable(store_client_)); - worker_table_.reset(new GcsWorkerTable(store_client_)); - system_config_table_.reset(new GcsInternalConfigTable(store_client_)); - } + explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service) + : GcsTableStorage(std::make_shared(main_io_service)) {} }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc index b921fd2acd2a0..f43d40dd392ac 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc @@ -33,6 +33,7 @@ class MockActorScheduler : public gcs::GcsActorSchedulerInterface { void Reschedule(std::shared_ptr actor) {} void ReleaseUnusedWorkers( const std::unordered_map> &node_to_workers) {} + void OnActorDestruction(std::shared_ptr actor) {} MOCK_METHOD1(CancelOnNode, std::vector(const NodeID &node_id)); MOCK_METHOD2(CancelOnWorker, ActorID(const NodeID &node_id, const WorkerID &worker_id)); @@ -105,8 +106,8 @@ class GcsActorManagerTest : public ::testing::Test { store_client_ = std::make_shared(io_service_); gcs_table_storage_ = std::make_shared(io_service_); gcs_actor_manager_.reset(new gcs::GcsActorManager( - mock_actor_scheduler_, gcs_table_storage_, gcs_pub_sub_, *runtime_env_mgr_, - [](const ActorID &actor_id) {}, + io_service_, mock_actor_scheduler_, gcs_table_storage_, gcs_pub_sub_, + *runtime_env_mgr_, [](const ActorID &actor_id) {}, [this](const JobID &job_id) { return job_namespace_table_[job_id]; }, [this](std::function fn, boost::posix_time::milliseconds delay) { if (skip_delay_) { @@ -953,6 +954,7 @@ TEST_F(GcsActorManagerTest, TestRayNamespace) { } TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { + google::protobuf::Arena arena; skip_delay_ = false; auto job_id_1 = JobID::FromInt(1); auto request1 = Mocker::GenRegisterActorRequest(job_id_1, /*max_restarts=*/0, @@ -971,7 +973,8 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { { rpc::GetAllActorInfoRequest request; - rpc::GetAllActorInfoReply reply; + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); bool called = false; auto callback = [&called](Status status, std::function success, std::function failure) { called = true; }; @@ -981,7 +984,8 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { } { rpc::GetAllActorInfoRequest request; - rpc::GetAllActorInfoReply reply; + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); request.set_show_dead_jobs(true); std::promise promise; auto callback = [&promise](Status status, std::function success, @@ -994,7 +998,8 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { delayed_to_run_(); { rpc::GetAllActorInfoRequest request; - rpc::GetAllActorInfoReply reply; + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); request.set_show_dead_jobs(true); std::promise promise; auto callback = [&promise](Status status, std::function success, diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc new file mode 100644 index 0000000000000..0829caf3e0d91 --- /dev/null +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc @@ -0,0 +1,139 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off +#include "gtest/gtest.h" +#include "gmock/gmock.h" +#include "ray/gcs/gcs_server/gcs_actor_manager.h" +#include "ray/gcs/gcs_server/gcs_actor_scheduler.h" +#include "mock/ray/gcs/store_client/store_client.h" +#include "mock/ray/gcs/gcs_server/gcs_node_manager.h" +#include "mock/ray/raylet_client/raylet_client.h" +#include "mock/ray/pubsub/subscriber.h" +#include "mock/ray/gcs/pubsub/gcs_pub_sub.h" +#include "mock/ray/rpc/worker/core_worker_client.h" +// clang-format on +using namespace ::testing; + +namespace ray { +namespace gcs { +struct MockCallback { + MOCK_METHOD(void, Call, ((std::shared_ptr))); + void operator()(std::shared_ptr a) { return Call(a); } +}; + +class GcsActorSchedulerTest : public Test { + public: + void SetUp() override { + store_client = std::make_shared(); + actor_table = std::make_unique(store_client); + gcs_node_manager = std::make_unique(); + pub_sub = std::make_shared(); + raylet_client = std::make_shared(); + core_worker_client = std::make_shared(); + client_pool = std::make_shared( + [this](const rpc::Address &) { return raylet_client; }); + actor_scheduler = std::make_unique( + io_context, *actor_table, *gcs_node_manager, pub_sub, + [this](auto a) { schedule_failure_handler(a); }, + [this](auto a, const rpc::PushTaskReply) { schedule_success_handler(a); }, + client_pool, [this](const rpc::Address &) { return core_worker_client; }); + auto node_info = std::make_shared(); + node_info->set_state(rpc::GcsNodeInfo::ALIVE); + node_id = NodeID::FromRandom(); + node_info->set_node_id(node_id.Binary()); + worker_id = WorkerID::FromRandom(); + gcs_node_manager->AddNode(node_info); + } + std::shared_ptr raylet_client; + instrumented_io_context io_context; + std::shared_ptr store_client; + std::unique_ptr actor_table; + std::unique_ptr actor_scheduler; + std::unique_ptr gcs_node_manager; + std::shared_ptr pub_sub; + std::shared_ptr core_worker_client; + std::shared_ptr client_pool; + MockCallback schedule_failure_handler; + MockCallback schedule_success_handler; + NodeID node_id; + WorkerID worker_id; +}; + +TEST_F(GcsActorSchedulerTest, KillWorkerLeak1) { + // Ensure worker is not leak in the following case: + // 1. Gcs start to lease a worker + // 2. Gcs cancel the actor + // 3. Gcs lease reply with a grant + // We'd like to test the worker got released eventually. + // Worker is released with actor killing + auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); + rpc::ActorTableData actor_data; + actor_data.set_state(rpc::ActorTableData::PENDING_CREATION); + actor_data.set_actor_id(actor_id.Binary()); + auto actor = std::make_shared(actor_data); + std::function cb; + EXPECT_CALL(*raylet_client, RequestWorkerLease(Matcher(), _, _)) + .WillOnce(testing::SaveArg<1>(&cb)); + // Ensure actor is killed + EXPECT_CALL(*core_worker_client, KillActor(_, _)); + actor_scheduler->Schedule(actor); + actor->GetMutableActorTableData()->set_state(rpc::ActorTableData::DEAD); + actor_scheduler->CancelOnNode(node_id); + ray::rpc::RequestWorkerLeaseReply reply; + reply.mutable_worker_address()->set_raylet_id(node_id.Binary()); + reply.mutable_worker_address()->set_worker_id(worker_id.Binary()); + cb(Status::OK(), reply); +} + +TEST_F(GcsActorSchedulerTest, KillWorkerLeak2) { + // Ensure worker is not leak in the following case: + // 1. Actor is in pending creation + // 2. Gcs push creation task to run in worker + // 3. Cancel the task + // 4. Task creating reply received + // We'd like to test the worker got released eventually. + // Worker is released with actor killing + auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); + rpc::ActorTableData actor_data; + actor_data.set_state(rpc::ActorTableData::PENDING_CREATION); + actor_data.set_actor_id(actor_id.Binary()); + auto actor = std::make_shared(actor_data); + rpc::ClientCallback request_worker_lease_cb; + // Ensure actor is killed + EXPECT_CALL(*core_worker_client, KillActor(_, _)); + EXPECT_CALL(*raylet_client, RequestWorkerLease(Matcher(), _, _)) + .WillOnce(testing::SaveArg<1>(&request_worker_lease_cb)); + + std::function async_put_with_index_cb; + // Leasing successfully + EXPECT_CALL(*store_client, AsyncPutWithIndex(_, _, _, _, _)) + .WillOnce(DoAll(SaveArg<4>(&async_put_with_index_cb), Return(Status::OK()))); + actor_scheduler->Schedule(actor); + rpc::RequestWorkerLeaseReply reply; + reply.mutable_worker_address()->set_raylet_id(node_id.Binary()); + reply.mutable_worker_address()->set_worker_id(worker_id.Binary()); + request_worker_lease_cb(Status::OK(), reply); + + rpc::ClientCallback push_normal_task_cb; + // Worker start to run task + EXPECT_CALL(*core_worker_client, PushNormalTask(_, _)) + .WillOnce(testing::SaveArg<1>(&push_normal_task_cb)); + async_put_with_index_cb(Status::OK()); + actor->GetMutableActorTableData()->set_state(rpc::ActorTableData::DEAD); + actor_scheduler->CancelOnWorker(node_id, worker_id); + push_normal_task_cb(Status::OK(), rpc::PushTaskReply()); +} +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc index ada5f0094872b..48793907f117f 100644 --- a/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc @@ -147,13 +147,15 @@ TEST_F(GcsBasedActorSchedulerTest, TestNotEnoughClusterResources) { ASSERT_TRUE(actor->GetNodeID().IsNil()); } -TEST_F(GcsBasedActorSchedulerTest, TestScheduleOneActor) { +TEST_F(GcsBasedActorSchedulerTest, TestScheduleAndDestroyOneActor) { // Add a node with 64 memory units and 8 CPU. std::unordered_map node_resources = {{kMemory_ResourceLabel, 64}, {kCPU_ResourceLabel, 8}}; auto node = AddNewNode(node_resources); auto node_id = NodeID::FromBinary(node->node_id()); ASSERT_EQ(1, gcs_node_manager_->GetAllAliveNodes().size()); + auto cluster_resources_before_scheduling = gcs_resource_manager_->GetClusterResources(); + ASSERT_TRUE(cluster_resources_before_scheduling.contains(node_id)); // Schedule a actor (requiring 32 memory units and 4 CPU). std::unordered_map required_placement_resources = { @@ -182,6 +184,20 @@ TEST_F(GcsBasedActorSchedulerTest, TestScheduleOneActor) { ASSERT_EQ(actor, success_actors_.front()); ASSERT_EQ(actor->GetNodeID(), node_id); ASSERT_EQ(actor->GetWorkerID(), worker_id); + + auto cluster_resources_after_scheduling = gcs_resource_manager_->GetClusterResources(); + ASSERT_TRUE(cluster_resources_after_scheduling.contains(node_id)); + ASSERT_FALSE( + cluster_resources_before_scheduling[node_id].GetAvailableResources().IsEqual( + cluster_resources_after_scheduling[node_id].GetAvailableResources())); + + // When destroying an actor, its acquired resources have to be returned. + gcs_actor_scheduler_->OnActorDestruction(actor); + auto cluster_resources_after_destruction = gcs_resource_manager_->GetClusterResources(); + ASSERT_TRUE(cluster_resources_after_destruction.contains(node_id)); + ASSERT_TRUE( + cluster_resources_before_scheduling[node_id].GetAvailableResources().IsEqual( + cluster_resources_after_destruction[node_id].GetAvailableResources())); } TEST_F(GcsBasedActorSchedulerTest, TestBalancedSchedule) { diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc new file mode 100644 index 0000000000000..e017fb793bafe --- /dev/null +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc @@ -0,0 +1,174 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off +#include "gtest/gtest.h" +#include "gmock/gmock.h" +#include "ray/gcs/gcs_server/gcs_placement_group_manager.h" +#include "mock/ray/gcs/gcs_server/gcs_placement_group_manager.h" +#include "mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h" +#include "mock/ray/gcs/gcs_server/gcs_resource_manager.h" +#include "mock/ray/gcs/store_client/store_client.h" +#include "ray/gcs/test/gcs_test_util.h" +// clang-format on + +using namespace ::testing; +using namespace ray; +using namespace ray::gcs; +namespace ray { +namespace gcs { + +class GcsPlacementGroupManagerMockTest : public Test { + public: + void SetUp() override { + store_client_ = std::make_shared(); + gcs_table_storage_ = std::make_shared(store_client_); + gcs_placement_group_scheduler_ = + std::make_shared(); + resource_manager_ = + std::make_shared(io_context_, nullptr, nullptr, true); + + gcs_placement_group_manager_ = std::make_unique( + io_context_, gcs_placement_group_scheduler_, gcs_table_storage_, + *resource_manager_, [](auto &) { return ""; }); + } + + std::unique_ptr gcs_placement_group_manager_; + std::shared_ptr gcs_placement_group_scheduler_; + std::shared_ptr gcs_table_storage_; + std::shared_ptr store_client_; + std::shared_ptr resource_manager_; + instrumented_io_context io_context_; +}; + +TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityReschedule) { + // Test priority works + // When return with reschedule, it should be given with the highest pri + auto req = + Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); + auto pg = std::make_shared(req, ""); + auto cb = [](Status s) {}; + PGSchedulingFailureCallback failure_callback; + PGSchedulingSuccessfulCallback success_callback; + StatusCallback put_cb; + EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _)) + .WillOnce(DoAll(SaveArg<3>(&put_cb), Return(Status::OK()))); + EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_, _, _)) + .WillOnce(DoAll(SaveArg<1>(&failure_callback), SaveArg<2>(&success_callback))); + auto now = absl::GetCurrentTimeNanos(); + gcs_placement_group_manager_->RegisterPlacementGroup(pg, cb); + auto &pending_queue = gcs_placement_group_manager_->pending_placement_groups_; + ASSERT_EQ(1, pending_queue.size()); + ASSERT_LE(now, pending_queue.begin()->first); + ASSERT_GE(absl::GetCurrentTimeNanos(), pending_queue.begin()->first); + put_cb(Status::OK()); + pg->UpdateState(rpc::PlacementGroupTableData::RESCHEDULING); + failure_callback(pg, true); + ASSERT_EQ(1, pending_queue.size()); + ASSERT_GE(0, pending_queue.begin()->first); +} + +TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityFailed) { + // Test priority works + // When return with a failure, exp backoff should work + auto req = + Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); + auto pg = std::make_shared(req, ""); + auto cb = [](Status s) {}; + PGSchedulingFailureCallback failure_callback; + PGSchedulingSuccessfulCallback success_callback; + StatusCallback put_cb; + EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _)) + .WillOnce(DoAll(SaveArg<3>(&put_cb), Return(Status::OK()))); + EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_, _, _)) + .Times(2) + .WillRepeatedly( + DoAll(SaveArg<1>(&failure_callback), SaveArg<2>(&success_callback))); + auto now = absl::GetCurrentTimeNanos(); + gcs_placement_group_manager_->RegisterPlacementGroup(pg, cb); + auto &pending_queue = gcs_placement_group_manager_->pending_placement_groups_; + ASSERT_EQ(1, pending_queue.size()); + ASSERT_LE(now, pending_queue.begin()->first); + ASSERT_GE(absl::GetCurrentTimeNanos(), pending_queue.begin()->first); + put_cb(Status::OK()); + pg->UpdateState(rpc::PlacementGroupTableData::PENDING); + now = absl::GetCurrentTimeNanos(); + failure_callback(pg, true); + auto exp_backer = ExponentialBackOff( + 1000000 * RayConfig::instance().gcs_create_placement_group_retry_min_interval_ms(), + RayConfig::instance().gcs_create_placement_group_retry_multiplier(), + 1000000 * RayConfig::instance().gcs_create_placement_group_retry_max_interval_ms()); + auto next = exp_backer.Next(); + ASSERT_DOUBLE_EQ( + next, + 1000000 * RayConfig::instance().gcs_create_placement_group_retry_min_interval_ms()); + ASSERT_EQ(1, pending_queue.size()); + auto rank = pending_queue.begin()->first; + ASSERT_LE(now + next, rank); + // ScheduleUnplacedBundles is not called here + gcs_placement_group_manager_->SchedulePendingPlacementGroups(); + ASSERT_EQ(1, pending_queue.size()); + ASSERT_EQ(rank, pending_queue.begin()->first); + + absl::SleepFor(absl::Milliseconds(1) + + absl::Nanoseconds(rank - absl::GetCurrentTimeNanos())); + gcs_placement_group_manager_->SchedulePendingPlacementGroups(); + ASSERT_EQ(0, pending_queue.size()); + pg->UpdateState(rpc::PlacementGroupTableData::PENDING); + now = absl::GetCurrentTimeNanos(); + failure_callback(pg, true); + next = RayConfig::instance().gcs_create_placement_group_retry_multiplier() * next; + ASSERT_EQ(1, pending_queue.size()); + ASSERT_LE(now + next, pending_queue.begin()->first); +} + +TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityOrder) { + // Test priority works + // Add two pgs + // Fail one and make sure it's scheduled later + auto req1 = + Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); + auto pg1 = std::make_shared(req1, ""); + auto req2 = + Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::SPREAD, 1); + auto pg2 = std::make_shared(req2, ""); + auto cb = [](Status s) {}; + PGSchedulingFailureCallback failure_callback; + PGSchedulingSuccessfulCallback success_callback; + StatusCallback put_cb; + EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _)) + .Times(2) + .WillRepeatedly(DoAll(SaveArg<3>(&put_cb), Return(Status::OK()))); + EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_, _, _)) + .Times(2) + .WillRepeatedly( + DoAll(SaveArg<1>(&failure_callback), SaveArg<2>(&success_callback))); + gcs_placement_group_manager_->RegisterPlacementGroup(pg1, cb); + gcs_placement_group_manager_->RegisterPlacementGroup(pg2, cb); + auto &pending_queue = gcs_placement_group_manager_->pending_placement_groups_; + ASSERT_EQ(2, pending_queue.size()); + put_cb(Status::OK()); + ASSERT_EQ(1, pending_queue.size()); + // PG1 is scheduled first, so PG2 is in pending queue + ASSERT_EQ(pg2, pending_queue.begin()->second.second); + failure_callback(pg1, true); + ASSERT_EQ(2, pending_queue.size()); + gcs_placement_group_manager_->SchedulePendingPlacementGroups(); + // PG2 is scheduled for the next, so PG1 is in pending queue + ASSERT_EQ(1, pending_queue.size()); + ASSERT_EQ(pg1, pending_queue.begin()->second.second); +} + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc index 7c941aa27f815..8eeed97f7eca6 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc @@ -22,6 +22,7 @@ #include "ray/gcs/test/gcs_test_util.h" namespace ray { +namespace gcs { using ::testing::_; using StatusCallback = std::function; @@ -135,6 +136,8 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { EXPECT_TRUE(WaitForCondition(condition, 10 * 1000)); } + ExponentialBackOff GetExpBackOff() { return ExponentialBackOff(0, 1); } + std::shared_ptr mock_placement_group_scheduler_; std::unique_ptr gcs_placement_group_manager_; std::unordered_map job_namespace_table_; @@ -148,6 +151,26 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { std::shared_ptr redis_client_; }; +TEST_F(GcsPlacementGroupManagerTest, TestPlacementGroupBundleCache) { + auto request = Mocker::GenCreatePlacementGroupRequest(); + std::atomic registered_placement_group_count(0); + RegisterPlacementGroup(request, + [®istered_placement_group_count](const Status &status) { + ++registered_placement_group_count; + }); + ASSERT_EQ(registered_placement_group_count, 1); + WaitForExpectedPgCount(1); + auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); + ASSERT_TRUE(placement_group->cached_bundle_specs_.empty()); + // Fill the cache and verify it. + const auto &bundle_specs = placement_group->GetBundles(); + ASSERT_EQ(placement_group->cached_bundle_specs_, bundle_specs); + ASSERT_FALSE(placement_group->cached_bundle_specs_.empty()); + // Invalidate the cache and verify it. + RAY_UNUSED(placement_group->GetMutableBundle(0)); + ASSERT_TRUE(placement_group->cached_bundle_specs_.empty()); +} + TEST_F(GcsPlacementGroupManagerTest, TestBasic) { auto request = Mocker::GenCreatePlacementGroupRequest(); std::atomic registered_placement_group_count(0); @@ -176,7 +199,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestSchedulingFailed) { auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.clear(); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), true); gcs_placement_group_manager_->SchedulePendingPlacementGroups(); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_.size(), 1); mock_placement_group_scheduler_->placement_groups_.clear(); @@ -240,7 +264,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeAdd) { mock_placement_group_scheduler_->placement_groups_.pop_back(); // If the creation of placement group fails, it will be rescheduled after a short time. - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), true); WaitForExpectedPgCount(1); } @@ -255,7 +280,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestRemovingPendingPlacementGroup) { auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.clear(); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), true); ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::PENDING); const auto &placement_group_id = placement_group->GetPlacementGroupID(); gcs_placement_group_manager_->RemovePlacementGroup(placement_group_id, @@ -291,7 +317,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestRemovingLeasingPlacementGroup) { gcs_placement_group_manager_->RemovePlacementGroup(placement_group_id, [](const Status &status) {}); ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::REMOVED); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), true); // Make sure it is not rescheduled gcs_placement_group_manager_->SchedulePendingPlacementGroups(); @@ -354,7 +381,6 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeDead) { placement_group->GetMutableBundle(0)->set_node_id(NodeID::FromRandom().Binary()); placement_group->GetMutableBundle(1)->set_node_id(NodeID::FromRandom().Binary()); mock_placement_group_scheduler_->placement_groups_.pop_back(); - // If a node dies, we will set the bundles above it to be unplaced and reschedule the // placement group. The placement group state is set to `RESCHEDULING` and will be // scheduled first. @@ -373,14 +399,15 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeDead) { placement_group->GetPlacementGroupID()); const auto &bundles = mock_placement_group_scheduler_->placement_groups_[0]->GetBundles(); - EXPECT_TRUE(NodeID::FromBinary(bundles[0]->GetMutableMessage().node_id()).IsNil()); - EXPECT_FALSE(NodeID::FromBinary(bundles[1]->GetMutableMessage().node_id()).IsNil()); + EXPECT_TRUE(NodeID::FromBinary(bundles[0]->GetMessage().node_id()).IsNil()); + EXPECT_FALSE(NodeID::FromBinary(bundles[1]->GetMessage().node_id()).IsNil()); // If `RESCHEDULING` placement group fails to create, we will schedule it again first. placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.pop_back(); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_.size(), 0); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), true); WaitForExpectedPgCount(1); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_[0]->GetPlacementGroupID(), placement_group->GetPlacementGroupID()); @@ -526,7 +553,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestSchedulingCanceledWhenPgIsInfeasible) { mock_placement_group_scheduler_->placement_groups_.clear(); // Mark it non-retryable. - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, false); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, + GetExpBackOff(), false); // Schedule twice to make sure it will not be scheduled afterward. gcs_placement_group_manager_->SchedulePendingPlacementGroups(); @@ -607,6 +635,7 @@ TEST_F(GcsPlacementGroupManagerTest, TestRayNamespace) { } } +} // namespace gcs } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index cbe1ba78495f4..5d265ac1bbb59 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -33,6 +33,7 @@ class GcsServerTest : public ::testing::Test { config.grpc_server_name = "MockedGcsServer"; config.grpc_server_thread_num = 1; config.redis_address = "127.0.0.1"; + config.node_ip_address = "127.0.0.1"; config.enable_sharding_conn = false; config.redis_port = TEST_REDIS_SERVER_PORTS.front(); gcs_server_.reset(new gcs::GcsServer(config, io_service_)); diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h index 249ac5a9fdd6a..409239801e2c5 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -79,6 +79,14 @@ struct GcsServerMocker { callbacks.push_back(callback); } + void RequestWorkerLease( + const rpc::TaskSpec &spec, + const rpc::ClientCallback &callback, + const int64_t backlog_size = -1) override { + num_workers_requested += 1; + callbacks.push_back(callback); + } + /// WorkerLeaseInterface void ReleaseUnusedWorkers( const std::vector &workers_in_use, @@ -180,7 +188,7 @@ struct GcsServerMocker { /// ResourceReserveInterface void CancelResourceReserve( - BundleSpecification &bundle_spec, + const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) override { num_return_requested += 1; diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.h b/src/ray/gcs/pubsub/gcs_pub_sub.h index b871a02b13ddd..70828a3679691 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.h +++ b/src/ray/gcs/pubsub/gcs_pub_sub.h @@ -96,6 +96,9 @@ class GcsPubSub { std::string DebugString() const; + protected: + GcsPubSub() : GcsPubSub(nullptr) {} + private: /// Represents a caller's command to subscribe or unsubscribe to a given /// channel. diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index c7244aac80549..443c42f9dee69 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -15,7 +15,7 @@ #pragma once #include -#include +#include #include #include #include diff --git a/src/ray/object_manager/object_buffer_pool.cc b/src/ray/object_manager/object_buffer_pool.cc index e6e214b3062f2..b25439cd7203c 100644 --- a/src/ray/object_manager/object_buffer_pool.cc +++ b/src/ray/object_manager/object_buffer_pool.cc @@ -14,6 +14,7 @@ #include "ray/object_manager/object_buffer_pool.h" +#include "absl/time/time.h" #include "ray/common/status.h" #include "ray/util/logging.h" @@ -21,26 +22,49 @@ namespace ray { ObjectBufferPool::ObjectBufferPool(const std::string &store_socket_name, uint64_t chunk_size) - : default_chunk_size_(chunk_size) { - store_socket_name_ = store_socket_name; + : store_socket_name_(store_socket_name), default_chunk_size_(chunk_size) { RAY_CHECK_OK(store_client_.Connect(store_socket_name_.c_str(), "", 0, 300)); } ObjectBufferPool::~ObjectBufferPool() { - // Abort everything in progress. - auto create_buf_state_copy = create_buffer_state_; - for (const auto &pair : create_buf_state_copy) { - AbortCreate(pair.first); + absl::MutexLock lock(&pool_mutex_); + auto inflight_ops = create_buffer_ops_; + pool_mutex_.Unlock(); + + for (const auto &[id, cond_var] : inflight_ops) { + cond_var->SignalAll(); + } + auto no_inflight = [this]() { + pool_mutex_.AssertReaderHeld(); + return create_buffer_ops_.empty(); + }; + // Assume no request would arrive, acquire pool_mutex_ when there is no inflight + // operation. Otherwise print an error. + if (!pool_mutex_.LockWhenWithTimeout(absl::Condition(&no_inflight), absl::Seconds(5))) { + RAY_LOG(ERROR) + << create_buffer_ops_.size() << " remaining inflight create buffer operations " + << "during ObjectBufferPool destruction. Either abort these operations before " + << "destroying ObjectBufferPool, or refactor ObjectBufferPool to make it " + "unnecessary to wait for the operations' completion."; } + + // Abort unfinished buffers in progress. + for (auto it = create_buffer_state_.begin(); it != create_buffer_state_.end(); it++) { + RAY_CHECK_OK(store_client_.Release(it->first)); + RAY_CHECK_OK(store_client_.Abort(it->first)); + create_buffer_state_.erase(it); + } + RAY_CHECK(create_buffer_state_.empty()); RAY_CHECK_OK(store_client_.Disconnect()); } -uint64_t ObjectBufferPool::GetNumChunks(uint64_t data_size) { +uint64_t ObjectBufferPool::GetNumChunks(uint64_t data_size) const { return (data_size + default_chunk_size_ - 1) / default_chunk_size_; } -uint64_t ObjectBufferPool::GetBufferLength(uint64_t chunk_index, uint64_t data_size) { +uint64_t ObjectBufferPool::GetBufferLength(uint64_t chunk_index, + uint64_t data_size) const { return (chunk_index + 1) * default_chunk_size_ > data_size ? data_size % default_chunk_size_ : default_chunk_size_; @@ -49,7 +73,7 @@ uint64_t ObjectBufferPool::GetBufferLength(uint64_t chunk_index, uint64_t data_s std::pair, ray::Status> ObjectBufferPool::CreateObjectReader(const ObjectID &object_id, rpc::Address owner_address) { - std::lock_guard lock(pool_mutex_); + absl::MutexLock lock(&pool_mutex_); std::vector object_ids{object_id}; std::vector object_buffers(1); @@ -76,53 +100,21 @@ ray::Status ObjectBufferPool::CreateChunk(const ObjectID &object_id, const rpc::Address &owner_address, uint64_t data_size, uint64_t metadata_size, uint64_t chunk_index) { - std::unique_lock lock(pool_mutex_); - if (create_buffer_state_.count(object_id) == 0) { - int64_t object_size = data_size - metadata_size; - // Try to create shared buffer. - std::shared_ptr data; - - // Release the buffer pool lock during the blocking create call. - lock.unlock(); - Status s = store_client_.CreateAndSpillIfNeeded( - object_id, owner_address, object_size, NULL, metadata_size, &data, - plasma::flatbuf::ObjectSource::ReceivedFromRemoteRaylet); - lock.lock(); - - // Another thread may have succeeded in creating the chunk while the lock - // was released. In that case skip the remainder of the creation block. - if (create_buffer_state_.count(object_id) == 0) { - std::vector buffer; - if (!s.ok()) { - // Create failed. The object may already exist locally. If something else went - // wrong, another chunk will succeed in creating the buffer, and this - // chunk will eventually make it here via pull requests. - return ray::Status::IOError(s.message()); - } - // Read object into store. - uint8_t *mutable_data = data->Data(); - uint64_t num_chunks = GetNumChunks(data_size); - create_buffer_state_.emplace( - std::piecewise_construct, std::forward_as_tuple(object_id), - std::forward_as_tuple(BuildChunks(object_id, mutable_data, data_size, data))); - RAY_LOG(DEBUG) << "Created object " << object_id - << " in plasma store, number of chunks: " << num_chunks - << ", chunk index: " << chunk_index; - RAY_CHECK(create_buffer_state_[object_id].chunk_info.size() == num_chunks); - } - } - if (create_buffer_state_[object_id].chunk_state[chunk_index] != - CreateChunkState::AVAILABLE) { + absl::MutexLock lock(&pool_mutex_); + RAY_RETURN_NOT_OK(EnsureBufferExists(object_id, owner_address, data_size, metadata_size, + chunk_index)); + auto &state = create_buffer_state_.at(object_id); + if (state.chunk_state[chunk_index] != CreateChunkState::AVAILABLE) { // There can be only one reference to this chunk at any given time. return ray::Status::IOError("Chunk already received by a different thread."); } - create_buffer_state_[object_id].chunk_state[chunk_index] = CreateChunkState::REFERENCED; + state.chunk_state[chunk_index] = CreateChunkState::REFERENCED; return ray::Status::OK(); } void ObjectBufferPool::WriteChunk(const ObjectID &object_id, const uint64_t chunk_index, const std::string &data) { - std::lock_guard lock(pool_mutex_); + absl::MutexLock lock(&pool_mutex_); auto it = create_buffer_state_.find(object_id); if (it == create_buffer_state_.end() || it->second.chunk_state.at(chunk_index) != CreateChunkState::REFERENCED) { @@ -148,7 +140,7 @@ void ObjectBufferPool::WriteChunk(const ObjectID &object_id, const uint64_t chun } void ObjectBufferPool::AbortCreate(const ObjectID &object_id) { - std::lock_guard lock(pool_mutex_); + absl::MutexLock lock(&pool_mutex_); auto it = create_buffer_state_.find(object_id); if (it != create_buffer_state_.end()) { RAY_LOG(INFO) << "Not enough memory to create requested object " << object_id @@ -179,13 +171,84 @@ std::vector ObjectBufferPool::BuildChunks( return chunks; } +ray::Status ObjectBufferPool::EnsureBufferExists(const ObjectID &object_id, + const rpc::Address &owner_address, + uint64_t data_size, + uint64_t metadata_size, + uint64_t chunk_index) { + while (true) { + // Buffer for object_id already exists. + if (create_buffer_state_.contains(object_id)) { + return ray::Status::OK(); + } + + auto it = create_buffer_ops_.find(object_id); + if (it == create_buffer_ops_.end()) { + // No inflight create buffer operation, proceed to start one. + break; + } + + auto cond_var = it->second; + // Release pool_mutex_ while waiting, until the current inflight create buffer + // operation finishes. + cond_var->Wait(&pool_mutex_); + } + + // Indicate that there is an inflight create buffer operation, by inserting into + // create_buffer_ops_. + RAY_CHECK( + create_buffer_ops_.insert({object_id, std::make_shared()}).second); + const int64_t object_size = + static_cast(data_size) - static_cast(metadata_size); + std::shared_ptr data; + + // Release pool_mutex_ during the blocking create call. + pool_mutex_.Unlock(); + Status s = store_client_.CreateAndSpillIfNeeded( + object_id, owner_address, static_cast(object_size), nullptr, + static_cast(metadata_size), &data, + plasma::flatbuf::ObjectSource::ReceivedFromRemoteRaylet); + pool_mutex_.Lock(); + + // No other thread could have created the buffer. + RAY_CHECK(!create_buffer_state_.contains(object_id)); + + // Remove object_id from create_buffer_ops_ to indicate to the waiting ops that the + // inflight operation has finished. Wake up waiters so they can either start another + // create buffer op, or proceed after the buffer has been created. + { + auto it = create_buffer_ops_.find(object_id); + it->second->SignalAll(); + create_buffer_ops_.erase(it); + } + + if (!s.ok()) { + // Create failed. Buffer creation will be tried by another chunk. + // And this chunk will eventually make it here via retried pull requests. + return ray::Status::IOError(s.message()); + } + + // Read object into store. + uint8_t *mutable_data = data->Data(); + uint64_t num_chunks = GetNumChunks(data_size); + create_buffer_state_.emplace( + std::piecewise_construct, std::forward_as_tuple(object_id), + std::forward_as_tuple(BuildChunks(object_id, mutable_data, data_size, data))); + RAY_CHECK(create_buffer_state_[object_id].chunk_info.size() == num_chunks); + RAY_LOG(DEBUG) << "Created object " << object_id + << " in plasma store, number of chunks: " << num_chunks + << ", chunk index: " << chunk_index; + + return ray::Status::OK(); +} + void ObjectBufferPool::FreeObjects(const std::vector &object_ids) { - std::lock_guard lock(pool_mutex_); + absl::MutexLock lock(&pool_mutex_); RAY_CHECK_OK(store_client_.Delete(object_ids)); } std::string ObjectBufferPool::DebugString() const { - std::lock_guard lock(pool_mutex_); + absl::MutexLock lock(&pool_mutex_); std::stringstream result; result << "BufferPool:"; result << "\n- create buffer state map size: " << create_buffer_state_.size(); diff --git a/src/ray/object_manager/object_buffer_pool.h b/src/ray/object_manager/object_buffer_pool.h index 05c51e5e00117..b2722a3eceecc 100644 --- a/src/ray/object_manager/object_buffer_pool.h +++ b/src/ray/object_manager/object_buffer_pool.h @@ -16,12 +16,14 @@ #include #include -#include +#include #include #include -#include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/object_manager/memory_object_reader.h" @@ -68,14 +70,14 @@ class ObjectBufferPool { /// /// \param data_size The size of the object + metadata. /// \return The number of chunks into which the object will be split. - uint64_t GetNumChunks(uint64_t data_size); + uint64_t GetNumChunks(uint64_t data_size) const; /// Computes the buffer length of a chunk of an object. /// /// \param chunk_index The chunk index for which to obtain the buffer length. /// \param data_size The size of the object + metadata. /// \return The buffer length of the chunk at chunk_index. - uint64_t GetBufferLength(uint64_t chunk_index, uint64_t data_size); + uint64_t GetBufferLength(uint64_t chunk_index, uint64_t data_size) const; /// Returns an object reader for read. /// @@ -85,7 +87,7 @@ class ObjectBufferPool { /// this method. An IOError status is returned if the Get call on the plasma store /// fails, and the MemoryObjectReader will be empty. std::pair, ray::Status> CreateObjectReader( - const ObjectID &object_id, rpc::Address owner_address); + const ObjectID &object_id, rpc::Address owner_address) LOCKS_EXCLUDED(pool_mutex_); /// Returns a chunk of an empty object at the given chunk_index. The object chunk /// serves as the buffer that is to be written to by a connection receiving an @@ -106,7 +108,7 @@ class ObjectBufferPool { /// (with no intermediate AbortCreateChunk). ray::Status CreateChunk(const ObjectID &object_id, const rpc::Address &owner_address, uint64_t data_size, uint64_t metadata_size, - uint64_t chunk_index); + uint64_t chunk_index) LOCKS_EXCLUDED(pool_mutex_); /// Write to a Chunk of an object. If all chunks of an object is written, /// it seals the object. @@ -119,34 +121,44 @@ class ObjectBufferPool { /// \param chunk_index The index of the chunk. /// \param data The data to write into the chunk. void WriteChunk(const ObjectID &object_id, uint64_t chunk_index, - const std::string &data); + const std::string &data) LOCKS_EXCLUDED(pool_mutex_); /// Free a list of objects from object store. /// /// \param object_ids the The list of ObjectIDs to be deleted. /// \return Void. - void FreeObjects(const std::vector &object_ids); + void FreeObjects(const std::vector &object_ids) LOCKS_EXCLUDED(pool_mutex_); /// Abort the create operation associated with an object. This destroys the buffer /// state, including create operations in progress for all chunks of the object. - void AbortCreate(const ObjectID &object_id); + void AbortCreate(const ObjectID &object_id) LOCKS_EXCLUDED(pool_mutex_); /// Returns debug string for class. /// /// \return string. - std::string DebugString() const; + std::string DebugString() const LOCKS_EXCLUDED(pool_mutex_); private: /// Splits an object into ceil(data_size/chunk_size) chunks, which will /// either be read or written to in parallel. std::vector BuildChunks(const ObjectID &object_id, uint8_t *data, uint64_t data_size, - std::shared_ptr buffer_ref); + std::shared_ptr buffer_ref) + EXCLUSIVE_LOCKS_REQUIRED(pool_mutex_); + + /// Ensures buffer for the object exists, and creates the buffer if needed. + /// Returns OK if buffer exists. + /// Must hold pool_mutex_ when calling this function. pool_mutex_ can be released + /// during the call. + ray::Status EnsureBufferExists(const ObjectID &object_id, + const rpc::Address &owner_address, uint64_t data_size, + uint64_t metadata_size, uint64_t chunk_index) + EXCLUSIVE_LOCKS_REQUIRED(pool_mutex_); /// The state of a chunk associated with a create operation. enum class CreateChunkState : unsigned int { AVAILABLE = 0, REFERENCED, SEALED }; - /// Holds the state of a create buffer. + /// Holds the state of creating chunks. Members are protected by pool_mutex_. struct CreateBufferState { CreateBufferState() {} CreateBufferState(std::vector chunk_info) @@ -166,18 +178,29 @@ class ObjectBufferPool { /// Returned when GetChunk or CreateChunk fails. const ChunkInfo errored_chunk_ = {0, nullptr, 0, nullptr}; - /// Mutex on public methods for thread-safe operations on - /// get_buffer_state_, create_buffer_state_, and store_client_. - mutable std::mutex pool_mutex_; + /// Socket name of plasma store. + const std::string store_socket_name_; + /// Determines the maximum chunk size to be transferred by a single thread. const uint64_t default_chunk_size_; + + /// Mutex to protect create_buffer_ops_, create_buffer_state_ and following invariants: + /// - create_buffer_ops_ contains an object_id iff there is an inflight operation to + /// create the buffer for the object. + /// - An object_id cannot appear in both create_buffer_ops_ and create_buffer_state_. + mutable absl::Mutex pool_mutex_; + /// Makes sure each object has at most one inflight create buffer operation. + /// Other operations can wait on the std::condition_variable for the operation + /// to complete. If successful, the corresponding entry in create_buffer_state_ + /// will be created. + absl::flat_hash_map> create_buffer_ops_ + GUARDED_BY(pool_mutex_); /// The state of a buffer that's currently being used. - std::unordered_map create_buffer_state_; + absl::flat_hash_map create_buffer_state_ + GUARDED_BY(pool_mutex_); /// Plasma client pool. plasma::PlasmaClient store_client_; - /// Socket name of plasma store. - std::string store_socket_name_; }; } // namespace ray diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 8e4dd703b91fb..3ee951d75553d 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -88,6 +88,7 @@ ObjectManager::ObjectManager( buffer_pool_(config_.store_socket_name, config_.object_chunk_size), rpc_work_(rpc_service_), object_manager_server_("ObjectManager", config_.object_manager_port, + config_.object_manager_address == "127.0.0.1", config_.rpc_service_threads_number), object_manager_service_(rpc_service_, *this), client_call_manager_(main_service, config_.rpc_service_threads_number), @@ -441,17 +442,18 @@ void ObjectManager::PushObjectInternal(const ObjectID &object_id, const NodeID & [=]() { // Post to the multithreaded RPC event loop so that data is copied // off of the main thread. - SendObjectChunk(push_id, object_id, node_id, chunk_id, rpc_client, - [=](const Status &status) { - // Post back to the main event loop because the - // PushManager is thread-safe. - main_service_->post( - [this, node_id, object_id]() { - push_manager_->OnChunkComplete(node_id, object_id); - }, - "ObjectManager.Push"); - }, - std::move(chunk_reader)); + SendObjectChunk( + push_id, object_id, node_id, chunk_id, rpc_client, + [=](const Status &status) { + // Post back to the main event loop because the + // PushManager is thread-safe. + main_service_->post( + [this, node_id, object_id]() { + push_manager_->OnChunkComplete(node_id, object_id); + }, + "ObjectManager.Push"); + }, + chunk_reader); }, "ObjectManager.Push"); }); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 3aaa847f03381..c0519a38306bd 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include #include @@ -49,6 +49,8 @@ namespace ray { struct ObjectManagerConfig { + /// The IP address this object manager is running on. + std::string object_manager_address; /// The port that the object manager should use to listen for connections /// from other object managers. If this is 0, the object manager will choose /// its own port. @@ -56,7 +58,7 @@ struct ObjectManagerConfig { /// The object manager's global timer frequency. unsigned int timer_freq_ms; /// The time in milliseconds to wait before retrying a pull - /// that fails due to node id lookup. + /// that failed. unsigned int pull_timeout_ms; /// Object chunk size, in bytes uint64_t object_chunk_size; diff --git a/src/ray/object_manager/plasma/store.cc b/src/ray/object_manager/plasma/store.cc index ff9e98ddb765c..0b8b24dbac56d 100644 --- a/src/ray/object_manager/plasma/store.cc +++ b/src/ray/object_manager/plasma/store.cc @@ -32,7 +32,7 @@ #include #include -#include +#include #include #include #include @@ -53,6 +53,7 @@ #include "ray/object_manager/plasma/protocol.h" #include "ray/util/util.h" +namespace ph = boost::placeholders; namespace fb = plasma::flatbuf; namespace plasma { @@ -297,7 +298,9 @@ void PlasmaStore::ConnectClient(const boost::system::error_code &error) { if (!error) { // Accept a new local client and dispatch it to the node manager. auto new_connection = Client::Create( - boost::bind(&PlasmaStore::ProcessMessage, this, _1, _2, _3), std::move(socket_)); + // NOLINTNEXTLINE : handler must be of boost::AcceptHandler type. + boost::bind(&PlasmaStore::ProcessMessage, this, ph::_1, ph::_2, ph::_3), + std::move(socket_)); } // We're ready to accept another client. DoAccept(); diff --git a/src/ray/object_manager/pull_manager.h b/src/ray/object_manager/pull_manager.h index 9b7f20c14a478..6c5108f111abe 100644 --- a/src/ray/object_manager/pull_manager.h +++ b/src/ray/object_manager/pull_manager.h @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include "absl/container/flat_hash_map.h" diff --git a/src/ray/protobuf/agent_manager.proto b/src/ray/protobuf/agent_manager.proto index f573f53766525..cbbd127004536 100644 --- a/src/ray/protobuf/agent_manager.proto +++ b/src/ray/protobuf/agent_manager.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index dd9cf403c305c..1d3dd8124484d 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; @@ -155,10 +156,12 @@ message RayException { /// The runtime environment describes all the runtime packages needed to /// run some task or actor. message RuntimeEnv { - /// The raw json passed from user - string raw_json = 1; - /// Uris used in this runtime env + /// The serialized runtime env passed from the user. + string serialized_runtime_env = 1; + /// URIs used in this runtime env. These will be used for reference counting. repeated string uris = 2; + /// Indicates whether to install runtime env eagerly before the workers are leased. + bool runtime_env_eager_install = 3; } /// The task specification encapsulates all immutable information about the @@ -209,21 +212,19 @@ message TaskSpec { int64 placement_group_bundle_index = 19; // Whether or not this task should capture parent's placement group automatically. bool placement_group_capture_child_tasks = 20; - // Environment variables to override for this task - map override_environment_variables = 21; // Whether or not to skip the execution of this task. When it's true, // the receiver will not execute the task. This field is used by async actors // to guarantee task submission order after restart. - bool skip_execution = 22; + bool skip_execution = 21; // Breakpoint if this task should drop into the debugger when it starts executing // and "" if the task should not drop into the debugger. - bytes debugger_breakpoint = 23; - // Serialized JSON string of the parsed runtime environment dict for this task. - string serialized_runtime_env = 24; + bytes debugger_breakpoint = 22; + // Runtime environment for this task. + RuntimeEnv runtime_env = 23; // The concurrency group name in which this task will be performed. - string concurrency_group_name = 25; + string concurrency_group_name = 24; // Whether application-level errors (exceptions) should be retried. - bool retry_exceptions = 26; + bool retry_exceptions = 25; } message Bundle { diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 9af0a87231326..81a8fbb5fd3d2 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/event.proto b/src/ray/protobuf/event.proto index 2edc202776f6b..5ec8ee9402492 100644 --- a/src/ray/protobuf/event.proto +++ b/src/ray/protobuf/event.proto @@ -1,4 +1,5 @@ syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index ec1f3e7380d53..5f35c1a21e4d5 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; @@ -149,19 +150,17 @@ message ActorTableData { RayException creation_task_exception = 18; // The actor's namespace. Named `ray_namespace` to avoid confusions when invoked in c++. string ray_namespace = 19; - // Runtime required to run this actor - // It'll only be set if it's a detached actor and the original job has this field - RuntimeEnv runtime_env = 20; // The unix ms timestamp the actor was started at. - uint64 start_time = 21; + uint64 start_time = 20; // The unix ms timestamp the actor was ended at. - uint64 end_time = 22; + uint64 end_time = 21; + // Serialized runtime_env used to report in the dashboard snapshot. We need to populate + // it here instead of grabbing it from the task spec because the task spec is cleared + // for deleted actors: https://github.com/ray-project/ray/pull/11149. + string serialized_runtime_env = 22; // The actor's class name. This is necessary because the task spec's lifetime // is shorter than the ActorTableData. string class_name = 23; - // The actor's serialized runtime environment. This is necessary because the - // task spec's lifetime is shorter than the ActorTableData. - string serialized_runtime_env = 24; } message ErrorTableData { @@ -278,24 +277,20 @@ message TaskLeaseData { } message JobConfig { - // Environment variables to be set on worker processes. - map worker_env = 1; // The number of java workers per worker process. - uint32 num_java_workers_per_process = 2; + uint32 num_java_workers_per_process = 1; // The jvm options for java workers of the job. - repeated string jvm_options = 3; + repeated string jvm_options = 2; // A list of directories or files (jar files or dynamic libraries) that specify the // search path for user code. This will be used as `CLASSPATH` in Java, and `PYTHONPATH` // in Python. In C++, libraries under these paths will be loaded by 'dlopen'. - repeated string code_search_path = 4; + repeated string code_search_path = 3; // Runtime environment to run the code - RuntimeEnv runtime_env = 5; + RuntimeEnv runtime_env = 4; // The job's namespace. Named `ray_namespace` to avoid confusions when invoked in c++. - string ray_namespace = 6; - // Serialized JSON string of the parsed runtime environment dict for this job. - string serialized_runtime_env = 7; + string ray_namespace = 5; // An opaque kv store for job related metadata. - map metadata = 8; + map metadata = 6; } message JobTableData { diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 308083f201208..65e9bbad13bc3 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -13,7 +13,7 @@ // limitations under the License. syntax = "proto3"; - +option cc_enable_arenas = true; package ray.rpc; import "src/ray/protobuf/common.proto"; diff --git a/src/ray/protobuf/job_agent.proto b/src/ray/protobuf/job_agent.proto index 07355a0a8f7c0..e187de67ae0f5 100644 --- a/src/ray/protobuf/job_agent.proto +++ b/src/ray/protobuf/job_agent.proto @@ -15,6 +15,7 @@ syntax = "proto3"; package ray.rpc; +option cc_enable_arenas = true; import "src/ray/protobuf/agent_manager.proto"; diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index 0c56bb7832b3a..788358fb0394e 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/object_manager.proto b/src/ray/protobuf/object_manager.proto index 8bd6986f6b5b1..c212b18b266d1 100644 --- a/src/ray/protobuf/object_manager.proto +++ b/src/ray/protobuf/object_manager.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/pubsub.proto b/src/ray/protobuf/pubsub.proto index fc046afcf69c2..8181f886ffb3c 100644 --- a/src/ray/protobuf/pubsub.proto +++ b/src/ray/protobuf/pubsub.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index e207263e515a7..5dab0499d7d56 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; @@ -61,8 +62,6 @@ message ClientTask { // A name parameter, if the payload can be called in more than one way // (like a method on a payload object). string name = 2; - // A namespace parameter. - string namespace = 9; // A reference to the payload. bytes payload_id = 3; // Positional parameters to pass to this call. @@ -76,6 +75,8 @@ message ClientTask { TaskOptions options = 7; // Options passed to create the default remote task excution environment. TaskOptions baseline_options = 8; + // A namespace parameter. + string namespace = 9; } message ClientTaskTicket { diff --git a/src/ray/protobuf/reporter.proto b/src/ray/protobuf/reporter.proto index 225c520481cc5..00849c0683960 100644 --- a/src/ray/protobuf/reporter.proto +++ b/src/ray/protobuf/reporter.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; diff --git a/src/ray/protobuf/runtime_env_agent.proto b/src/ray/protobuf/runtime_env_agent.proto index a7903f8939c91..f36adf38cdb2a 100644 --- a/src/ray/protobuf/runtime_env_agent.proto +++ b/src/ray/protobuf/runtime_env_agent.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.rpc; @@ -21,6 +22,10 @@ import "src/ray/protobuf/agent_manager.proto"; message CreateRuntimeEnvRequest { string serialized_runtime_env = 1; bytes job_id = 2; + // Serialized allocated resource instances. Key is resource type, value is allocated + // instances. For example,{"CPU":20000,"memory":40000,"GPU":[10000, 10000]} means 2 cpu + // cores, 2 Gi memory, GPU 0 and GPU 1. + string serialized_allocated_resource_instances = 3; } message CreateRuntimeEnvReply { diff --git a/src/ray/protobuf/serialization.proto b/src/ray/protobuf/serialization.proto index e5fed8e4a3876..84da8dff1531c 100644 --- a/src/ray/protobuf/serialization.proto +++ b/src/ray/protobuf/serialization.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.serialization; diff --git a/src/ray/protobuf/serve.proto b/src/ray/protobuf/serve.proto index 24e755a0b883a..2636dcf685544 100644 --- a/src/ray/protobuf/serve.proto +++ b/src/ray/protobuf/serve.proto @@ -13,6 +13,7 @@ // limitations under the License. syntax = "proto3"; +option cc_enable_arenas = true; package ray.serve; @@ -31,19 +32,17 @@ message AutoscalingConfig { uint32 max_replicas = 2; // Target number of in flight requests per replicas. This is the primary configuration // knob for replica autoscaler. Lower the number, the more rapidly will the replicas - // being scaled up. Must be a non-negative inter. + // being scaled up. Must be a non-negative integer. uint32 target_num_ongoing_requests_per_replica = 3; // The frequency of how long does each replica sending metrics to autoscaler. double metrics_interval_s = 4; - // The interval (in seconds) of autoscaler evaluating metrics and performing scaling - // decision. - double loop_period_s = 5; + // The window (in seconds) for autoscaler to calculate rolling average of metrics on. - double look_back_period_s = 6; + double look_back_period_s = 5; // The multiplicative "gain" factor to limit scaling decisions. - double smoothing_factor = 7; + double smoothing_factor = 6; } // Configuration options for a backend, to be set by the user. @@ -62,11 +61,11 @@ message BackendConfig { // Duration that backend workers will wait until there is no more work to be done before // shutting down. Defaults to 2s. - double experimental_graceful_shutdown_wait_loop_s = 4; + double graceful_shutdown_wait_loop_s = 4; // Controller waits for this duration to forcefully kill the replica for shutdown. // Defaults to 20s. - double experimental_graceful_shutdown_timeout_s = 5; + double graceful_shutdown_timeout_s = 5; // Is the construction of backend is cross language? bool is_cross_language = 6; @@ -95,3 +94,35 @@ message RequestMetadata { message RequestWrapper { bytes body = 1; } + +message UpdatedObject { + bytes object_snapshot = 1; + int32 snapshot_id = 2; +} + +message LongPollRequest { + map keys_to_snapshot_ids = 1; +} + +message LongPollResult { + map updated_objects = 1; +} + +message EndpointInfo { + string endpoint_name = 1; + string route = 2; + map config = 3; +} + +message EndpointSet { + map endpoints = 1; +} + +message ActorSet { + repeated string names = 1; +} + +message BackendVersion { + string code_version = 1; + bytes user_config = 2; +} diff --git a/src/ray/ray_version_script.lds b/src/ray/ray_version_script.lds index 6d53de5ed92d1..b18b99d675dfa 100644 --- a/src/ray/ray_version_script.lds +++ b/src/ray/ray_version_script.lds @@ -39,7 +39,6 @@ VERSION_1.0 { *ray*streaming*; *aligned_free*; *aligned_malloc*; - *absl*; *grpc*; local: *; }; diff --git a/src/ray/raylet/agent_manager.cc b/src/ray/raylet/agent_manager.cc index 55fe1392f6686..ec15c27a85be4 100644 --- a/src/ray/raylet/agent_manager.cc +++ b/src/ray/raylet/agent_manager.cc @@ -36,6 +36,8 @@ void AgentManager::HandleRegisterAgent(const rpc::RegisterAgentRequest &request, RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << agent_ip_address_ << ", port: " << agent_port_ << ", pid: " << agent_pid_; reply->set_status(rpc::AGENT_RPC_STATUS_OK); + // Reset the restart count after registration is done. + agent_restart_count_ = 0; send_reply_callback(ray::Status::OK(), nullptr, nullptr); } @@ -65,14 +67,16 @@ void AgentManager::StartAgent() { ProcessEnvironment env; env.insert({"RAY_NODE_ID", options_.node_id.Hex()}); env.insert({"RAY_RAYLET_PID", std::to_string(getpid())}); + // Report the restart count to the agent so that we can decide whether or not + // report the error message to drivers. + env.insert({"RESTART_COUNT", std::to_string(agent_restart_count_)}); + env.insert({"MAX_RESTART_COUNT", + std::to_string(RayConfig::instance().agent_max_restart_count())}); Process child(argv.data(), nullptr, ec, false, env); if (!child.IsValid() || ec) { // The worker failed to start. This is a fatal error. RAY_LOG(FATAL) << "Failed to start agent with return value " << ec << ": " << ec.message(); - RAY_UNUSED(delay_executor_([this] { StartAgent(); }, - RayConfig::instance().agent_restart_interval_ms())); - return; } std::thread monitor_thread([this, child]() mutable { @@ -101,22 +105,39 @@ void AgentManager::StartAgent() { .WithField("pid", agent_pid_) << "Agent process with pid " << child.GetId() << " exit, return value " << exit_code; - RAY_UNUSED(delay_executor_([this] { StartAgent(); }, - RayConfig::instance().agent_restart_interval_ms())); + if (agent_restart_count_ < RayConfig::instance().agent_max_restart_count()) { + RAY_UNUSED(delay_executor_( + [this] { + agent_restart_count_++; + StartAgent(); + }, + // Retrying with exponential backoff + RayConfig::instance().agent_restart_interval_ms() * + std::pow(2, (agent_restart_count_ + 1)))); + } else { + RAY_LOG(INFO) << "Agent has failed " + << RayConfig::instance().agent_max_restart_count() + << " times in a row without registering the agent. This is highly " + "likely there's a bug in the dashboard agent. Please check out " + "the dashboard_agent.log file."; + } }); monitor_thread.detach(); } -void AgentManager::CreateRuntimeEnv(const JobID &job_id, - const std::string &serialized_runtime_env, - CreateRuntimeEnvCallback callback) { +void AgentManager::CreateRuntimeEnv( + const JobID &job_id, const std::string &serialized_runtime_env, + const std::string &serialized_allocated_resource_instances, + CreateRuntimeEnvCallback callback) { if (runtime_env_agent_client_ == nullptr) { RAY_LOG(INFO) << "Runtime env agent is not registered yet. Will retry CreateRuntimeEnv later: " << serialized_runtime_env; delay_executor_( - [this, job_id, serialized_runtime_env, callback] { - CreateRuntimeEnv(job_id, serialized_runtime_env, callback); + [this, job_id, serialized_runtime_env, serialized_allocated_resource_instances, + callback] { + CreateRuntimeEnv(job_id, serialized_runtime_env, + serialized_allocated_resource_instances, callback); }, RayConfig::instance().agent_manager_retry_interval_ms()); return; @@ -124,9 +145,12 @@ void AgentManager::CreateRuntimeEnv(const JobID &job_id, rpc::CreateRuntimeEnvRequest request; request.set_job_id(job_id.Hex()); request.set_serialized_runtime_env(serialized_runtime_env); + request.set_serialized_allocated_resource_instances( + serialized_allocated_resource_instances); runtime_env_agent_client_->CreateRuntimeEnv( - request, [this, job_id, serialized_runtime_env, callback]( - Status status, const rpc::CreateRuntimeEnvReply &reply) { + request, + [this, job_id, serialized_runtime_env, serialized_allocated_resource_instances, + callback](const Status &status, const rpc::CreateRuntimeEnvReply &reply) { if (status.ok()) { if (reply.status() == rpc::AGENT_RPC_STATUS_OK) { callback(true, reply.serialized_runtime_env_context()); @@ -142,8 +166,10 @@ void AgentManager::CreateRuntimeEnv(const JobID &job_id, << ", status = " << status << ", maybe there are some network problems, will retry it later."; delay_executor_( - [this, job_id, serialized_runtime_env, callback] { - CreateRuntimeEnv(job_id, serialized_runtime_env, callback); + [this, job_id, serialized_runtime_env, + serialized_allocated_resource_instances, callback] { + CreateRuntimeEnv(job_id, serialized_runtime_env, + serialized_allocated_resource_instances, callback); }, RayConfig::instance().agent_manager_retry_interval_ms()); } diff --git a/src/ray/raylet/agent_manager.h b/src/ray/raylet/agent_manager.h index bb12df0f64da4..ba81454b84536 100644 --- a/src/ray/raylet/agent_manager.h +++ b/src/ray/raylet/agent_manager.h @@ -64,9 +64,10 @@ class AgentManager : public rpc::AgentManagerServiceHandler { /// Request agent to create a runtime env. /// \param[in] runtime_env The runtime env. - virtual void CreateRuntimeEnv(const JobID &job_id, - const std::string &serialized_runtime_env, - CreateRuntimeEnvCallback callback); + virtual void CreateRuntimeEnv( + const JobID &job_id, const std::string &serialized_runtime_env, + const std::string &serialized_allocated_resource_instances, + CreateRuntimeEnvCallback callback); /// Request agent to delete a list of URIs. /// \param[in] URIs The list of URIs to delete. @@ -80,6 +81,8 @@ class AgentManager : public rpc::AgentManagerServiceHandler { Options options_; pid_t agent_pid_ = 0; int agent_port_ = 0; + /// The number of times the agent is restarted. + std::atomic agent_restart_count_ = 0; std::string agent_ip_address_; DelayExecutorFn delay_executor_; RuntimeEnvAgentClientFactoryFn runtime_env_agent_client_factory_; diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index aa096b3f1e86b..bdceb344cbab6 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -212,6 +212,7 @@ int main(int argc, char *argv[]) { // Configuration for the object manager. ray::ObjectManagerConfig object_manager_config; + object_manager_config.object_manager_address = node_ip_address; object_manager_config.object_manager_port = object_manager_port; object_manager_config.store_socket_name = store_socket_name; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 4260542319060..5970d06a2e7db 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -252,7 +252,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self temp_dir_(config.temp_dir), initial_config_(config), dependency_manager_(object_manager_), - node_manager_server_("NodeManager", config.node_manager_port), + node_manager_server_("NodeManager", config.node_manager_port, + config.node_manager_address == "127.0.0.1"), node_manager_service_(io_service, *this), agent_manager_service_handler_( new DefaultAgentManagerServiceHandler(agent_manager_)), @@ -372,7 +373,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self }, /*runtime_env_agent_factory=*/ [this](const std::string &ip_address, int port) { - RAY_CHECK(!ip_address.empty() && port != 0); + RAY_CHECK(!ip_address.empty() && port != 0) + << "ip_address: " << ip_address << " port: " << port; return std::shared_ptr( new rpc::RuntimeEnvAgentClient(ip_address, port, client_call_manager_)); }); @@ -525,7 +527,7 @@ void NodeManager::DestroyWorker(std::shared_ptr worker, } void NodeManager::HandleJobStarted(const JobID &job_id, const JobTableData &job_data) { - RAY_LOG(DEBUG) << "HandleJobStarted " << job_id; + RAY_LOG(DEBUG) << "HandleJobStarted for job " << job_id; worker_pool_.HandleJobStarted(job_id, job_data.config()); // NOTE: Technically `HandleJobStarted` isn't idempotent because we'll // increment the ref count multiple times. This is fine because @@ -1868,7 +1870,8 @@ void NodeManager::FinishAssignedActorCreationTask(WorkerInterface &worker, auto job_id = task.GetTaskSpecification().JobId(); auto job_config = worker_pool_.GetJobConfig(job_id); RAY_CHECK(job_config); - runtime_env_manager_.AddURIReference(actor_id.Hex(), job_config->runtime_env()); + runtime_env_manager_.AddURIReference(actor_id.Hex(), + task.GetTaskSpecification().RuntimeEnv()); } } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index a699635c439f7..fed0ce1012290 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -48,7 +48,6 @@ namespace ray { namespace raylet { -using rpc::ActorTableData; using rpc::ErrorType; using rpc::GcsNodeInfo; using rpc::HeartbeatTableData; @@ -273,13 +272,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// returned to idle. bool FinishAssignedTask(const std::shared_ptr &worker_ptr); - /// Helper function to produce actor table data for a newly created actor. - /// - /// \param task_spec RayTask specification of the actor creation task that created the - /// actor. - /// \param worker The port that the actor is listening on. - std::shared_ptr CreateActorTableDataFromCreationTask( - const TaskSpecification &task_spec, int port, const WorkerID &worker_id); /// Handle a worker finishing an assigned actor creation task. /// \param worker The worker that finished the task. /// \param task The actor task or actor creation task. diff --git a/src/ray/raylet/placement_group_resource_manager.cc b/src/ray/raylet/placement_group_resource_manager.cc index 8639689edb949..d9ccfd1ac0574 100644 --- a/src/ray/raylet/placement_group_resource_manager.cc +++ b/src/ray/raylet/placement_group_resource_manager.cc @@ -152,6 +152,9 @@ void NewPlacementGroupResourceManager::ReturnBundle( // will be resource leak. cluster_resource_scheduler_->DeleteLocalResource(resource.first); deleted.push_back(resource.first); + } else { + RAY_LOG(DEBUG) << "Available bundle resource:[" << resource.first + << "] is not empty. Resources are not deleted from the local node."; } } pg_bundles_.erase(it); diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index b8040b6f8acdc..424b83def75c7 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -15,7 +15,7 @@ #include "ray/raylet/raylet.h" #include -#include +#include #include #include diff --git a/src/ray/raylet/scheduling/cluster_resource_data.cc b/src/ray/raylet/scheduling/cluster_resource_data.cc index ea4ae6621f6b5..f19287d0915f5 100644 --- a/src/ray/raylet/scheduling/cluster_resource_data.cc +++ b/src/ray/raylet/scheduling/cluster_resource_data.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/raylet/scheduling/cluster_resource_data.h" + #include "ray/common/bundle_spec.h" #include "ray/common/task/scheduling_resources.h" @@ -536,7 +537,7 @@ bool TaskResourceInstances::IsEmpty() const { return true; } -std::string TaskResourceInstances::DebugString() const { +std::string TaskResourceInstances::DebugString(const StringIdMap &string_id_map) const { std::stringstream buffer; buffer << std::endl << " Allocation: {"; for (size_t i = 0; i < this->predefined_resources.size(); i++) { @@ -547,7 +548,7 @@ std::string TaskResourceInstances::DebugString() const { buffer << " ["; for (auto it = this->custom_resources.begin(); it != this->custom_resources.end(); ++it) { - buffer << it->first << ":" << VectorToString(it->second) << ", "; + buffer << string_id_map.Get(it->first) << ":" << VectorToString(it->second) << ", "; } buffer << "]" << std::endl; diff --git a/src/ray/raylet/scheduling/cluster_resource_data.h b/src/ray/raylet/scheduling/cluster_resource_data.h index 0398726f39d42..783ab12da9eee 100644 --- a/src/ray/raylet/scheduling/cluster_resource_data.h +++ b/src/ray/raylet/scheduling/cluster_resource_data.h @@ -138,7 +138,7 @@ class TaskResourceInstances { /// Check whether there are no resource instances. bool IsEmpty() const; /// Returns human-readable string for these resources. - std::string DebugString() const; + [[nodiscard]] std::string DebugString(const StringIdMap &string_id_map) const; }; /// Total and available capacities of each resource of a node. @@ -189,7 +189,7 @@ class NodeResourceInstances { /// Returns if this equals another node resources. bool operator==(const NodeResourceInstances &other); /// Returns human-readable string for these resources. - std::string DebugString(StringIdMap string_to_int_map) const; + [[nodiscard]] std::string DebugString(StringIdMap string_to_int_map) const; }; struct Node { diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc index 6fcff8a501c55..1174f138395e0 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc @@ -456,8 +456,7 @@ void ClusterResourceScheduler::AddLocalResourceInstances( for (size_t i = 0; i < instances.size(); i++) { node_instances->available[i] += instances[i]; - node_instances->total[i] = - std::max(node_instances->total[i], node_instances->available[i]); + node_instances->total[i] += instances[i]; } UpdateLocalAvailableResourcesFromResourceInstances(); } diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index 1b90e93fb1bf4..323675b827a57 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -716,31 +716,34 @@ void ClusterTaskManager::FillResourceUsage( TaskSpecification::GetSchedulingClass(one_cpu_resource_set)); { num_reported++; - int count = 0; + int ready_count = 0; auto it = tasks_to_schedule_.find(one_cpu_scheduling_cls); if (it != tasks_to_schedule_.end()) { - count += it->second.size(); + ready_count += it->second.size(); } it = tasks_to_dispatch_.find(one_cpu_scheduling_cls); if (it != tasks_to_dispatch_.end()) { - count += it->second.size(); + ready_count += it->second.size(); } - - if (count > 0) { + int infeasible_count = 0; + it = infeasible_tasks_.find(one_cpu_scheduling_cls); + if (it != infeasible_tasks_.end()) { + infeasible_count += it->second.size(); + } + const int total_count = ready_count + infeasible_count; + if (total_count > 0) { auto by_shape_entry = resource_load_by_shape->Add(); - for (const auto &resource : one_cpu_resource_set.GetResourceMap()) { + for (const auto &[label, quantity] : one_cpu_resource_set.GetResourceMap()) { // Add to `resource_loads`. - const auto &label = resource.first; - const auto &quantity = resource.second; - (*resource_loads)[label] += quantity * count; + (*resource_loads)[label] += quantity * total_count; // Add to `resource_load_by_shape`. (*by_shape_entry->mutable_shape())[label] = quantity; } - int num_ready = by_shape_entry->num_ready_requests_queued(); - by_shape_entry->set_num_ready_requests_queued(num_ready + count); + by_shape_entry->set_num_ready_requests_queued(ready_count); + by_shape_entry->set_num_infeasible_requests_queued(infeasible_count); auto backlog_it = backlog_tracker_.find(one_cpu_scheduling_cls); if (backlog_it != backlog_tracker_.end()) { @@ -1196,8 +1199,6 @@ void ClusterTaskManager::ScheduleAndDispatchTasks() { } void ClusterTaskManager::SpillWaitingTasks() { - RAY_LOG(DEBUG) << "Attempting to spill back from waiting task queue, num waiting: " - << waiting_task_queue_.size(); // Try to spill waiting tasks to a remote node, prioritizing those at the end // of the queue. Waiting tasks are spilled if there are enough remote // resources AND (we have no resources available locally OR their diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index 78fe7320c8631..3ef8906d40afa 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -48,8 +48,7 @@ class MockWorkerPool : public WorkerPoolInterface { void PopWorker(const TaskSpecification &task_spec, const PopWorkerCallback &callback, const std::string &allocated_instances_serialized_json) { num_pops++; - const WorkerCacheKey env = { - task_spec.OverrideEnvironmentVariables(), task_spec.SerializedRuntimeEnv(), {}}; + const WorkerCacheKey env = {task_spec.SerializedRuntimeEnv(), {}}; const int runtime_env_hash = env.IntHash(); callbacks[runtime_env_hash].push_back(callback); } @@ -101,10 +100,11 @@ class MockWorkerPool : public WorkerPoolInterface { int num_pops; }; -std::shared_ptr CreateSingleNodeScheduler( - const std::string &id, double num_gpus = 0.0) { +std::shared_ptr CreateSingleNodeScheduler(const std::string &id, + double num_cpus, + double num_gpus) { std::unordered_map local_node_resources; - local_node_resources[ray::kCPU_ResourceLabel] = 8; + local_node_resources[ray::kCPU_ResourceLabel] = num_cpus; local_node_resources[ray::kGPU_ResourceLabel] = num_gpus; local_node_resources[ray::kMemory_ResourceLabel] = 128; @@ -116,16 +116,18 @@ std::shared_ptr CreateSingleNodeScheduler( RayTask CreateTask(const std::unordered_map &required_resources, int num_args = 0, std::vector args = {}, - std::string serialized_runtime_env = "{}") { + const std::string &serialized_runtime_env = "{}", + const std::vector &runtime_env_uris = {}) { TaskSpecBuilder spec_builder; TaskID id = RandomTaskId(); JobID job_id = RandomJobId(); rpc::Address address; - spec_builder.SetCommonTaskSpec( - id, "dummy_task", Language::PYTHON, - FunctionDescriptorBuilder::BuildPython("", "", "", ""), job_id, TaskID::Nil(), 0, - TaskID::Nil(), address, 0, required_resources, {}, - std::make_pair(PlacementGroupID::Nil(), -1), true, "", serialized_runtime_env); + spec_builder.SetCommonTaskSpec(id, "dummy_task", Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("", "", "", ""), + job_id, TaskID::Nil(), 0, TaskID::Nil(), address, 0, + required_resources, {}, + std::make_pair(PlacementGroupID::Nil(), -1), true, "", + serialized_runtime_env, runtime_env_uris); if (!args.empty()) { for (auto &arg : args) { @@ -177,39 +179,41 @@ class MockTaskDependencyManager : public TaskDependencyManagerInterface { class ClusterTaskManagerTest : public ::testing::Test { public: - ClusterTaskManagerTest(double num_gpus_at_head = 0.0) + ClusterTaskManagerTest(double num_cpus_at_head = 8.0, double num_gpus_at_head = 0.0) : id_(NodeID::FromRandom()), - scheduler_(CreateSingleNodeScheduler(id_.Binary(), num_gpus_at_head)), + scheduler_( + CreateSingleNodeScheduler(id_.Binary(), num_cpus_at_head, num_gpus_at_head)), is_owner_alive_(true), node_info_calls_(0), announce_infeasible_task_calls_(0), dependency_manager_(missing_objects_), - task_manager_(id_, scheduler_, dependency_manager_, - /* is_owner_alive= */ - [this](const WorkerID &worker_id, const NodeID &node_id) { - return is_owner_alive_; - }, - /* get_node_info= */ - [this](const NodeID &node_id) { - node_info_calls_++; - return node_info_[node_id]; - }, - /* announce_infeasible_task= */ - [this](const RayTask &task) { announce_infeasible_task_calls_++; }, - pool_, leased_workers_, - /* get_task_arguments= */ - [this](const std::vector &object_ids, - std::vector> *results) { - for (auto &obj_id : object_ids) { - if (missing_objects_.count(obj_id) == 0) { - results->emplace_back(MakeDummyArg()); - } else { - results->emplace_back(nullptr); - } - } - return true; - }, - /*max_pinned_task_arguments_bytes=*/1000) {} + task_manager_( + id_, scheduler_, dependency_manager_, + /* is_owner_alive= */ + [this](const WorkerID &worker_id, const NodeID &node_id) { + return is_owner_alive_; + }, + /* get_node_info= */ + [this](const NodeID &node_id) { + node_info_calls_++; + return node_info_[node_id]; + }, + /* announce_infeasible_task= */ + [this](const RayTask &task) { announce_infeasible_task_calls_++; }, pool_, + leased_workers_, + /* get_task_arguments= */ + [this](const std::vector &object_ids, + std::vector> *results) { + for (auto &obj_id : object_ids) { + if (missing_objects_.count(obj_id) == 0) { + results->emplace_back(MakeDummyArg()); + } else { + results->emplace_back(nullptr); + } + } + return true; + }, + /*max_pinned_task_arguments_bytes=*/1000) {} RayObject *MakeDummyArg() { std::vector data; @@ -287,7 +291,15 @@ class ClusterTaskManagerTest : public ::testing::Test { // Same as ClusterTaskManagerTest, but the head node starts with 4.0 num gpus. class ClusterTaskManagerTestWithGPUsAtHead : public ClusterTaskManagerTest { public: - ClusterTaskManagerTestWithGPUsAtHead() : ClusterTaskManagerTest(4.0) {} + ClusterTaskManagerTestWithGPUsAtHead() + : ClusterTaskManagerTest(/*num_cpus_at_head=*/8.0, /*num_gpus_at_head=*/4.0) {} +}; + +// Same as ClusterTaskManagerTest, but the head node starts with 0.0 num cpus. +class ClusterTaskManagerTestWithoutCPUsAtHead : public ClusterTaskManagerTest { + public: + ClusterTaskManagerTestWithoutCPUsAtHead() + : ClusterTaskManagerTest(/*num_cpus_at_head=*/0.0) {} }; TEST_F(ClusterTaskManagerTest, BasicTest) { @@ -367,8 +379,7 @@ TEST_F(ClusterTaskManagerTest, DispatchQueueNonBlockingTest) { pool_.TriggerCallbacks(); // Push a worker that can only run task A. - const WorkerCacheKey env_A = { - /*override_environment_variables=*/{}, serialized_runtime_env_A, {}}; + const WorkerCacheKey env_A = {serialized_runtime_env_A, {}}; const int runtime_env_hash_A = env_A.IntHash(); std::shared_ptr worker_A = std::make_shared(WorkerID::FromRandom(), 1234, runtime_env_hash_A); @@ -1525,6 +1536,50 @@ TEST_F(ClusterTaskManagerTest, PopWorkerExactlyOnce) { AssertNoLeaks(); } +// Regression test for https://github.com/ray-project/ray/issues/16935: +// When a task requires 1 CPU and is infeasible because head node has 0 CPU, +// make sure the task's resource demand is reported. +TEST_F(ClusterTaskManagerTestWithoutCPUsAtHead, OneCpuInfeasibleTask) { + rpc::RequestWorkerLeaseReply reply; + bool callback_occurred = false; + bool *callback_occurred_ptr = &callback_occurred; + auto callback = [callback_occurred_ptr](const Status &, const std::function &, + const std::function &) { + *callback_occurred_ptr = true; + }; + + constexpr int num_cases = 5; + // Create 5 tasks with different CPU requests. + const std::array cpu_request = {1, 2, 1, 3, 1}; + // Each type of CPU request corresponds to a types of resource demand. + const std::array demand_types = {1, 2, 2, 3, 3}; + // Number of infeasible 1 CPU requests.. + const std::array num_infeasible_1cpu = {1, 1, 2, 2, 3}; + + for (int i = 0; i < num_cases; ++i) { + RayTask task = CreateTask({{ray::kCPU_ResourceLabel, cpu_request[i]}}); + task_manager_.QueueAndScheduleTask(task, &reply, callback); + pool_.TriggerCallbacks(); + + // The task cannot run because there is only 1 node (head) with 0 CPU. + ASSERT_FALSE(callback_occurred); + ASSERT_EQ(leased_workers_.size(), 0); + ASSERT_EQ(pool_.workers.size(), 0); + ASSERT_EQ(node_info_calls_, 0); + + rpc::ResourcesData data; + task_manager_.FillResourceUsage(data); + const auto &resource_load_by_shape = data.resource_load_by_shape(); + ASSERT_EQ(resource_load_by_shape.resource_demands().size(), demand_types[i]); + + // 1 CPU demand currently is always the 1st. + const auto &demand = resource_load_by_shape.resource_demands()[0]; + EXPECT_EQ(demand.num_infeasible_requests_queued(), num_infeasible_1cpu[i]); + ASSERT_EQ(demand.shape().size(), 1); + ASSERT_EQ(demand.shape().at("CPU"), 1); + } +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/src/ray/raylet/scheduling/fixed_point.cc b/src/ray/raylet/scheduling/fixed_point.cc deleted file mode 100644 index ec0b3ed9af16d..0000000000000 --- a/src/ray/raylet/scheduling/fixed_point.cc +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2020-2021 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ray/raylet/scheduling/fixed_point.h" - -#include - -FixedPoint::FixedPoint(double d) { i_ = (uint64_t)(d * RESOURCE_UNIT_SCALING); } - -FixedPoint::FixedPoint(int i) { i_ = (i * RESOURCE_UNIT_SCALING); } - -FixedPoint::FixedPoint(uint32_t i) { i_ = (i * RESOURCE_UNIT_SCALING); } - -FixedPoint::FixedPoint(int64_t i) : FixedPoint((double)i) {} - -FixedPoint::FixedPoint(uint64_t i) : FixedPoint((double)i) {} - -FixedPoint FixedPoint::operator+(FixedPoint const &ru) const { - FixedPoint res; - res.i_ = i_ + ru.i_; - return res; -} - -FixedPoint FixedPoint::operator+=(FixedPoint const &ru) { - i_ += ru.i_; - return *this; -} - -FixedPoint FixedPoint::operator-(FixedPoint const &ru) const { - FixedPoint res; - res.i_ = i_ - ru.i_; - return res; -} - -FixedPoint FixedPoint::operator-=(FixedPoint const &ru) { - i_ -= ru.i_; - return *this; -} - -FixedPoint FixedPoint::operator-() const { - FixedPoint res; - res.i_ = -i_; - return res; -} - -FixedPoint FixedPoint::operator+(double const d) const { - FixedPoint res; - res.i_ = i_ + (int64_t)(d * RESOURCE_UNIT_SCALING); - return res; -} - -FixedPoint FixedPoint::operator-(double const d) const { - FixedPoint res; - res.i_ = i_ - (int64_t)(d * RESOURCE_UNIT_SCALING); - return res; -} - -FixedPoint FixedPoint::operator=(double const d) { - i_ = (int64_t)(d * RESOURCE_UNIT_SCALING); - return *this; -} - -FixedPoint FixedPoint::operator+=(double const d) { - i_ += (int64_t)(d * RESOURCE_UNIT_SCALING); - return *this; -} - -FixedPoint FixedPoint::operator+=(int64_t const ru) { - *this += (double)ru; - return *this; -} - -bool FixedPoint::operator<(FixedPoint const &ru1) const { return (i_ < ru1.i_); }; -bool FixedPoint::operator>(FixedPoint const &ru1) const { return (i_ > ru1.i_); }; -bool FixedPoint::operator<=(FixedPoint const &ru1) const { return (i_ <= ru1.i_); }; -bool FixedPoint::operator>=(FixedPoint const &ru1) const { return (i_ >= ru1.i_); }; -bool FixedPoint::operator==(FixedPoint const &ru1) const { return (i_ == ru1.i_); }; -bool FixedPoint::operator!=(FixedPoint const &ru1) const { return (i_ != ru1.i_); }; - -std::ostream &operator<<(std::ostream &out, FixedPoint const &ru1) { - out << ru1.i_; - return out; -} - -double FixedPoint::Double() const { return round(i_) / RESOURCE_UNIT_SCALING; }; diff --git a/src/ray/raylet/scheduling/fixed_point.h b/src/ray/raylet/scheduling/fixed_point.h index f133397ec6251..a18ffd1873218 100644 --- a/src/ray/raylet/scheduling/fixed_point.h +++ b/src/ray/raylet/scheduling/fixed_point.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include @@ -25,41 +26,85 @@ class FixedPoint { int64_t i_ = 0; public: - FixedPoint() = default; - FixedPoint(double d); - FixedPoint(int i); - FixedPoint(uint32_t i); - FixedPoint(int64_t i); - FixedPoint(uint64_t i); - - FixedPoint operator+(FixedPoint const &ru) const; - - FixedPoint operator+=(FixedPoint const &ru); - - FixedPoint operator-(FixedPoint const &ru) const; - - FixedPoint operator-=(FixedPoint const &ru); - - FixedPoint operator-() const; - - FixedPoint operator+(double const d) const; - - FixedPoint operator-(double const d) const; - - FixedPoint operator=(double const d); - - FixedPoint operator+=(double const d); - - FixedPoint operator+=(int64_t const ru); - - bool operator<(FixedPoint const &ru1) const; - bool operator>(FixedPoint const &ru1) const; - bool operator<=(FixedPoint const &ru1) const; - bool operator>=(FixedPoint const &ru1) const; - bool operator==(FixedPoint const &ru1) const; - bool operator!=(FixedPoint const &ru1) const; - - double Double() const; + FixedPoint() : FixedPoint(0.0) {} + FixedPoint(double d) { i_ = (uint64_t)(d * RESOURCE_UNIT_SCALING); } // NOLINT + + FixedPoint(int i) { i_ = (i * RESOURCE_UNIT_SCALING); } // NOLINT + + FixedPoint(uint32_t i) { i_ = (i * RESOURCE_UNIT_SCALING); } // NOLINT + + FixedPoint(int64_t i) : FixedPoint((double)i) {} // NOLINT + + FixedPoint(uint64_t i) : FixedPoint((double)i) {} // NOLINT + + FixedPoint operator+(FixedPoint const &ru) const { + FixedPoint res; + res.i_ = i_ + ru.i_; + return res; + } + + FixedPoint &operator+=(FixedPoint const &ru) { + i_ += ru.i_; + return *this; + } + + FixedPoint operator-(FixedPoint const &ru) const { + FixedPoint res; + res.i_ = i_ - ru.i_; + return res; + } + + FixedPoint &operator-=(FixedPoint const &ru) { + i_ -= ru.i_; + return *this; + } + + FixedPoint operator-() const { + FixedPoint res; + res.i_ = -i_; + return res; + } + + FixedPoint operator+(double const d) const { + FixedPoint res; + res.i_ = i_ + static_cast(d * RESOURCE_UNIT_SCALING); + return res; + } + + FixedPoint operator-(double const d) const { + FixedPoint res; + res.i_ = i_ + static_cast(d * RESOURCE_UNIT_SCALING); + return res; + } + + FixedPoint operator=(double const d) { + i_ = static_cast(d * RESOURCE_UNIT_SCALING); + return *this; + } + + FixedPoint operator+=(double const d) { + i_ += static_cast(d * RESOURCE_UNIT_SCALING); + return *this; + } + + FixedPoint operator+=(int64_t const ru) { + *this += static_cast(ru); + return *this; + } + + bool operator<(FixedPoint const &ru1) const { return (i_ < ru1.i_); }; + bool operator>(FixedPoint const &ru1) const { return (i_ > ru1.i_); }; + bool operator<=(FixedPoint const &ru1) const { return (i_ <= ru1.i_); }; + bool operator>=(FixedPoint const &ru1) const { return (i_ >= ru1.i_); }; + bool operator==(FixedPoint const &ru1) const { return (i_ == ru1.i_); }; + bool operator!=(FixedPoint const &ru1) const { return (i_ != ru1.i_); }; + + [[nodiscard]] double Double() const { return round(i_) / RESOURCE_UNIT_SCALING; }; friend std::ostream &operator<<(std::ostream &out, FixedPoint const &ru1); }; + +inline std::ostream &operator<<(std::ostream &out, FixedPoint const &ru1) { + out << ru1.i_; + return out; +} diff --git a/src/ray/raylet/scheduling/scheduling_policy.cc b/src/ray/raylet/scheduling/scheduling_policy.cc index 4bf28bdb75a21..40c1ca39605d8 100644 --- a/src/ray/raylet/scheduling/scheduling_policy.cc +++ b/src/ray/raylet/scheduling/scheduling_policy.cc @@ -57,7 +57,7 @@ int64_t HybridPolicyWithFilter(const ResourceRequest &resource_request, if (node_filter == NodeFilter::kGPU) { return has_gpu; } - RAY_CHECK(node_filter == NodeFilter::kCPUOnly); + RAY_CHECK(node_filter == NodeFilter::kNonGpu); return !has_gpu; }; @@ -149,16 +149,18 @@ int64_t HybridPolicy(const ResourceRequest &resource_request, const int64_t loca spread_threshold, force_spillback, require_available); } - // Try schedule on CPU-only nodes. - const auto node_id = - HybridPolicyWithFilter(resource_request, local_node_id, nodes, spread_threshold, - force_spillback, require_available, NodeFilter::kCPUOnly); - if (node_id != -1) { - return node_id; + // Try schedule on non-GPU nodes. + auto best_node_id = HybridPolicyWithFilter( + resource_request, local_node_id, nodes, spread_threshold, force_spillback, + /*require_available*/ true, NodeFilter::kNonGpu); + if (best_node_id != -1) { + return best_node_id; } - // Could not schedule on CPU-only nodes, schedule on GPU nodes as a last resort. + + // If we cannot find any available node from non-gpu nodes, fallback to the original + // scheduling return HybridPolicyWithFilter(resource_request, local_node_id, nodes, spread_threshold, - force_spillback, require_available, NodeFilter::kGPU); + force_spillback, require_available); } } // namespace raylet_scheduling_policy diff --git a/src/ray/raylet/scheduling/scheduling_policy.h b/src/ray/raylet/scheduling/scheduling_policy.h index b6f382ff1d078..b137491576690 100644 --- a/src/ray/raylet/scheduling/scheduling_policy.h +++ b/src/ray/raylet/scheduling/scheduling_policy.h @@ -62,8 +62,15 @@ int64_t HybridPolicy( bool force_spillback, bool require_available, bool scheduler_avoid_gpu_nodes = RayConfig::instance().scheduler_avoid_gpu_nodes()); -// -enum class NodeFilter { kAny, kGPU, kCPUOnly }; +enum class NodeFilter { + /// Default scheduling. + kAny, + /// Schedule on GPU only nodes. + kGPU, + /// Schedule on nodes that don't have GPU. Since GPUs are more scarce resources, we need + /// special handling for this. + kNonGpu +}; /// \param resource_request: The resource request we're attempting to schedule. /// \param local_node_id: The id of the local node, which is needed for traversal order. @@ -72,7 +79,7 @@ enum class NodeFilter { kAny, kGPU, kCPUOnly }; /// truncated to 0. /// \param node_filter: defines the subset of nodes were are allowed to schedule on. /// can be one of kAny (can schedule on all nodes), kGPU (can only schedule on kGPU -/// nodes), kCPUOnly (can only schedule on non-GPU nodes. +/// nodes), kNonGpu (can only schedule on non-GPU nodes. /// /// \return -1 if the task is unfeasible, otherwise the node id (key in `nodes`) to /// schedule on. diff --git a/src/ray/raylet/scheduling/scheduling_policy_test.cc b/src/ray/raylet/scheduling/scheduling_policy_test.cc index fb51d7f4c8711..6a834db1966e9 100644 --- a/src/ray/raylet/scheduling/scheduling_policy_test.cc +++ b/src/ray/raylet/scheduling/scheduling_policy_test.cc @@ -338,6 +338,42 @@ TEST_F(SchedulingPolicyTest, ForceSpillbackOnlyFeasibleLocallyTest) { ASSERT_EQ(to_schedule, -1); } +TEST_F(SchedulingPolicyTest, NonGpuNodePreferredSchedulingTest) { + // Prefer to schedule on CPU nodes first. + // GPU nodes should be preferred as a last resort. + StringIdMap map; + int64_t local_node = 0; + int64_t remote_node_1 = 1; + int64_t remote_node_2 = 2; + + // local {CPU:2, GPU:1} + // Remote {CPU: 2} + absl::flat_hash_map nodes; + nodes.emplace(local_node, CreateNodeResources(2, 2, 0, 0, 1, 1)); + nodes.emplace(remote_node_1, CreateNodeResources(2, 2, 0, 0, 0, 0)); + nodes.emplace(remote_node_2, CreateNodeResources(3, 3, 0, 0, 0, 0)); + + ResourceRequest req = ResourceMapToResourceRequest(map, {{"CPU", 1}}, false); + int to_schedule = raylet_scheduling_policy::HybridPolicy( + req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, remote_node_1); + + req = ResourceMapToResourceRequest(map, {{"CPU", 3}}, false); + to_schedule = raylet_scheduling_policy::HybridPolicy( + req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, remote_node_2); + + req = ResourceMapToResourceRequest(map, {{"CPU", 1}, {"GPU", 1}}, false); + to_schedule = raylet_scheduling_policy::HybridPolicy( + req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, local_node); + + req = ResourceMapToResourceRequest(map, {{"CPU", 2}}, false); + to_schedule = raylet_scheduling_policy::HybridPolicy( + req, local_node, nodes, 0.51, false, true, /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, remote_node_1); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 08331a75f176d..fd2b7b723f755 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -14,7 +14,7 @@ #include "ray/raylet/worker.h" -#include +#include #include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/raylet.h" diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index f0268021280f8..959cc551f0dbc 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -176,7 +176,6 @@ Process WorkerPool::StartWorkerProcess( const Language &language, const rpc::WorkerType worker_type, const JobID &job_id, PopWorkerStatus *status, const std::vector &dynamic_options, const int runtime_env_hash, const std::string &serialized_runtime_env, - std::unordered_map override_environment_variables, const std::string &serialized_runtime_env_context, const std::string &allocated_instances_serialized_json) { rpc::JobConfig *job_config = nullptr; @@ -313,39 +312,41 @@ Process WorkerPool::StartWorkerProcess( // need to add a new CLI parameter for both Python and Java workers. env.emplace(kEnvVarKeyJobId, job_id.Hex()); } - if (job_config) { - env.insert(job_config->worker_env().begin(), job_config->worker_env().end()); - } - - for (const auto &pair : override_environment_variables) { - env[pair.first] = pair.second; - } - if (language == Language::PYTHON) { + if (language == Language::PYTHON || language == Language::JAVA) { if (serialized_runtime_env != "{}" && serialized_runtime_env != "") { worker_command_args.push_back("--serialized-runtime-env=" + serialized_runtime_env); // Allocated_resource_json is only used in "shim process". worker_command_args.push_back("--allocated-instances-serialized-json=" + allocated_instances_serialized_json); + + worker_command_args.push_back("--language=" + Language_Name(language)); + + worker_command_args.push_back("--runtime-env-hash=" + + std::to_string(runtime_env_hash)); + + if (serialized_runtime_env_context != "{}" && + !serialized_runtime_env_context.empty()) { + worker_command_args.push_back("--serialized-runtime-env-context=" + + serialized_runtime_env_context); + } } else { // The "shim process" setup worker is not needed, so do not run it. // Check that the arg really is the path to the setup worker before erasing it, to // prevent breaking tests that mock out the worker command args. if (worker_command_args.size() >= 2 && worker_command_args[1].find(kSetupWorkerFilename) != std::string::npos) { - worker_command_args.erase(worker_command_args.begin() + 1, - worker_command_args.begin() + 2); + if (language == Language::PYTHON) { + worker_command_args.erase(worker_command_args.begin() + 1, + worker_command_args.begin() + 2); + } else { + // Erase the python executable as well for other languages. + worker_command_args.erase(worker_command_args.begin(), + worker_command_args.begin() + 2); + } } } - worker_command_args.push_back("--runtime-env-hash=" + - std::to_string(runtime_env_hash)); - - if (serialized_runtime_env_context != "{}" && serialized_runtime_env_context != "") { - worker_command_args.push_back("--serialized-runtime-env-context=" + - serialized_runtime_env_context); - } - if (ray_debugger_external) { worker_command_args.push_back("--ray-debugger-external"); } @@ -483,6 +484,24 @@ void WorkerPool::MarkPortAsFree(int port) { void WorkerPool::HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config) { all_jobs_[job_id] = job_config; + if (job_config.runtime_env().runtime_env_eager_install() && + job_config.has_runtime_env()) { + auto const &runtime_env = job_config.runtime_env().serialized_runtime_env(); + RAY_LOG(INFO) << "[Eagerly] Start install runtime environment for job " << job_id + << ". The runtime environment was " << runtime_env << "."; + CreateRuntimeEnv( + runtime_env, job_id, + [job_id](bool successful, const std::string &serialized_runtime_env_context) { + if (successful) { + RAY_LOG(INFO) << "[Eagerly] Create runtime env successful for job " << job_id + << ". The result context was " << serialized_runtime_env_context + << "."; + } else { + RAY_LOG(ERROR) << "[Eagerly] Couldn't create a runtime environment for job " + << job_id << "."; + } + }); + } } void WorkerPool::HandleJobFinished(const JobID &job_id) { @@ -749,7 +768,7 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { // The worker is used for the actor creation task with dynamic options. if (!used) { // Put it into idle dedicated worker pool. - // TODO(guyang.sgy): This worker will not be used forever. We should kill it. + // TODO(SongGuyang): This worker will not be used forever. We should kill it. state.idle_dedicated_workers[task_id] = worker; } return; @@ -921,7 +940,8 @@ void WorkerPool::TryKillingIdleWorkers() { void WorkerPool::PopWorker(const TaskSpecification &task_spec, const PopWorkerCallback &callback, const std::string &allocated_instances_serialized_json) { - RAY_LOG(DEBUG) << "Pop worker for task " << task_spec.TaskId(); + RAY_LOG(DEBUG) << "Pop worker for task " << task_spec.TaskId() << " task name " + << task_spec.FunctionDescriptor()->ToString(); auto &state = GetStateForLanguage(task_spec.GetLanguage()); std::shared_ptr worker = nullptr; @@ -936,8 +956,7 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, Process proc = StartWorkerProcess( task_spec.GetLanguage(), rpc::WorkerType::WORKER, task_spec.JobId(), &status, dynamic_options, task_spec.GetRuntimeEnvHash(), serialized_runtime_env, - task_spec.OverrideEnvironmentVariables(), serialized_runtime_env_context, - allocated_instances_serialized_json); + serialized_runtime_env_context, allocated_instances_serialized_json); if (status == PopWorkerStatus::OK) { RAY_CHECK(proc.IsValid()); WarnAboutSize(); @@ -948,7 +967,7 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, state.starting_workers_to_tasks[proc] = std::move(task_info); } } else { - // TODO(guyang.sgy): Wait until a worker is pushed or a worker can be started If + // TODO(SongGuyang): Wait until a worker is pushed or a worker can be started If // startup concurrency maxed out or job not started. PopWorkerCallbackAsync(callback, nullptr, status); } @@ -976,24 +995,24 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, dynamic_options = task_spec.DynamicWorkerOptions(); } - // create runtime env. if (task_spec.HasRuntimeEnv()) { - agent_manager_->CreateRuntimeEnv( - task_spec.JobId(), task_spec.SerializedRuntimeEnv(), - [start_worker_process_fn, callback, &state, task_spec, dynamic_options, - allocated_instances_serialized_json]( - bool success, const std::string &serialized_runtime_env_context) { - if (success) { + // create runtime env. + CreateRuntimeEnv( + task_spec.SerializedRuntimeEnv(), task_spec.JobId(), + [start_worker_process_fn, callback, &state, task_spec, dynamic_options]( + bool successful, const std::string &serialized_runtime_env_context) { + if (successful) { start_worker_process_fn(task_spec, state, dynamic_options, true, task_spec.SerializedRuntimeEnv(), serialized_runtime_env_context, callback); } else { - RAY_LOG(WARNING) << "Couldn't create a runtime environment for task " - << task_spec.TaskId() << ". The runtime environment was " - << task_spec.SerializedRuntimeEnv() << "."; callback(nullptr, PopWorkerStatus::RuntimeEnvCreationFailed); + RAY_LOG(WARNING) + << "Create runtime env failed for task " << task_spec.TaskId() + << " and couldn't create the dedicated worker."; } - }); + }, + allocated_instances_serialized_json); } else { start_worker_process_fn(task_spec, state, dynamic_options, true, "", "", callback); @@ -1036,8 +1055,8 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, // Start a new worker process. if (task_spec.HasRuntimeEnv()) { // create runtime env. - agent_manager_->CreateRuntimeEnv( - task_spec.JobId(), task_spec.SerializedRuntimeEnv(), + CreateRuntimeEnv( + task_spec.SerializedRuntimeEnv(), task_spec.JobId(), [start_worker_process_fn, callback, &state, task_spec]( bool successful, const std::string &serialized_runtime_env_context) { if (successful) { @@ -1045,12 +1064,13 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, task_spec.SerializedRuntimeEnv(), serialized_runtime_env_context, callback); } else { - RAY_LOG(WARNING) << "Couldn't create a runtime environment for task " - << task_spec.TaskId() << ". The runtime environment was " - << task_spec.SerializedRuntimeEnv() << "."; callback(nullptr, PopWorkerStatus::RuntimeEnvCreationFailed); + RAY_LOG(WARNING) + << "Create runtime env failed for task " << task_spec.TaskId() + << " and couldn't create the worker."; } - }); + }, + allocated_instances_serialized_json); } else { start_worker_process_fn(task_spec, state, {}, false, "", "", callback); } @@ -1067,7 +1087,7 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, int64_t bac int64_t num_available_cpus) { // Code path of task that needs a dedicated worker. if ((task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) || - task_spec.OverrideEnvironmentVariables().size() > 0 || task_spec.HasRuntimeEnv()) { + task_spec.HasRuntimeEnv()) { return; // Not handled. // TODO(architkulkarni): We'd eventually like to prestart workers with the same // runtime env to improve initial startup performance. @@ -1324,6 +1344,26 @@ WorkerPool::IOWorkerState &WorkerPool::GetIOWorkerStateFromWorkerType( UNREACHABLE; } +void WorkerPool::CreateRuntimeEnv( + const std::string &serialized_runtime_env, const JobID &job_id, + const std::function &callback, + const std::string &serialized_allocated_resource_instances) { + // create runtime env. + agent_manager_->CreateRuntimeEnv( + job_id, serialized_runtime_env, serialized_allocated_resource_instances, + [job_id, serialized_runtime_env, callback]( + bool successful, const std::string &serialized_runtime_env_context) { + if (successful) { + callback(true, serialized_runtime_env_context); + } else { + RAY_LOG(WARNING) << "Couldn't create a runtime environment for job " << job_id + << ". The runtime environment was " << serialized_runtime_env + << "."; + callback(false, ""); + } + }); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 7991600cfd6c6..92c19329c17dc 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -397,7 +397,6 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { PopWorkerStatus *status /*output*/, const std::vector &dynamic_options = {}, const int runtime_env_hash = 0, const std::string &serialized_runtime_env = "{}", - std::unordered_map override_environment_variables = {}, const std::string &serialized_runtime_env_context = "{}", const std::string &allocated_instances_serialized_json = "{}"); @@ -589,6 +588,12 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { const PopWorkerStatus &status, bool *found /* output */, bool *worker_used /* output */, TaskID *task_id /* output */); + /// Create runtime env asynchronously by runtime env agent. + void CreateRuntimeEnv( + const std::string &serialized_runtime_env, const JobID &job_id, + const std::function &callback, + const std::string &serialized_allocated_resource_instances = "{}"); + /// For Process class for managing subprocesses (e.g. reaping zombies). instrumented_io_context *io_service_; /// Node ID of the current node. diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 9a28520700a8e..37fb903b4a7ab 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -103,9 +103,10 @@ class WorkerPoolMock : public WorkerPool { const WorkerCommandMap &worker_commands, absl::flat_hash_map> &mock_worker_rpc_clients) - : WorkerPool(io_service, NodeID::FromRandom(), "", POOL_SIZE_SOFT_LIMIT, 0, - MAXIMUM_STARTUP_CONCURRENCY, 0, 0, {}, nullptr, worker_commands, - []() {}, 0, [this]() { return current_time_ms_; }), + : WorkerPool( + io_service, NodeID::FromRandom(), "", POOL_SIZE_SOFT_LIMIT, 0, + MAXIMUM_STARTUP_CONCURRENCY, 0, 0, {}, nullptr, worker_commands, []() {}, 0, + [this]() { return current_time_ms_; }), last_worker_process_(), instrumented_io_service_(io_service), error_message_type_(1), @@ -257,7 +258,7 @@ class WorkerPoolMock : public WorkerPool { is_java = true; } } - // TODO(guyang.sgy): support C++ language workers. + // TODO(SongGuyang): support C++ language workers. int num_workers = is_java ? NUM_WORKERS_PER_PROCESS_JAVA : 1; for (int i = 0; i < num_workers; i++) { auto worker = @@ -458,7 +459,7 @@ static inline TaskSpecification ExampleTaskSpec( } else { message.set_type(TaskType::NORMAL_TASK); } - message.set_serialized_runtime_env(serialized_runtime_env); + message.mutable_runtime_env()->set_serialized_runtime_env(serialized_runtime_env); return TaskSpecification(std::move(message)); } @@ -1257,8 +1258,7 @@ TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) { ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), /*dynamic_options=*/{}, TaskID::ForFakeTask(), "mock_runtime_env_2"); - const WorkerCacheKey env1 = { - /*override_environment_variables=*/{}, "mock_runtime_env_1", {}}; + const WorkerCacheKey env1 = {"mock_runtime_env_1", {}}; const int runtime_env_hash_1 = env1.IntHash(); // Push worker with runtime env 1. diff --git a/src/ray/raylet_client/raylet_client.cc b/src/ray/raylet_client/raylet_client.cc index 41e9611491c7d..290c5bd068898 100644 --- a/src/ray/raylet_client/raylet_client.cc +++ b/src/ray/raylet_client/raylet_client.cc @@ -296,13 +296,20 @@ Status raylet::RayletClient::FreeObjects(const std::vector &object_ids } void raylet::RayletClient::RequestWorkerLease( - const TaskSpecification &resource_spec, + const rpc::TaskSpec &task_spec, const rpc::ClientCallback &callback, const int64_t backlog_size) { - rpc::RequestWorkerLeaseRequest request; - request.mutable_resource_spec()->CopyFrom(resource_spec.GetMessage()); - request.set_backlog_size(backlog_size); - grpc_client_->RequestWorkerLease(request, callback); + google::protobuf::Arena arena; + auto request = + google::protobuf::Arena::CreateMessage(&arena); + // The unsafe allocating here is actually safe because the life-cycle of + // task_spec is longer than request. + // Request will be sent before the end of this call, and after that, it won't be + // used any more. + request->unsafe_arena_set_allocated_resource_spec( + const_cast(&task_spec)); + request->set_backlog_size(backlog_size); + grpc_client_->RequestWorkerLease(*request, callback); } /// Spill objects to external storage. @@ -373,7 +380,7 @@ void raylet::RayletClient::CommitBundleResources( } void raylet::RayletClient::CancelResourceReserve( - BundleSpecification &bundle_spec, + const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) { rpc::CancelResourceReserveRequest request; request.mutable_bundle_spec()->CopyFrom(bundle_spec.GetMessage()); diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index 558fed24b24cf..323837e513fe1 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -68,6 +68,10 @@ class WorkerLeaseInterface { const ray::TaskSpecification &resource_spec, const ray::rpc::ClientCallback &callback, const int64_t backlog_size = -1) = 0; + virtual void RequestWorkerLease( + const rpc::TaskSpec &task_spec, + const ray::rpc::ClientCallback &callback, + const int64_t backlog_size = -1) = 0; /// Returns a worker to the raylet. /// \param worker_port The local port of the worker on the raylet node. @@ -117,7 +121,7 @@ class ResourceReserveInterface { const ray::rpc::ClientCallback &callback) = 0; virtual void CancelResourceReserve( - BundleSpecification &bundle_spec, + const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) = 0; virtual void ReleaseUnusedBundles( @@ -360,6 +364,13 @@ class RayletClient : public RayletClientInterface { void RequestWorkerLease( const ray::TaskSpecification &resource_spec, const ray::rpc::ClientCallback &callback, + const int64_t backlog_size) override { + RequestWorkerLease(resource_spec.GetMessage(), callback, backlog_size); + } + + void RequestWorkerLease( + const rpc::TaskSpec &resource_spec, + const ray::rpc::ClientCallback &callback, const int64_t backlog_size) override; /// Implements WorkerLeaseInterface. @@ -389,7 +400,7 @@ class RayletClient : public RayletClientInterface { /// Implements CancelResourceReserveInterface. void CancelResourceReserve( - BundleSpecification &bundle_spec, + const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) override; diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 3840527fb5a9a..8f3f98b67445c 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -36,11 +36,13 @@ DEFINE_stats(grpc_server_req_finished, "Finished request number in grpc server", namespace ray { namespace rpc { -GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, - bool use_tls, int64_t keepalive_time_ms) +GrpcServer::GrpcServer(std::string name, const uint32_t port, + bool listen_to_localhost_only, int num_threads, + int64_t keepalive_time_ms, bool use_tls) : name_(std::move(name)), port_(port), use_tls_(use_tls), + listen_to_localhost_only_(listen_to_localhost_only), is_closed_(true), num_threads_(num_threads), keepalive_time_ms_(keepalive_time_ms) { @@ -49,7 +51,8 @@ GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads, void GrpcServer::Run() { uint32_t specified_port = port_; - std::string server_address("0.0.0.0:" + std::to_string(port_)); + std::string server_address((listen_to_localhost_only_ ? "127.0.0.1:" : "0.0.0.0:") + + std::to_string(port_)); grpc::ServerBuilder builder; // Disable the SO_REUSEPORT option. We don't need it in ray. If the option is enabled // (default behavior in grpc), we may see multiple workers listen on the same port and diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 826efbdf260bb..6d39f77bf5287 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -61,9 +61,10 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - GrpcServer(std::string name, const uint32_t port, int num_threads = 1, - bool use_tls = false, - int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/); + GrpcServer(std::string name, const uint32_t port, bool listen_to_localhost_only, + int num_threads = 1, + int64_t keepalive_time_ms = 7200000, /*2 hours, grpc default*/ + bool use_tls = false); /// Destruct this gRPC server. ~GrpcServer() { Shutdown(); } @@ -112,8 +113,14 @@ class GrpcServer { const std::string name_; /// Port of this server. int port_; +<<<<<<< HEAD /// Whether to use TLS. bool use_tls_; +======= + /// Listen to localhost (127.0.0.1) only if it's true, otherwise listen to all network + /// interfaces (0.0.0.0) + const bool listen_to_localhost_only_; +>>>>>>> master /// Indicates whether this server has been closed. bool is_closed_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index d3c199b50c6bb..9e2d50e8324e4 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -14,10 +14,11 @@ #pragma once +#include #include -#include #include + #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/grpc_util.h" #include "ray/common/status.h" @@ -145,6 +146,7 @@ class ServerCallImpl : public ServerCall { response_writer_(&context_), io_service_(io_service), call_name_(std::move(call_name)) { + reply_ = google::protobuf::Arena::CreateMessage(&arena_); // TODO call_name_ sometimes get corrunpted due to memory issues. RAY_CHECK(!call_name_.empty()) << "Call name is empty"; STATS_grpc_server_req_new.Record(1.0, call_name_); @@ -187,7 +189,7 @@ class ServerCallImpl : public ServerCall { factory.CreateCall(); } (service_handler_.*handle_request_function_)( - request_, &reply_, + request_, reply_, [this](Status status, std::function success, std::function failure) { // These two callbacks must be set before `SendReply`, because `SendReply` @@ -222,9 +224,13 @@ class ServerCallImpl : public ServerCall { /// Tell gRPC to finish this request and send reply asynchronously. void SendReply(const Status &status) { state_ = ServerCallState::SENDING_REPLY; - response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this); + response_writer_.Finish(*reply_, RayStatusToGrpcStatus(status), this); } + /// The memory pool for this request. It's used for reply. + /// With arena, we'll be able to setup the reply without copying some field. + google::protobuf::Arena arena_; + /// State of this call. ServerCallState state_; @@ -250,8 +256,9 @@ class ServerCallImpl : public ServerCall { /// The request message. Request request_; - /// The reply message. - Reply reply_; + /// The reply message. This one is owned by arena. It's not valid beyond + /// the life-cycle of this call. + Reply *reply_; /// Human-readable name for this RPC call. std::string call_name_; diff --git a/src/ray/rpc/test/grpc_server_client_test.cc b/src/ray/rpc/test/grpc_server_client_test.cc index e7b602e6b316f..3bd86f5a24f63 100644 --- a/src/ray/rpc/test/grpc_server_client_test.cc +++ b/src/ray/rpc/test/grpc_server_client_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include + #include "gtest/gtest.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" @@ -35,13 +36,14 @@ class TestServiceHandler { RAY_LOG(INFO) << "No reply!"; return; } - send_reply_callback(ray::Status::OK(), - /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, - /*reply_failure=*/ - [this]() { - RAY_LOG(INFO) << "Reply failed."; - reply_failure_count++; - }); + send_reply_callback( + ray::Status::OK(), + /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, + /*reply_failure=*/ + [this]() { + RAY_LOG(INFO) << "Reply failed."; + reply_failure_count++; + }); } std::atomic request_count{0}; std::atomic reply_failure_count{0}; @@ -83,7 +85,7 @@ class TestGrpcServerClientFixture : public ::testing::Test { handler_io_service_.run(); }); test_service_.reset(new TestGrpcService(handler_io_service_, test_service_handler_)); - grpc_server_.reset(new GrpcServer("test", 0)); + grpc_server_.reset(new GrpcServer("test", 0, true)); grpc_server_->RegisterService(*test_service_); grpc_server_->Run(); diff --git a/src/ray/util/event.h b/src/ray/util/event.h index 9caed946f3af1..4f2e98a4427c3 100644 --- a/src/ray/util/event.h +++ b/src/ray/util/event.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once +#include + #include #include #include @@ -22,6 +24,8 @@ #include #include #include + +#include "nlohmann/json.hpp" #include "ray/util/logging.h" #include "ray/util/util.h" #include "spdlog/sinks/basic_file_sink.h" @@ -29,10 +33,6 @@ #include "spdlog/spdlog.h" #include "src/ray/protobuf/event.pb.h" -#include "nlohmann/json.hpp" - -#include - using json = nlohmann::json; namespace ray { @@ -102,7 +102,7 @@ class EventManager final { // We added `const json &custom_fields` here because we need to support typed custom // fields. - // TODO(guyang.sgy): Remove the protobuf `rpc::Event` and use an internal struct + // TODO(SongGuyang): Remove the protobuf `rpc::Event` and use an internal struct // instead. void Publish(const rpc::Event &event, const json &custom_fields); diff --git a/src/ray/util/util.h b/src/ray/util/util.h index 9b2e3f443dbac..95500e91694a7 100644 --- a/src/ray/util/util.h +++ b/src/ray/util/util.h @@ -21,7 +21,6 @@ #include #include #include - #include #include "ray/util/logging.h" @@ -167,7 +166,7 @@ class InitShutdownRAII { /// \param shutdown_func The shutdown function. /// \param args The arguments for the init function. template - InitShutdownRAII(InitFunc init_func, ShutdownFunc shutdown_func, Args &&... args) + InitShutdownRAII(InitFunc init_func, ShutdownFunc shutdown_func, Args &&...args) : shutdown_(shutdown_func) { init_func(args...); } @@ -259,7 +258,7 @@ template class ThreadPrivate { public: template - ThreadPrivate(Ts &&... ts) : t_(std::forward(ts)...) {} + explicit ThreadPrivate(Ts &&...ts) : t_(std::forward(ts)...) {} T &operator*() { ThreadCheck(); @@ -312,4 +311,43 @@ class ThreadPrivate { mutable std::mutex mutex_; }; +class ExponentialBackOff { + public: + ExponentialBackOff() = default; + ExponentialBackOff(const ExponentialBackOff &) = default; + ExponentialBackOff(ExponentialBackOff &&) = default; + ExponentialBackOff &operator=(const ExponentialBackOff &) = default; + ExponentialBackOff &operator=(ExponentialBackOff &&) = default; + + /// Construct an exponential back off counter. + /// + /// \param[in] initial_value The start value for this counter + /// \param[in] multiplier The multiplier for this counter. + /// \param[in] max_value The maximum value for this counter. By default it's + /// infinite double. + ExponentialBackOff(uint64_t initial_value, double multiplier, + uint64_t max_value = std::numeric_limits::max()) + : curr_value_(initial_value), + initial_value_(initial_value), + max_value_(max_value), + multiplier_(multiplier) { + RAY_CHECK(multiplier > 0.0) << "Multiplier must be greater than 0"; + } + + uint64_t Next() { + auto ret = curr_value_; + curr_value_ = curr_value_ * multiplier_; + curr_value_ = std::min(curr_value_, max_value_); + return ret; + } + + void Reset() { curr_value_ = initial_value_; } + + private: + uint64_t curr_value_; + uint64_t initial_value_; + uint64_t max_value_; + double multiplier_; +}; + } // namespace ray diff --git a/src/ray/util/util_test.cc b/src/ray/util/util_test.cc index 435f1598f4f69..3e13dedb10bf9 100644 --- a/src/ray/util/util_test.cc +++ b/src/ray/util/util_test.cc @@ -102,6 +102,23 @@ TEST(UtilTest, ParseCommandLineTest) { ASSERT_EQ(ParseCommandLine(R"(x' a \b')", win32), ArgList({R"(x')", R"(a)", R"(\b')"})); } +TEST(UtilTest, ExponentialBackOffTest) { + auto exp = ExponentialBackOff(1, 2, 9); + ASSERT_EQ(1, exp.Next()); + ASSERT_EQ(2, exp.Next()); + ASSERT_EQ(4, exp.Next()); + ASSERT_EQ(8, exp.Next()); + ASSERT_EQ(9, exp.Next()); + ASSERT_EQ(9, exp.Next()); + exp.Reset(); + ASSERT_EQ(1, exp.Next()); + ASSERT_EQ(2, exp.Next()); + ASSERT_EQ(4, exp.Next()); + ASSERT_EQ(8, exp.Next()); + ASSERT_EQ(9, exp.Next()); + ASSERT_EQ(9, exp.Next()); +} + TEST(UtilTest, ParseURLTest) { const std::string url = "http://abc?num_objects=9&offset=8388878&size=8388878"; auto parsed_url = *ParseURL(url); diff --git a/streaming/src/queue/queue_handler.h b/streaming/src/queue/queue_handler.h index e04e34b359804..26bd863e85ecc 100644 --- a/streaming/src/queue/queue_handler.h +++ b/streaming/src/queue/queue_handler.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include #include diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index c51b1a8a11a5b..5e5b575223a6b 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -639,7 +639,7 @@ class StreamingWorker { } // namespace ray int main(int argc, char **argv) { - RAY_CHECK(argc == 5); + RAY_CHECK(argc >= 4); auto store_socket = std::string(argv[1]); auto raylet_socket = std::string(argv[2]); auto node_manager_port = std::stoi(std::string(argv[3])); diff --git a/thirdparty/patches/prometheus-windows-pollfd.patch b/thirdparty/patches/prometheus-windows-pollfd.patch index 1941b6cb247c0..3b30942bb85f2 100644 --- a/thirdparty/patches/prometheus-windows-pollfd.patch +++ b/thirdparty/patches/prometheus-windows-pollfd.patch @@ -6,17 +6,46 @@ Windows Vista and later SDKs define struct pollfd for WSAPoll(), but it has a pe civetweb provides its own implementation of poll, but it has a conflicting definition for pollfd. Hence we block Windows from defining pollfd (which this project doesn't use). --- - bazel/civetweb.BUILD | 1 + - 1 file changed, 1 insertion(+) + bazel/civetweb.BUILD | 7 +++++++ + 1 file changed, 7 insertions(+) diff --git bazel/civetweb.BUILD bazel/civetweb.BUILD --- bazel/civetweb.BUILD +++ bazel/civetweb.BUILD -@@ -34,5 +34,6 @@ cc_library( +@@ -9,6 +9,11 @@ config_setting( + values = {"cpu": "darwin_x86_64"}, + ) + ++config_setting( ++ name = "darwin_arm64", ++ values = {"cpu": "darwin_arm64"}, ++) ++ + config_setting( + name = "windows", + values = { "cpu": "x64_windows" }, +@@ -34,6 +39,7 @@ cc_library( "-DNO_CACHING", "-DNO_SSL", "-DNO_FILES", + "-D_WIN32_WINNT=0x0502", "-UDEBUG", ], --- + includes = [ +@@ -46,6 +52,7 @@ cc_library( + }) + select({ + ":darwin": [], + ":darwin_x86_64": [], ++ ":darwin_arm64": [], + ":windows": [], + ":windows_msvc": [], + "//conditions:default": ["-lrt"], +@@ -86,6 +93,7 @@ cc_library( + }) + select({ + ":darwin": [], + ":darwin_x86_64": [], ++ ":darwin_arm64": [], + ":windows": [], + ":windows_msvc": [], + "//conditions:default": ["-lrt"], +-- diff --git a/thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch b/thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch deleted file mode 100644 index 9cd53fe60f842..0000000000000 --- a/thirdparty/patches/rules_boost-undefine-boost_fallthrough.patch +++ /dev/null @@ -1,8 +0,0 @@ -diff --git BUILD.boost BUILD.boost ---- BUILD.boost -+++ BUILD.boost -@@ -1356,3 +1356,2 @@ boost_library( - defines = [ -- "BOOST_FALLTHROUGH", - ], --- diff --git a/thirdparty/patches/rules_boost-windows-linkopts.patch b/thirdparty/patches/rules_boost-windows-linkopts.patch index 28bda4eb06939..204443d3c7186 100644 --- a/thirdparty/patches/rules_boost-windows-linkopts.patch +++ b/thirdparty/patches/rules_boost-windows-linkopts.patch @@ -1,15 +1,12 @@ diff --git BUILD.boost BUILD.boost --- BUILD.boost +++ BUILD.boost -@@ -313,1 +313,9 @@ boost_library(name = "asio", -- linkopts = ["-lpthread"], -+ linkopts = select({ -+ ":linux": [ -+ "-lpthread", -+ ], -+ ":osx_x86_64": [ -+ "-lpthread", -+ ], -+ "//conditions:default": [], -+ }), --- +@@ -428,6 +428,7 @@ boost_library( + }), + linkopts = select({ + ":android": [], ++ ":windows": [], + "//conditions:default": ["-lpthread"], + }), + deps = [ +-- From 78bbb341c605b6ffda2c66c6b25c99ac08d58917 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 12 Oct 2021 16:35:19 +0100 Subject: [PATCH 42/56] Squashed commit of the following: commit 1593350efe1e9520171eb52ade25fd1022c512f6 Merge: 504399b4f 2c9370832 Author: Oscar Knagg Date: Tue Oct 12 16:05:41 2021 +0100 Merge remote-tracking branch 'origin' into tls-working commit 504399b4fa895c2720719dbe7213e688caafbdee Author: Oscar Knagg Date: Tue Oct 12 14:56:17 2021 +0100 format.sh changes commit 7b23f9e51f061b47bc2f94fc3bbea902d493c785 Author: Oscar Knagg Date: Tue Oct 12 14:54:13 2021 +0100 Fix tests commit fdbe8eb93e9b66f4b99baeb45955c483be675b01 Author: Oscar Knagg Date: Tue Oct 12 13:46:17 2021 +0100 Move functions around commit 36ce6ac6079e61644cd69af56d5887d10aec99cc Merge: a33e32f23 8241a03d3 Author: Oscar Knagg Date: Tue Oct 12 13:40:10 2021 +0100 Merge branch 'master' of https://github.com/ray-project/ray into tls-working commit a33e32f2316a48e0c0018652dc14935b6b033dac Author: Oscar Knagg Date: Tue Oct 12 13:33:51 2021 +0100 Fix bad import commit 263e8f66c2d081eb6a57d577ab2a7d8f4564934f Author: Oscar Knagg Date: Tue Oct 12 12:57:41 2021 +0100 Add TLS configuration to ray_config_def.h commit 425ce874e2a18b460f29059b88323351c883e1e6 Author: Oscar Knagg Date: Tue Oct 12 12:57:17 2021 +0100 Formatting commit b510a7b4c2b96b495a81db2293f1340ef4da4bbc Author: Oscar Knagg Date: Mon Oct 11 15:48:54 2021 +0100 Move tests into separate file commit 97df18536f12cdc7928e296f415cdd0b16e6b26a Author: Oscar Knagg Date: Mon Oct 11 14:09:25 2021 +0100 load_certs_from_env -> tls_utils commit fb1b05c3a216df5f9e4f861139b459f6994395cf Author: Oscar Knagg Date: Mon Oct 11 14:08:09 2021 +0100 Docs v1 commit 9e95bb12600becb29db932095326af31c5a22afc Author: Oscar Knagg Date: Mon Oct 11 11:52:17 2021 +0100 tls_utils file commit 85968934db01436983475ad3bfd838073ed0f211 Merge: d04fe6de3 ab55b808c Author: Oscar Knagg Date: Mon Oct 11 11:46:07 2021 +0100 Merge branch 'master' of https://github.com/ray-project/ray into tls --- dashboard/agent.py | 57 +++++----------- dashboard/head.py | 2 +- doc/source/configure.rst | 22 ++++++ python/ray/_private/test_utils.py | 68 +++---------------- python/ray/_private/tls_utils.py | 85 ++++++++++++++++++++++++ python/ray/_private/utils.py | 28 +------- python/ray/tests/test_basic.py | 15 ++--- python/ray/tests/test_tls_auth.py | 66 ++++++++++++++++-- python/ray/util/client/server/proxier.py | 4 +- python/ray/util/client/server/server.py | 2 +- python/ray/util/client/worker.py | 4 +- src/ray/common/ray_config_def.h | 8 +++ src/ray/rpc/grpc_server.h | 2 + 13 files changed, 215 insertions(+), 148 deletions(-) create mode 100644 python/ray/_private/tls_utils.py diff --git a/dashboard/agent.py b/dashboard/agent.py index 0046528a44f27..f56e76f61fff9 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -11,7 +11,6 @@ import traceback from grpc.experimental import aio as aiogrpc -from distutils.version import LooseVersion import ray import ray.dashboard.consts as dashboard_consts @@ -84,7 +83,7 @@ def __init__(self, assert self.ppid > 0 logger.info("Parent pid is %s", self.ppid) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = ray._private.utils.add_port_to_grpc_server( + self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( self.server, f"[::]:{self.dashboard_agent_port}") logger.info("Dashboard agent grpc address: %s:%s", self.ip, self.grpc_port) @@ -144,12 +143,8 @@ async def _check_parent(): sys.exit(-1) # Create a http session for all modules. - # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore - if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"): - self.http_session = aiohttp.ClientSession( - loop=asyncio.get_event_loop()) - else: - self.http_session = aiohttp.ClientSession() + self.http_session = aiohttp.ClientSession( + loop=asyncio.get_event_loop()) # Start a grpc asyncio server. await self.server.start() @@ -342,8 +337,8 @@ async def _check_parent(): # https://github.com/ray-project/ray/issues/14026. if sys.platform == "win32": logger.warning( - "The dashboard is currently disabled on windows. " - "See https://github.com/ray-project/ray/issues/14026 " + "The dashboard is currently disabled on windows." + "See https://github.com/ray-project/ray/issues/14026" "for more details") while True: time.sleep(999) @@ -367,34 +362,14 @@ async def _check_parent(): loop = asyncio.get_event_loop() loop.run_until_complete(agent.run()) except Exception as e: - # All these env vars should be available because - # they are provided by the parent raylet. - restart_count = os.environ["RESTART_COUNT"] - max_restart_count = os.environ["MAX_RESTART_COUNT"] - raylet_pid = os.environ["RAY_RAYLET_PID"] - node_ip = args.node_ip_address - if restart_count >= max_restart_count: - # Agent is failed to be started many times. - # Push an error to all drivers, so that users can know the - # impact of the issue. - redis_client = ray._private.services.create_redis_client( - args.redis_address, password=args.redis_password) - traceback_str = ray._private.utils.format_error_message( - traceback.format_exc()) - message = ( - f"(ip={node_ip}) " - f"The agent on node {platform.uname()[1]} failed to " - f"be restarted {max_restart_count} " - "times. There are 3 possible problems if you see this error." - "\n 1. The dashboard might not display correct " - "information on this node." - "\n 2. Metrics on this node won't be reported." - "\n 3. runtime_env APIs won't work." - "\nCheck out the `dashboard_agent.log` to see the " - "detailed failure messages.") - ray._private.utils.push_error_to_driver_through_redis( - redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, - message) - logger.error(message) - logger.exception(e) - exit(1) + # Something went wrong, so push an error to all drivers. + redis_client = ray._private.services.create_redis_client( + args.redis_address, password=args.redis_password) + traceback_str = ray._private.utils.format_error_message( + traceback.format_exc()) + message = ("The agent on node {} failed with the following " + "error:\n{}".format(platform.uname()[1], traceback_str)) + ray._private.utils.push_error_to_driver_through_redis( + redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, message) + logger.exception(message) + raise e diff --git a/dashboard/head.py b/dashboard/head.py index 5c52d86cd5aad..c7cc857c5c787 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -121,7 +121,7 @@ def __init__(self, http_host, http_port, http_port_retries, redis_address, ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = ray._private.utils.add_port_to_grpc_server( + self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( self.server, "[::]:0") logger.info("Dashboard head grpc address: %s:%s", self.ip, self.grpc_port) diff --git a/doc/source/configure.rst b/doc/source/configure.rst index 5e93b2c6e4f82..186255d855373 100644 --- a/doc/source/configure.rst +++ b/doc/source/configure.rst @@ -234,6 +234,28 @@ to localhost when the ray is started using ``ray.init``. See the `Redis security documentation `__ for more information. +TLS Authentication +------------------ + +Ray can be configured to use TLS on it's gRPC channels. +This has means that connecting to the Ray client on the head node will +require an appropriate set of credentials and also that data exchanged between +various processes (client, head, workers) will be encrypted. + +Enabling TLS will cause a performance hit due to the extra overhead of mutual +authentication and encryption. +Testing has shown that this overhead is large for small workloads and becomes +relatively smaller for large workloads. +The exact overhead will depend on the nature of your workload. + +TLS is enabled by setting environment variables. + +- ``RAY_USE_TLS``: Either 1 or 0 to use/not-use TLS. If this is set to 1 then all of the environment variables below must be set. Default: 0. +- ``RAY_TLS_SERVER_CERT``: Location of a `certificate file` which is presented to other endpoints so as to achieve mutual authentication. +- ``RAY_TLS_SERVER_KEY``: Location of a `private key file` which is the cryptographic means to prove to other endpoints that you are the authorized user of a given certificate. +- ``RAY_TLS_CA_CERT``: Location of a `CA certificate file` which allows TLS to decide whether an endpoint's certificate has been signed by the correct authority. + + Java Applications ----------------- diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 50bb3d13c008b..8da4ac9f03e69 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -7,24 +7,25 @@ import pathlib import subprocess import sys -import tempfile import time import timeit import math import traceback -import datetime from typing import Optional, Any, List, Dict from contextlib import redirect_stdout, redirect_stderr import yaml import socket import pytest +import tempfile import ray import ray._private.services import ray._private.utils import ray._private.gcs_utils as gcs_utils +from ray._private.tls_utils import generate_self_signed_tls_certs from ray.util.queue import Queue, _QueueActor, Empty from ray.scripts.scripts import main as ray_main + try: from prometheus_client.parser import text_string_to_metric_families except (ImportError, ModuleNotFoundError): @@ -690,57 +691,11 @@ async def get_batch(self, return batch -def generate_self_signed_tls_certs(): - """Create self-signed key/cert pair for testing. - - This method requires the library ``cryptography`` be installed. - """ - try: - from cryptography import x509 - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import hashes, serialization - from cryptography.hazmat.primitives.asymmetric import rsa - from cryptography.x509.oid import NameOID - except ImportError: - raise ImportError( - "Using `Security.temporary` requires `cryptography`, please " - "install it using either pip or conda") - key = rsa.generate_private_key( - public_exponent=65537, key_size=2048, backend=default_backend()) - key_contents = key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ).decode() - - ray_interal = x509.Name( - [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) - # This is the same logic used by the GCS server to acquire a - # private/interal IP address to listen on. If we just use localhost + - # 127.0.0.1 then we won't be able to connect to the GCS and will get - # an error like "No match found for server name: 192.168.X.Y" - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.connect(("8.8.8.8", 80)) - private_ip_address = s.getsockname()[0] - s.close() - altnames = x509.SubjectAlternativeName([ - x509.DNSName(socket.gethostbyname( - socket.gethostname())), # Probably 127.0.0.1 - x509.DNSName("127.0.0.1"), - x509.DNSName(private_ip_address), # 192.168.*.* - x509.DNSName("localhost"), - ]) - now = datetime.datetime.utcnow() - cert = (x509.CertificateBuilder() - .subject_name(ray_interal).issuer_name(ray_interal).add_extension( - altnames, critical=False).public_key(key.public_key()) - .serial_number(x509.random_serial_number()).not_valid_before(now) - .not_valid_after(now + datetime.timedelta(days=365)).sign( - key, hashes.SHA256(), default_backend())) - - cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() - - return cert_contents, key_contents +def is_placement_group_removed(pg): + table = ray.util.placement_group_table(pg) + if "state" not in table: + return False + return table["state"] == "REMOVED" def setup_tls(): @@ -772,10 +727,3 @@ def teardown_tls(key_filepath, cert_filepath, temp_dir): del os.environ["RAY_TLS_SERVER_CERT"] del os.environ["RAY_TLS_SERVER_KEY"] del os.environ["RAY_TLS_CA_CERT"] - - -def is_placement_group_removed(pg): - table = ray.util.placement_group_table(pg) - if "state" not in table: - return False - return table["state"] == "REMOVED" diff --git a/python/ray/_private/tls_utils.py b/python/ray/_private/tls_utils.py new file mode 100644 index 0000000000000..8344d86c30c4b --- /dev/null +++ b/python/ray/_private/tls_utils.py @@ -0,0 +1,85 @@ +import datetime +import os +import socket + +import grpc + + +def generate_self_signed_tls_certs(): + """Create self-signed key/cert pair for testing. + + This method requires the library ``cryptography`` be installed. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + except ImportError: + raise ImportError( + "Using `Security.temporary` requires `cryptography`, please " + "install it using either pip or conda") + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend()) + key_contents = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + ray_interal = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) + # This is the same logic used by the GCS server to acquire a + # private/interal IP address to listen on. If we just use localhost + + # 127.0.0.1 then we won't be able to connect to the GCS and will get + # an error like "No match found for server name: 192.168.X.Y" + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + private_ip_address = s.getsockname()[0] + s.close() + altnames = x509.SubjectAlternativeName([ + x509.DNSName(socket.gethostbyname( + socket.gethostname())), # Probably 127.0.0.1 + x509.DNSName("127.0.0.1"), + x509.DNSName(private_ip_address), # 192.168.*.* + x509.DNSName("localhost"), + ]) + now = datetime.datetime.utcnow() + cert = (x509.CertificateBuilder().subject_name(ray_interal).issuer_name( + ray_interal).add_extension(altnames, critical=False).public_key( + key.public_key()).serial_number( + x509.random_serial_number()).not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=365)).sign( + key, hashes.SHA256(), default_backend())) + + cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() + + return cert_contents, key_contents + + +def add_port_to_grpc_server(server, address): + if os.environ.get("RAY_USE_TLS", "0") == "1": + server_cert_chain, private_key, ca_cert = load_certs_from_env() + credentials = grpc.ssl_server_credentials( + [(private_key, server_cert_chain)], + root_certificates=ca_cert, + require_client_auth=ca_cert is not None) + return server.add_secure_port(address, credentials) + else: + return server.add_insecure_port(address) + + +def load_certs_from_env(): + if os.environ.get("RAY_USE_TLS", "0") == "1": + with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: + server_cert_chain = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: + private_key = f.read() + if "RAY_TLS_CA_CERT" in os.environ: + with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: + ca_cert = f.read() + else: + ca_cert = None + + return server_cert_chain, private_key, ca_cert diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 37430d928dd92..50fe38ed65f74 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -27,6 +27,7 @@ import ray import ray._private.gcs_utils as gcs_utils import ray.ray_constants as ray_constants +from ray._private.tls_utils import load_certs_from_env # Import psutil after ray so the packaged version is used. import psutil @@ -1111,21 +1112,6 @@ def validate_namespace(namespace: str): "Pass None to not specify a namespace.") -def load_certs_from_env(): - if os.environ.get("RAY_USE_TLS", "0") == "1": - with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: - server_cert_chain = f.read() - with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: - private_key = f.read() - if "RAY_TLS_CA_CERT" in os.environ: - with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: - ca_cert = f.read() - else: - ca_cert = None - - return server_cert_chain, private_key, ca_cert - - def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] = None, asynchronous: bool = False): @@ -1142,15 +1128,3 @@ def init_grpc_channel(address: str, channel = grpc_module.insecure_channel(address, options=options) return channel - - -def add_port_to_grpc_server(server, address): - if os.environ.get("RAY_USE_TLS", "0") == "1": - server_cert_chain, private_key, ca_cert = load_certs_from_env() - credentials = grpc.ssl_server_credentials( - [(private_key, server_cert_chain)], - root_certificates=ca_cert, - require_client_auth=ca_cert is not None) - return server.add_secure_port(address, credentials) - else: - return server.add_insecure_port(address) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index d5b73ece9bf54..ad4d844b7c304 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -76,8 +76,7 @@ def test_omp_threads_set(shutdown_only): assert os.environ["OMP_NUM_THREADS"] == "1" -@pytest.mark.parametrize("use_tls", [False, True], indirect=True) -def test_submit_api(shutdown_only, use_tls): +def test_submit_api(shutdown_only): ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) @ray.remote @@ -141,8 +140,7 @@ def method(self): assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] -@pytest.mark.parametrize("use_tls", [False, True], indirect=True) -def test_invalid_arguments(shutdown_only, use_tls): +def test_invalid_arguments(shutdown_only): ray.init(num_cpus=2) for opt in [np.random.randint(-100, -1), np.random.uniform(0, 1)]: @@ -238,8 +236,7 @@ def check(): {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}) -@pytest.mark.parametrize("use_tls", [False, True], indirect=True) -def test_put_get(shutdown_only, use_tls): +def test_put_get(shutdown_only): ray.init(num_cpus=0) for i in range(100): @@ -268,8 +265,7 @@ def test_put_get(shutdown_only, use_tls): @pytest.mark.skipif(sys.platform != "linux", reason="Failing on Windows") -@pytest.mark.parametrize("use_tls", [False, True], indirect=True) -def test_wait_timing(shutdown_only, use_tls): +def test_wait_timing(shutdown_only): ray.init(num_cpus=2) @ray.remote @@ -303,8 +299,7 @@ def test_function_descriptor(): assert d.get(python_descriptor2) == 123 -@pytest.mark.parametrize("use_tls", [False, True], indirect=True) -def test_ray_options(shutdown_only, use_tls): +def test_ray_options(shutdown_only): ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2}) @ray.remote( diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py index 057c2e0b2ae32..01b234ceb8315 100644 --- a/python/ray/tests/test_tls_auth.py +++ b/python/ray/tests/test_tls_auth.py @@ -1,23 +1,81 @@ # coding: utf-8 +import logging import os import sys import pytest -import logging +import ray logger = logging.getLogger(__name__) +@pytest.mark.parametrize("use_tls", [True], indirect=True) +def test_put_get_with_tls(shutdown_only, use_tls): + ray.init(num_cpus=0) + + for i in range(100): + value_before = i * 10**6 + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + for i in range(100): + value_before = i * 10**6 * 1.0 + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + for i in range(100): + value_before = "h" * i + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + for i in range(100): + value_before = [1] * i + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + +@pytest.mark.parametrize("use_tls", [True], indirect=True) +def test_submit_with_tls(shutdown_only, use_tls): + ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) + + @ray.remote + def f(n): + return list(range(n)) + + id1, id2, id3 = f._remote(args=[3], num_returns=3) + assert ray.get([id1, id2, id3]) == [0, 1, 2] + + @ray.remote + class Actor: + def __init__(self, x, y=0): + self.x = x + self.y = y + + def method(self, a, b=0): + return self.x, self.y, a, b + + a = Actor._remote( + args=[0], kwargs={"y": 1}, num_gpus=1, resources={"Custom": 1}) + + id1, id2, id3, id4 = a.method._remote( + args=["test"], kwargs={"b": 2}, num_returns=4) + assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] + + @pytest.mark.skipif( sys.platform == "darwin", reason=("Cryptography doesn't install in Mac build pipeline")) @pytest.mark.parametrize("use_tls", [True], indirect=True) def test_client_connect_to_tls_server(use_tls, init_and_serve): - from ray.util.client import ray + from ray.util.client import ray as ray_client os.environ["RAY_USE_TLS"] = "0" with pytest.raises(ConnectionError): - ray.connect("localhost:50051") + ray_client.connect("localhost:50051") os.environ["RAY_USE_TLS"] = "1" - ray.connect("localhost:50051") + ray_client.connect("localhost:50051") diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 165fbbcabe8a9..98ad26c93d8b4 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -29,8 +29,8 @@ from ray._private.parameter import RayParams from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.services import ProcessInfo, start_ray_client_server -from ray._private.utils import (detect_fate_sharing_support, - add_port_to_grpc_server) +from ray._private.utils import detect_fate_sharing_support +from ray._private.tls_utils import add_port_to_grpc_server # Import psutil after ray so the packaged version is used. import psutil diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 351b981d0a17c..27a10d18e3b11 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -35,7 +35,7 @@ from ray.ray_constants import env_integer from ray.util.placement_group import PlacementGroup from ray._private.client_mode_hook import disable_client_hook -from ray._private.utils import add_port_to_grpc_server +from ray._private.tls_utils import add_port_to_grpc_server logger = logging.getLogger(__name__) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index e07f74c6c50c9..4b45ac0c761ee 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -362,8 +362,8 @@ def get(self, vals, *, timeout: Optional[float] = None) -> Any: logger.debug("Internal retry for get {}".format(to_get)) if len(to_get) != len(res): raise Exception( - "Mismatched number of items in request ({}) and response ({})" - .format(len(to_get), len(res))) + "Mismatched number of items in request ({}) and response ({})". + format(len(to_get), len(res))) if isinstance(vals, ClientObjectRef): res = res[0] return res diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 33ce0975a364d..0a6a61357b79f 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -488,3 +488,11 @@ RAY_CONFIG(bool, scheduler_avoid_gpu_nodes, true) /// Whether to skip running local GC in runtime env. RAY_CONFIG(bool, runtime_env_skip_local_gc, false) + +/// Whether or not use TLS. +RAY_CONFIG(int64_t, USE_TLS, 0) + +/// Location of TLS credentials +RAY_CONFIG(std::string, TLS_SERVER_CERT, "") +RAY_CONFIG(std::string, TLS_SERVER_KEY, "") +RAY_CONFIG(std::string, TLS_CA_CERT, "") diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 2538257f6c631..c83628b72b2e8 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -60,6 +60,8 @@ class GrpcServer { /// /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port + /// will be chosen. + GrpcServer(std::string name, const uint32_t port, bool listen_to_localhost_only, int num_threads = 1, int64_t keepalive_time_ms = 7200000, /*2 hours, grpc default*/ From 7639a6567649da01278c9c960ad3efde48ad098b Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 12 Oct 2021 17:12:01 +0100 Subject: [PATCH 43/56] Replace getenv with RayConfig --- src/ray/rpc/grpc_client.h | 17 ++++------------- src/ray/rpc/grpc_server.cc | 22 ++++++---------------- src/ray/rpc/grpc_server.h | 6 +----- 3 files changed, 11 insertions(+), 34 deletions(-) diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index cdde388f7fb6e..12daab630e939 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -54,7 +54,6 @@ class GrpcClient { argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - CheckTlSEnvironmentVariables(); std::shared_ptr channel = BuildChannel(argument, address, port); stub_ = GrpcService::NewStub(channel); @@ -71,7 +70,6 @@ class GrpcClient { argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - CheckTlSEnvironmentVariables(); std::shared_ptr channel = BuildChannel(argument, address, port); stub_ = GrpcService::NewStub(channel); @@ -108,10 +106,10 @@ class GrpcClient { std::shared_ptr BuildChannel(const grpc::ChannelArguments &argument, const std::string &address, int port) { std::shared_ptr channel; - if (use_tls_) { - std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); - std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_KEY")); - std::string root_cert_file = std::string(std::getenv("RAY_TLS_CA_CERT")); + if (::RayConfig::instance().USE_TLS()) { + std::string server_cert_file = std::string(::RayConfig::instance().TLS_SERVER_CERT()); + std::string server_key_file = std::string(::RayConfig::instance().TLS_SERVER_KEY()); + std::string root_cert_file = std::string(::RayConfig::instance().TLS_CA_CERT()); std::string server_cert_chain = ReadCert(server_cert_file); std::string private_key = ReadCert(server_key_file); std::string cacert = ReadCert(root_cert_file); @@ -130,13 +128,6 @@ class GrpcClient { return channel; }; - void CheckTlSEnvironmentVariables() { - if (std::getenv("RAY_USE_TLS")) { - use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; - } else { - use_tls_ = false; - }; - } }; } // namespace rpc diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 8f3f98b67445c..edc2e6a156047 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -38,10 +38,9 @@ namespace rpc { GrpcServer::GrpcServer(std::string name, const uint32_t port, bool listen_to_localhost_only, int num_threads, - int64_t keepalive_time_ms, bool use_tls) + int64_t keepalive_time_ms) : name_(std::move(name)), port_(port), - use_tls_(use_tls), listen_to_localhost_only_(listen_to_localhost_only), is_closed_(true), num_threads_(num_threads), @@ -67,20 +66,11 @@ void GrpcServer::Run() { RayConfig::instance().grpc_keepalive_timeout_ms()); builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0); - if (std::getenv("RAY_USE_TLS")) { - use_tls_ = std::strcmp(std::getenv("RAY_USE_TLS"), "0") != 0; - } else { - use_tls_ = false; - } - if (use_tls_) { - std::string server_cert_file = std::string(std::getenv("RAY_TLS_SERVER_CERT")); - std::string server_key_file = std::string(std::getenv("RAY_TLS_SERVER_KEY")); - std::string root_cert_file = std::string(std::getenv("RAY_TLS_CA_CERT")); - - // Create credentials from hardcoded location - std::string rootcert = ReadCert(root_cert_file); - std::string servercert = ReadCert(server_cert_file); - std::string serverkey = ReadCert(server_key_file); + if (RayConfig::instance().USE_TLS()) { + // Create credentials from locations specified in config + std::string rootcert = ReadCert(RayConfig::instance().TLS_CA_CERT()); + std::string servercert = ReadCert(RayConfig::instance().TLS_SERVER_CERT()); + std::string serverkey = ReadCert(RayConfig::instance().TLS_SERVER_KEY()); grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey.c_str(), servercert.c_str()}; grpc::SslServerCredentialsOptions ssl_opts( diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index c83628b72b2e8..6d3ad9eb203d5 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -61,11 +61,9 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - GrpcServer(std::string name, const uint32_t port, bool listen_to_localhost_only, int num_threads = 1, - int64_t keepalive_time_ms = 7200000, /*2 hours, grpc default*/ - bool use_tls = false); + int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/); /// Destruct this gRPC server. ~GrpcServer() { Shutdown(); } @@ -114,8 +112,6 @@ class GrpcServer { const std::string name_; /// Port of this server. int port_; - /// Whether to use TLS. - bool use_tls_; /// Listen to localhost (127.0.0.1) only if it's true, otherwise listen to all network /// interfaces (0.0.0.0) const bool listen_to_localhost_only_; From d95419a27058401c302e1b86057d91c517976d0a Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 12 Oct 2021 17:12:46 +0100 Subject: [PATCH 44/56] Remove lingering errors from earlier merge --- dashboard/agent.py | 55 ++++++++++++++++++------- python/ray/_private/test_utils.py | 2 +- python/ray/tests/conftest.py | 15 +------ python/ray/tests/test_client_builder.py | 8 ++-- 4 files changed, 46 insertions(+), 34 deletions(-) diff --git a/dashboard/agent.py b/dashboard/agent.py index f56e76f61fff9..3b6c7c98dcb12 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -11,6 +11,7 @@ import traceback from grpc.experimental import aio as aiogrpc +from distutils.version import LooseVersion import ray import ray.dashboard.consts as dashboard_consts @@ -143,8 +144,12 @@ async def _check_parent(): sys.exit(-1) # Create a http session for all modules. - self.http_session = aiohttp.ClientSession( - loop=asyncio.get_event_loop()) + # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore + if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"): + self.http_session = aiohttp.ClientSession( + loop=asyncio.get_event_loop()) + else: + self.http_session = aiohttp.ClientSession() # Start a grpc asyncio server. await self.server.start() @@ -337,8 +342,8 @@ async def _check_parent(): # https://github.com/ray-project/ray/issues/14026. if sys.platform == "win32": logger.warning( - "The dashboard is currently disabled on windows." - "See https://github.com/ray-project/ray/issues/14026" + "The dashboard is currently disabled on windows. " + "See https://github.com/ray-project/ray/issues/14026 " "for more details") while True: time.sleep(999) @@ -362,14 +367,34 @@ async def _check_parent(): loop = asyncio.get_event_loop() loop.run_until_complete(agent.run()) except Exception as e: - # Something went wrong, so push an error to all drivers. - redis_client = ray._private.services.create_redis_client( - args.redis_address, password=args.redis_password) - traceback_str = ray._private.utils.format_error_message( - traceback.format_exc()) - message = ("The agent on node {} failed with the following " - "error:\n{}".format(platform.uname()[1], traceback_str)) - ray._private.utils.push_error_to_driver_through_redis( - redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, message) - logger.exception(message) - raise e + # All these env vars should be available because + # they are provided by the parent raylet. + restart_count = os.environ["RESTART_COUNT"] + max_restart_count = os.environ["MAX_RESTART_COUNT"] + raylet_pid = os.environ["RAY_RAYLET_PID"] + node_ip = args.node_ip_address + if restart_count >= max_restart_count: + # Agent is failed to be started many times. + # Push an error to all drivers, so that users can know the + # impact of the issue. + redis_client = ray._private.services.create_redis_client( + args.redis_address, password=args.redis_password) + traceback_str = ray._private.utils.format_error_message( + traceback.format_exc()) + message = ( + f"(ip={node_ip}) " + f"The agent on node {platform.uname()[1]} failed to " + f"be restarted {max_restart_count} " + "times. There are 3 possible problems if you see this error." + "\n 1. The dashboard might not display correct " + "information on this node." + "\n 2. Metrics on this node won't be reported." + "\n 3. runtime_env APIs won't work." + "\nCheck out the `dashboard_agent.log` to see the " + "detailed failure messages.") + ray._private.utils.push_error_to_driver_through_redis( + redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, + message) + logger.error(message) + logger.exception(e) + exit(1) diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 8da4ac9f03e69..92df94abc7a60 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -9,12 +9,12 @@ import sys import time import timeit +import socket import math import traceback from typing import Optional, Any, List, Dict from contextlib import redirect_stdout, redirect_stderr import yaml -import socket import pytest import tempfile diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 25edefd60fd1d..50007ea56e59f 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -94,15 +94,11 @@ def ray_start_regular_shared(request): "local_mode": True }, { "local_mode": False - }, { - "local_mode": False, - "use_tls": True }]) def ray_start_shared_local_modes(request): param = getattr(request, "param", {}) use_tls = param.pop("use_tls", False) - with manage_tls(use_tls): - with _ray_start(**param) as res: + with _ray_start(**param) as res: yield res @@ -303,15 +299,6 @@ def log_pubsub(): p.close() -@contextmanager -def manage_tls(use_tls): - if use_tls: - key_filepath, cert_filepath, temp_dir = setup_tls() - yield use_tls - if use_tls: - teardown_tls(key_filepath, cert_filepath, temp_dir) - - @pytest.fixture def use_tls(request): if request.param: diff --git a/python/ray/tests/test_client_builder.py b/python/ray/tests/test_client_builder.py index 832ebe478a78a..918e165955e23 100644 --- a/python/ray/tests/test_client_builder.py +++ b/python/ray/tests/test_client_builder.py @@ -55,7 +55,7 @@ def test_namespace(): put in the same namespace. This test checks that: - * When two drivers don't specify a namespace, they are placed in different + RayConfig::instance().RAY_USE_TLS() * When two drivers don't specify a namespace, they are placed in different anonymous namespaces. * When two drivers specify a namespace, they collide. * The namespace name (as provided by the runtime context) is correct. @@ -78,13 +78,13 @@ def ping(self): print(ray.get_runtime_context().namespace) """ anon_driver = template.format(namespace="None") - run_string_as_driver(anon_driver, dict(os.environ)) + run_string_as_driver(anon_driver) # This second run will fail if the actors don't run in separate anonymous # namespaces. - run_string_as_driver(anon_driver, dict(os.environ)) + run_string_as_driver(anon_driver) run_in_namespace = template.format(namespace="'namespace'") - script_namespace = run_string_as_driver(run_in_namespace, dict(os.environ)) + script_namespace = run_string_as_driver(run_in_namespace) # The second run fails because the actors are run in the same namespace. with pytest.raises(subprocess.CalledProcessError): run_string_as_driver(run_in_namespace) From 1c92af2068c48c9c4c954fb14e5fe26d9e715dcd Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 14 Oct 2021 10:11:51 +0100 Subject: [PATCH 45/56] Address comments pt2 --- python/ray/_private/test_utils.py | 2 +- python/ray/_private/tls_utils.py | 11 ++- python/ray/_private/utils.py | 2 +- python/ray/tests/conftest.py | 6 +- python/ray/tests/test_client_builder.py | 2 +- python/ray/tests/test_tls_auth.py | 118 +++++++++++++++++------- python/ray/util/client/worker.py | 9 +- src/ray/common/ray_config_def.h | 19 +++- src/ray/rpc/common.cc | 6 +- src/ray/rpc/common.h | 2 + src/ray/rpc/grpc_client.h | 6 +- 11 files changed, 129 insertions(+), 54 deletions(-) diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 92df94abc7a60..8da4ac9f03e69 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -9,12 +9,12 @@ import sys import time import timeit -import socket import math import traceback from typing import Optional, Any, List, Dict from contextlib import redirect_stdout, redirect_stderr import yaml +import socket import pytest import tempfile diff --git a/python/ray/_private/tls_utils.py b/python/ray/_private/tls_utils.py index 8344d86c30c4b..0dcecf4512c94 100644 --- a/python/ray/_private/tls_utils.py +++ b/python/ray/_private/tls_utils.py @@ -59,7 +59,7 @@ def generate_self_signed_tls_certs(): def add_port_to_grpc_server(server, address): - if os.environ.get("RAY_USE_TLS", "0") == "1": + if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): server_cert_chain, private_key, ca_cert = load_certs_from_env() credentials = grpc.ssl_server_credentials( [(private_key, server_cert_chain)], @@ -71,7 +71,14 @@ def add_port_to_grpc_server(server, address): def load_certs_from_env(): - if os.environ.get("RAY_USE_TLS", "0") == "1": + if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): + if ("RAY_TLS_SERVER_CERT" not in os.environ) or \ + ("RAY_TLS_SERVER_KEY" not in os.environ): + raise RuntimeError( + "If the environment variable RAY_USE_TLS is set to true" + "then both RAY_TLS_SERVER_CERT and RAY_TLS_SERVER_KEY must " + "also be set.") + with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: server_cert_chain = f.read() with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 50fe38ed65f74..f4e441c7fecdc 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1116,7 +1116,7 @@ def init_grpc_channel(address: str, options: Optional[Sequence[Tuple[str, Any]]] = None, asynchronous: bool = False): grpc_module = aiogrpc if asynchronous else grpc - if os.environ.get("RAY_USE_TLS", "0") == "1": + if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): server_cert_chain, private_key, ca_cert = load_certs_from_env() credentials = grpc.ssl_channel_credentials( certificate_chain=server_cert_chain, diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 50007ea56e59f..62a2a81edbb88 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -89,17 +89,15 @@ def ray_start_regular_shared(request): @pytest.fixture( - scope="module", - params=[{ + scope="module", params=[{ "local_mode": True }, { "local_mode": False }]) def ray_start_shared_local_modes(request): param = getattr(request, "param", {}) - use_tls = param.pop("use_tls", False) with _ray_start(**param) as res: - yield res + yield res @pytest.fixture diff --git a/python/ray/tests/test_client_builder.py b/python/ray/tests/test_client_builder.py index 918e165955e23..406933ced5522 100644 --- a/python/ray/tests/test_client_builder.py +++ b/python/ray/tests/test_client_builder.py @@ -55,7 +55,7 @@ def test_namespace(): put in the same namespace. This test checks that: - RayConfig::instance().RAY_USE_TLS() * When two drivers don't specify a namespace, they are placed in different + * When two drivers don't specify a namespace, they are placed in different anonymous namespaces. * When two drivers specify a namespace, they collide. * The namespace name (as provided by the runtime context) is correct. diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py index 01b234ceb8315..909951f0036c7 100644 --- a/python/ray/tests/test_tls_auth.py +++ b/python/ray/tests/test_tls_auth.py @@ -2,18 +2,37 @@ import logging import os import sys +import subprocess import pytest -import ray +from ray._private.test_utils import run_string_as_driver logger = logging.getLogger(__name__) @pytest.mark.parametrize("use_tls", [True], indirect=True) -def test_put_get_with_tls(shutdown_only, use_tls): - ray.init(num_cpus=0) +def test_init_with_tls(use_tls): + # Run as a new process to pick up environment variables set + # in the use_tls fixture + run_string_as_driver( + """ +import ray +try: + ray.init() +finally: + ray.shutdown() + """, + env=os.environ) + +@pytest.mark.parametrize("use_tls", [True], indirect=True) +def test_put_get_with_tls(use_tls): + run_string_as_driver( + """ +import ray +ray.init() +try: for i in range(100): value_before = i * 10**6 object_ref = ray.put(value_before) @@ -37,45 +56,82 @@ def test_put_get_with_tls(shutdown_only, use_tls): object_ref = ray.put(value_before) value_after = ray.get(object_ref) assert value_before == value_after +finally: + ray.shutdown() + """, + env=os.environ) -@pytest.mark.parametrize("use_tls", [True], indirect=True) -def test_submit_with_tls(shutdown_only, use_tls): - ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) +@pytest.mark.parametrize("use_tls", [True], indirect=True, scope="module") +def test_submit_with_tls(use_tls): + run_string_as_driver( + """ +import ray +ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) - @ray.remote - def f(n): - return list(range(n)) +@ray.remote +def f(n): + return list(range(n)) - id1, id2, id3 = f._remote(args=[3], num_returns=3) - assert ray.get([id1, id2, id3]) == [0, 1, 2] +id1, id2, id3 = f._remote(args=[3], num_returns=3) +assert ray.get([id1, id2, id3]) == [0, 1, 2] - @ray.remote - class Actor: - def __init__(self, x, y=0): - self.x = x - self.y = y +@ray.remote +class Actor: + def __init__(self, x, y=0): + self.x = x + self.y = y - def method(self, a, b=0): - return self.x, self.y, a, b + def method(self, a, b=0): + return self.x, self.y, a, b - a = Actor._remote( - args=[0], kwargs={"y": 1}, num_gpus=1, resources={"Custom": 1}) +a = Actor._remote( + args=[0], kwargs={"y": 1}, num_gpus=1, resources={"Custom": 1}) - id1, id2, id3, id4 = a.method._remote( - args=["test"], kwargs={"b": 2}, num_returns=4) - assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] +id1, id2, id3, id4 = a.method._remote( + args=["test"], kwargs={"b": 2}, num_returns=4) +assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] + """, + env=os.environ) @pytest.mark.skipif( sys.platform == "darwin", reason=("Cryptography doesn't install in Mac build pipeline")) @pytest.mark.parametrize("use_tls", [True], indirect=True) -def test_client_connect_to_tls_server(use_tls, init_and_serve): - from ray.util.client import ray as ray_client - os.environ["RAY_USE_TLS"] = "0" - with pytest.raises(ConnectionError): - ray_client.connect("localhost:50051") - - os.environ["RAY_USE_TLS"] = "1" - ray_client.connect("localhost:50051") +def test_client_connect_to_tls_server(use_tls, call_ray_start): + for k, v in os.environ.items(): + if k.startswith("RAY_"): + print("export {}={}".format(k, v)) + + tls_env = os.environ.copy( + ) # use_tls fixture sets TLS environment variables + without_tls_env = {} + + # Attempt to connect without TLS + try: + out = run_string_as_driver( + """ +from ray.util.client import ray as ray_client +ray_client.connect("localhost:10001") + """, + env=without_tls_env) + except subprocess.CalledProcessError as e: + assert "ConnectionError" in e.output.decode("utf-8") + + # Attempt to connect with TLS + out = run_string_as_driver( + """ +import ray +from ray.util.client import ray as ray_client +ray_client.connect("localhost:10001") +print(ray.is_initialized()) + """, + env=tls_env) + assert out == "True\n" + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 4b45ac0c761ee..f6bcf47fe8bc2 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -101,7 +101,8 @@ def __init__( self.server = None self._conn_state = grpc.ChannelConnectivity.IDLE self._converted: Dict[str, ClientStub] = {} - self._secure = secure or os.environ.get("RAY_USE_TLS", "0") == "1" + self._secure = secure or os.environ.get("RAY_USE_TLS", + "0").lower() in ("1", "true") self._conn_str = conn_str self._connection_retries = connection_retries @@ -160,7 +161,7 @@ def _connect_channel(self, reconnecting=False) -> None: if self._secure: if self._credentials is not None: credentials = self._credentials - elif os.environ.get("RAY_USE_TLS", "0") == "1": + elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): server_cert_chain, private_key, ca_cert = ray._private.utils \ .load_certs_from_env() credentials = grpc.ssl_channel_credentials( @@ -362,8 +363,8 @@ def get(self, vals, *, timeout: Optional[float] = None) -> Any: logger.debug("Internal retry for get {}".format(to_get)) if len(to_get) != len(res): raise Exception( - "Mismatched number of items in request ({}) and response ({})". - format(len(to_get), len(res))) + "Mismatched number of items in request ({}) and response ({})" + .format(len(to_get), len(res))) if isinstance(vals, ClientObjectRef): res = res[0] return res diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 0a6a61357b79f..4dffc5f25260b 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -490,9 +490,20 @@ RAY_CONFIG(bool, scheduler_avoid_gpu_nodes, true) RAY_CONFIG(bool, runtime_env_skip_local_gc, false) /// Whether or not use TLS. -RAY_CONFIG(int64_t, USE_TLS, 0) +RAY_CONFIG(bool, USE_TLS, + std::getenv("RAY_USE_TLS") != nullptr && + (std::getenv("RAY_USE_TLS") == std::string("true") || + std::getenv("RAY_USE_TLS") == std::string("1"))) /// Location of TLS credentials -RAY_CONFIG(std::string, TLS_SERVER_CERT, "") -RAY_CONFIG(std::string, TLS_SERVER_KEY, "") -RAY_CONFIG(std::string, TLS_CA_CERT, "") +RAY_CONFIG(std::string, TLS_SERVER_CERT, + std::getenv("RAY_TLS_SERVER_CERT") != nullptr + ? std::getenv("RAY_TLS_SERVER_CERT") + : "") +RAY_CONFIG(std::string, TLS_SERVER_KEY, + std::getenv("RAY_TLS_SERVER_KEY") != nullptr + ? std::getenv("RAY_TLS_SERVER_KEY") + : "") +RAY_CONFIG(std::string, TLS_CA_CERT, + std::getenv("RAY_TLS_CA_CERT") != nullptr ? std::getenv("RAY_TLS_CA_CERT") + : "") diff --git a/src/ray/rpc/common.cc b/src/ray/rpc/common.cc index eef01f3e1e2f5..7526c1e6efc6f 100644 --- a/src/ray/rpc/common.cc +++ b/src/ray/rpc/common.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "ray/rpc/common.h" + #include #include -#include "ray/rpc/common.h" - namespace ray::rpc { std::string ReadCert(const std::string &cert_filepath) { @@ -26,4 +26,4 @@ std::string ReadCert(const std::string &cert_filepath) { return buffer.str(); }; -} // namespace rpc::ray +} // namespace ray::rpc diff --git a/src/ray/rpc/common.h b/src/ray/rpc/common.h index 929a555a942f6..314e1eccf382c 100644 --- a/src/ray/rpc/common.h +++ b/src/ray/rpc/common.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + namespace ray::rpc { // Utility to read cert file from a particular location diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 12daab630e939..2670bc0674cde 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -45,7 +45,7 @@ template class GrpcClient { public: GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, - bool use_tls = true) + bool use_tls = false) : client_call_manager_(call_manager), use_tls_(use_tls) { grpc::ChannelArguments argument; // Disable http proxy since it disrupts local connections. TODO(ekl) we should make @@ -107,7 +107,8 @@ class GrpcClient { const std::string &address, int port) { std::shared_ptr channel; if (::RayConfig::instance().USE_TLS()) { - std::string server_cert_file = std::string(::RayConfig::instance().TLS_SERVER_CERT()); + std::string server_cert_file = + std::string(::RayConfig::instance().TLS_SERVER_CERT()); std::string server_key_file = std::string(::RayConfig::instance().TLS_SERVER_KEY()); std::string root_cert_file = std::string(::RayConfig::instance().TLS_CA_CERT()); std::string server_cert_chain = ReadCert(server_cert_file); @@ -127,7 +128,6 @@ class GrpcClient { } return channel; }; - }; } // namespace rpc From 7c3f7b2e0bdd507b641a5cf18210dd851581a7ea Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Thu, 14 Oct 2021 12:36:16 +0100 Subject: [PATCH 46/56] Tidy up --- python/ray/_private/test_utils.py | 5 ++--- python/ray/_private/tls_utils.py | 30 ++++++++++++------------- python/ray/internal/internal_api.py | 6 +++-- python/ray/tests/test_client_builder.py | 1 + python/ray/tests/test_tls_auth.py | 19 +++++++++++----- 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 8da4ac9f03e69..5d6d12d5ea94f 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -9,12 +9,12 @@ import sys import time import timeit +import socket import math import traceback from typing import Optional, Any, List, Dict from contextlib import redirect_stdout, redirect_stderr import yaml -import socket import pytest import tempfile @@ -25,7 +25,6 @@ from ray._private.tls_utils import generate_self_signed_tls_certs from ray.util.queue import Queue, _QueueActor, Empty from ray.scripts.scripts import main as ray_main - try: from prometheus_client.parser import text_string_to_metric_families except (ImportError, ModuleNotFoundError): @@ -723,7 +722,7 @@ def teardown_tls(key_filepath, cert_filepath, temp_dir): os.remove(key_filepath) os.remove(cert_filepath) os.removedirs(temp_dir) - os.environ["RAY_USE_TLS"] = "0" + del os.environ["RAY_USE_TLS"] del os.environ["RAY_TLS_SERVER_CERT"] del os.environ["RAY_TLS_SERVER_KEY"] del os.environ["RAY_TLS_CA_CERT"] diff --git a/python/ray/_private/tls_utils.py b/python/ray/_private/tls_utils.py index 0dcecf4512c94..0e4746201bc73 100644 --- a/python/ray/_private/tls_utils.py +++ b/python/ray/_private/tls_utils.py @@ -71,22 +71,20 @@ def add_port_to_grpc_server(server, address): def load_certs_from_env(): - if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): - if ("RAY_TLS_SERVER_CERT" not in os.environ) or \ - ("RAY_TLS_SERVER_KEY" not in os.environ): - raise RuntimeError( - "If the environment variable RAY_USE_TLS is set to true" - "then both RAY_TLS_SERVER_CERT and RAY_TLS_SERVER_KEY must " - "also be set.") + tls_env_vars = [ + "RAY_TLS_SERVER_CERT", "RAY_TLS_SERVER_KEY", "RAY_TLS_CA_CERT" + ] + if any(v not in os.environ for v in tls_env_vars): + raise RuntimeError( + "If the environment variable RAY_USE_TLS is set to true " + "then RAY_TLS_SERVER_CERT, RAY_TLS_SERVER_KEY and " + "RAY_TLS_CA_CERT must also be set.") - with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: - server_cert_chain = f.read() - with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: - private_key = f.read() - if "RAY_TLS_CA_CERT" in os.environ: - with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: - ca_cert = f.read() - else: - ca_cert = None + with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: + server_cert_chain = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: + private_key = f.read() + with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: + ca_cert = f.read() return server_cert_chain, private_key, ca_cert diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index e81637078956c..7df4016e1a982 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -66,7 +66,8 @@ def get_store_stats(state, node_manager_address=None, node_manager_port=None): options=[ ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), - ]) + ], + ) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) reply = stub.FormatGlobalMemoryInfo( @@ -92,7 +93,8 @@ def node_stats(node_manager_address=None, options=[ ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), - ]) + ], + ) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) node_stats = stub.GetNodeStats( diff --git a/python/ray/tests/test_client_builder.py b/python/ray/tests/test_client_builder.py index 406933ced5522..c325a7188b04a 100644 --- a/python/ray/tests/test_client_builder.py +++ b/python/ray/tests/test_client_builder.py @@ -77,6 +77,7 @@ def ping(self): ray.get(a.ping.remote()) print(ray.get_runtime_context().namespace) """ + anon_driver = template.format(namespace="None") run_string_as_driver(anon_driver) # This second run will fail if the actors don't run in separate anonymous diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py index 909951f0036c7..cc208255836e2 100644 --- a/python/ray/tests/test_tls_auth.py +++ b/python/ray/tests/test_tls_auth.py @@ -11,6 +11,10 @@ logger = logging.getLogger(__name__) +@pytest.mark.skipif( + sys.platform == "darwin", + reason=( + "Cryptography (TLS dependency) doesn't install in Mac build pipeline")) @pytest.mark.parametrize("use_tls", [True], indirect=True) def test_init_with_tls(use_tls): # Run as a new process to pick up environment variables set @@ -26,6 +30,10 @@ def test_init_with_tls(use_tls): env=os.environ) +@pytest.mark.skipif( + sys.platform == "darwin", + reason=( + "Cryptography (TLS dependency) doesn't install in Mac build pipeline")) @pytest.mark.parametrize("use_tls", [True], indirect=True) def test_put_get_with_tls(use_tls): run_string_as_driver( @@ -62,6 +70,10 @@ def test_put_get_with_tls(use_tls): env=os.environ) +@pytest.mark.skipif( + sys.platform == "darwin", + reason=( + "Cryptography (TLS dependency) doesn't install in Mac build pipeline")) @pytest.mark.parametrize("use_tls", [True], indirect=True, scope="module") def test_submit_with_tls(use_tls): run_string_as_driver( @@ -97,13 +109,10 @@ def method(self, a, b=0): @pytest.mark.skipif( sys.platform == "darwin", - reason=("Cryptography doesn't install in Mac build pipeline")) + reason=( + "Cryptography (TLS dependency) doesn't install in Mac build pipeline")) @pytest.mark.parametrize("use_tls", [True], indirect=True) def test_client_connect_to_tls_server(use_tls, call_ray_start): - for k, v in os.environ.items(): - if k.startswith("RAY_"): - print("export {}={}".format(k, v)) - tls_env = os.environ.copy( ) # use_tls fixture sets TLS environment variables without_tls_env = {} From 8d204c59e7d76dbb2fc2b8c7135790b9c265c2a0 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Fri, 15 Oct 2021 09:51:21 +0100 Subject: [PATCH 47/56] Hopefully fix lint --- src/ray/rpc/grpc_server.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index edc2e6a156047..ea081ca790e35 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -71,8 +71,8 @@ void GrpcServer::Run() { std::string rootcert = ReadCert(RayConfig::instance().TLS_CA_CERT()); std::string servercert = ReadCert(RayConfig::instance().TLS_SERVER_CERT()); std::string serverkey = ReadCert(RayConfig::instance().TLS_SERVER_KEY()); - grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey.c_str(), - servercert.c_str()}; + grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey, + servercert}; grpc::SslServerCredentialsOptions ssl_opts( GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); ssl_opts.pem_root_certs = rootcert; From 74d1652eeca19743718e4b045a28ced9113caa4f Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Mon, 18 Oct 2021 09:58:55 +0100 Subject: [PATCH 48/56] Lint --- src/ray/rpc/grpc_server.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index ea081ca790e35..7c69bee606646 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -71,8 +71,7 @@ void GrpcServer::Run() { std::string rootcert = ReadCert(RayConfig::instance().TLS_CA_CERT()); std::string servercert = ReadCert(RayConfig::instance().TLS_SERVER_CERT()); std::string serverkey = ReadCert(RayConfig::instance().TLS_SERVER_KEY()); - grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey, - servercert}; + grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey, servercert}; grpc::SslServerCredentialsOptions ssl_opts( GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); ssl_opts.pem_root_certs = rootcert; From 50c2da2dd951b00e7f77c3a86c7548b32d8827ea Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 19 Oct 2021 10:23:44 +0100 Subject: [PATCH 49/56] Remove unecessary logic in ray_config_def.h --- src/ray/common/ray_config_def.h | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index d402e0e054889..bd69c4ea3bcf2 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -494,20 +494,9 @@ RAY_CONFIG(bool, scheduler_avoid_gpu_nodes, true) RAY_CONFIG(bool, runtime_env_skip_local_gc, false) /// Whether or not use TLS. -RAY_CONFIG(bool, USE_TLS, - std::getenv("RAY_USE_TLS") != nullptr && - (std::getenv("RAY_USE_TLS") == std::string("true") || - std::getenv("RAY_USE_TLS") == std::string("1"))) +RAY_CONFIG(bool, USE_TLS, false) /// Location of TLS credentials -RAY_CONFIG(std::string, TLS_SERVER_CERT, - std::getenv("RAY_TLS_SERVER_CERT") != nullptr - ? std::getenv("RAY_TLS_SERVER_CERT") - : "") -RAY_CONFIG(std::string, TLS_SERVER_KEY, - std::getenv("RAY_TLS_SERVER_KEY") != nullptr - ? std::getenv("RAY_TLS_SERVER_KEY") - : "") -RAY_CONFIG(std::string, TLS_CA_CERT, - std::getenv("RAY_TLS_CA_CERT") != nullptr ? std::getenv("RAY_TLS_CA_CERT") - : "") +RAY_CONFIG(std::string, TLS_SERVER_CERT, "") +RAY_CONFIG(std::string, TLS_SERVER_KEY, "") +RAY_CONFIG(std::string, TLS_CA_CERT, "") From 859985420d50a443e8f96f7719a6e4c1e37e0ce4 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 19 Oct 2021 10:28:28 +0100 Subject: [PATCH 50/56] Actually check for ConnectionError in test_client_connect_to_tls_server --- python/ray/tests/test_tls_auth.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py index cc208255836e2..06ee4361f50a1 100644 --- a/python/ray/tests/test_tls_auth.py +++ b/python/ray/tests/test_tls_auth.py @@ -118,15 +118,14 @@ def test_client_connect_to_tls_server(use_tls, call_ray_start): without_tls_env = {} # Attempt to connect without TLS - try: - out = run_string_as_driver( + with pytest.raises(subprocess.CalledProcessError) as exc_info: + run_string_as_driver( """ from ray.util.client import ray as ray_client ray_client.connect("localhost:10001") """, env=without_tls_env) - except subprocess.CalledProcessError as e: - assert "ConnectionError" in e.output.decode("utf-8") + assert "ConnectionError" in exc_info.value.output.decode("utf-8") # Attempt to connect with TLS out = run_string_as_driver( From 9dfd1065b31fee06a0c3e18b97186e3a65b85d08 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 19 Oct 2021 11:18:03 +0100 Subject: [PATCH 51/56] Remove unused ReadFile declaration --- src/ray/rpc/grpc_server.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 6d3ad9eb203d5..58795ed2ee4f9 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -89,8 +89,6 @@ class GrpcServer { } } - /// Read a file - std::string ReadFile(std::string filename); /// Get the port of this gRPC server. int GetPort() const { return port_; } From 4feae4504aeb47cde4c0788e9e55af3fbd5b69bd Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 19 Oct 2021 11:57:44 +0100 Subject: [PATCH 52/56] Lint --- src/ray/rpc/grpc_server.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 58795ed2ee4f9..843c0acbacf81 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -89,7 +89,6 @@ class GrpcServer { } } - /// Get the port of this gRPC server. int GetPort() const { return port_; } From 5b57d7d84fefe6b452636ee51977c5287cba622b Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Tue, 19 Oct 2021 14:33:29 +0100 Subject: [PATCH 53/56] Replace grpc.insercure_channel with ray._private.utils.init_grpc_channel in recent code --- dashboard/head.py | 2 +- python/ray/util/client/server/proxier.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dashboard/head.py b/dashboard/head.py index 2d2060224c10a..fc7b12d3a6c27 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -57,7 +57,7 @@ async def get_gcs_address_with_retry(redis_client) -> str: class GCSHealthCheckThread(threading.Thread): def __init__(self, gcs_address: str): - self.grpc_gcs_channel = grpc.insecure_channel( + self.grpc_gcs_channel = ray._private.utils.init_grpc_channel( gcs_address, options=GRPC_CHANNEL_OPTIONS) self.gcs_heartbeat_info_stub = ( gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub( diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 98ad26c93d8b4..9eebfeb12678b 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -119,7 +119,7 @@ def __init__(self, self._free_ports: List[int] = list( range(MIN_SPECIFIC_SERVER_PORT, MAX_SPECIFIC_SERVER_PORT)) - self._runtime_env_channel = grpc.insecure_channel( + self._runtime_env_channel = ray._private.utils.init_grpc_channel( f"localhost:{runtime_env_agent_port}") self._runtime_env_stub = runtime_env_agent_pb2_grpc.RuntimeEnvServiceStub( # noqa: E501 self._runtime_env_channel) From 67d32b78fb7793f9b723bf2186c003c7ae184814 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 19 Oct 2021 14:30:02 -0700 Subject: [PATCH 54/56] Trigger retest From f4032f15019a1783321076bcc2234b9c28fb0d07 Mon Sep 17 00:00:00 2001 From: Oscar Knagg Date: Wed, 20 Oct 2021 22:22:21 +0100 Subject: [PATCH 55/56] Attempt to fix windows build --- python/ray/tests/test_tls_auth.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py index 06ee4361f50a1..9b3d418c70d19 100644 --- a/python/ray/tests/test_tls_auth.py +++ b/python/ray/tests/test_tls_auth.py @@ -11,6 +11,14 @@ logger = logging.getLogger(__name__) +def build_env(): + env = os.environ.copy() + if sys.platform == "win32" and "SYSTEMROOT" not in env: + env["SYSTEMROOT"] = r"C:\Windows" + + return env + + @pytest.mark.skipif( sys.platform == "darwin", reason=( @@ -27,7 +35,7 @@ def test_init_with_tls(use_tls): finally: ray.shutdown() """, - env=os.environ) + env=build_env()) @pytest.mark.skipif( @@ -67,7 +75,7 @@ def test_put_get_with_tls(use_tls): finally: ray.shutdown() """, - env=os.environ) + env=build_env()) @pytest.mark.skipif( @@ -104,7 +112,7 @@ def method(self, a, b=0): args=["test"], kwargs={"b": 2}, num_returns=4) assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] """, - env=os.environ) + env=build_env()) @pytest.mark.skipif( @@ -113,9 +121,8 @@ def method(self, a, b=0): "Cryptography (TLS dependency) doesn't install in Mac build pipeline")) @pytest.mark.parametrize("use_tls", [True], indirect=True) def test_client_connect_to_tls_server(use_tls, call_ray_start): - tls_env = os.environ.copy( - ) # use_tls fixture sets TLS environment variables - without_tls_env = {} + tls_env = build_env() # use_tls fixture sets TLS environment variables + without_tls_env = {k: v for k, v in tls_env.items() if "TLS" not in k} # Attempt to connect without TLS with pytest.raises(subprocess.CalledProcessError) as exc_info: @@ -142,4 +149,5 @@ def test_client_connect_to_tls_server(use_tls, call_ray_start): if __name__ == "__main__": import pytest import sys + sys.exit(pytest.main(["-v", __file__])) From e74d7079cf64d661db636f2b2b220777ce0d4c3d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 20 Oct 2021 18:06:27 -0700 Subject: [PATCH 56/56] Update worker.py --- python/ray/util/client/worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 5cc3a55a9762d..5b441350cd67f 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -35,6 +35,7 @@ from ray.util.client.dataclient import DataClient from ray.util.client.logsclient import LogstreamClient from ray.util.debug import log_once +import ray._private.utils from ray._private.runtime_env.working_dir import upload_working_dir_if_needed if TYPE_CHECKING: