From 53575b33ad8ed50500a45be1543cb59274d33467 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Tue, 19 Dec 2023 14:10:19 -0600 Subject: [PATCH] [serve] Refactor `RayServeWrappedReplica` to not be dynamically generated (#42014) * fix Signed-off-by: Edward Oakes * oops Signed-off-by: Edward Oakes --------- Signed-off-by: Edward Oakes --- python/ray/serve/_private/constants.py | 4 + python/ray/serve/_private/deployment_info.py | 18 +- python/ray/serve/_private/replica.py | 756 +++++++++---------- 3 files changed, 379 insertions(+), 399 deletions(-) diff --git a/python/ray/serve/_private/constants.py b/python/ray/serve/_private/constants.py index 684608acd3198..529d3237ce895 100644 --- a/python/ray/serve/_private/constants.py +++ b/python/ray/serve/_private/constants.py @@ -36,6 +36,10 @@ #: Max concurrency ASYNC_CONCURRENCY = int(1e6) +# Concurrency group used for replica operations that cannot be blocked by user code +# (e.g., health checks and fetching queue length). +REPLICA_CONTROL_PLANE_CONCURRENCY_GROUP = "control_plane" + # How often to call the control loop on the controller. CONTROL_LOOP_PERIOD_S = 0.1 diff --git a/python/ray/serve/_private/deployment_info.py b/python/ray/serve/_private/deployment_info.py index 453d7b840c92a..0832bde5be946 100644 --- a/python/ray/serve/_private/deployment_info.py +++ b/python/ray/serve/_private/deployment_info.py @@ -4,6 +4,7 @@ from ray.serve._private.autoscaling_policy import BasicAutoscalingPolicy from ray.serve._private.common import TargetCapacityDirection from ray.serve._private.config import DeploymentConfig, ReplicaConfig +from ray.serve._private.constants import REPLICA_CONTROL_PLANE_CONCURRENCY_GROUP from ray.serve.generated.serve_pb2 import DeploymentInfo as DeploymentInfoProto from ray.serve.generated.serve_pb2 import ( TargetCapacityDirection as TargetCapacityDirectionProto, @@ -11,9 +12,8 @@ # Concurrency group used for operations that cannot be blocked by user code # (e.g., health checks and fetching queue length). -CONTROL_PLANE_CONCURRENCY_GROUP = "control_plane" REPLICA_DEFAULT_ACTOR_OPTIONS = { - "concurrency_groups": {CONTROL_PLANE_CONCURRENCY_GROUP: 1} + "concurrency_groups": {REPLICA_CONTROL_PLANE_CONCURRENCY_GROUP: 1} } @@ -101,14 +101,20 @@ def set_target_capacity( @property def actor_def(self): - # Delayed import as replica depends on this file. - from ray.serve._private.replica import create_replica_wrapper - if self._cached_actor_def is None: assert self.actor_name is not None + # Break circular import :(. + from ray.serve._private.replica import RayServeWrappedReplica + + # Dynamically create a new class with custom name here so Ray picks it up + # correctly in actor metadata table and observability stack. self._cached_actor_def = ray.remote(**REPLICA_DEFAULT_ACTOR_OPTIONS)( - create_replica_wrapper(self.actor_name) + type( + self.actor_name, + (RayServeWrappedReplica,), + dict(RayServeWrappedReplica.__dict__), + ) ) return self._cached_actor_def diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 4dfa61a746d77..3b91a99298a63 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -38,11 +38,11 @@ RAY_SERVE_GAUGE_METRIC_SET_PERIOD_S, RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_PERIOD_S, RECONFIGURE_METHOD, + REPLICA_CONTROL_PLANE_CONCURRENCY_GROUP, SERVE_CONTROLLER_NAME, SERVE_LOGGER_NAME, SERVE_NAMESPACE, ) -from ray.serve._private.deployment_info import CONTROL_PLANE_CONCURRENCY_GROUP from ray.serve._private.http_util import ( ASGIAppReplicaWrapper, ASGIMessageQueue, @@ -72,435 +72,405 @@ logger = logging.getLogger(SERVE_LOGGER_NAME) -def create_replica_wrapper(actor_class_name: str): - """Creates a replica class wrapping the provided function or class. - - This approach is picked over inheritance to avoid conflict between user - provided class and the RayServeReplica class. - """ - - # TODO(architkulkarni): Add type hints after upgrading cloudpickle - class RayServeWrappedReplica(object): - async def __init__( - self, - deployment_name, - replica_tag, - serialized_deployment_def: bytes, - serialized_init_args: bytes, - serialized_init_kwargs: bytes, - deployment_config_proto_bytes: bytes, - version: DeploymentVersion, - app_name: str = None, - ): - self._replica_tag = replica_tag - deployment_config = DeploymentConfig.from_proto_bytes( - deployment_config_proto_bytes +class RayServeWrappedReplica: + async def __init__( + self, + deployment_name: str, + replica_tag: str, + serialized_deployment_def: bytes, + serialized_init_args: bytes, + serialized_init_kwargs: bytes, + deployment_config_proto_bytes: bytes, + version: DeploymentVersion, + app_name: str = None, + ): + self._replica_tag = replica_tag + deployment_config = DeploymentConfig.from_proto_bytes( + deployment_config_proto_bytes + ) + if deployment_config.logging_config is None: + logging_config = LoggingConfig() + else: + logging_config = LoggingConfig(**deployment_config.logging_config) + + self._configure_logger_and_profilers(replica_tag, logging_config) + + self._event_loop = get_or_create_event_loop() + + deployment_def = cloudpickle.loads(serialized_deployment_def) + + if isinstance(deployment_def, str): + import_path = deployment_def + module_name, attr_name = parse_import_path(import_path) + deployment_def = getattr(import_module(module_name), attr_name) + # For ray or serve decorated class or function, strip to return + # original body + if isinstance(deployment_def, RemoteFunction): + deployment_def = deployment_def._function + elif isinstance(deployment_def, ActorClass): + deployment_def = deployment_def.__ray_metadata__.modified_class + elif isinstance(deployment_def, Deployment): + logger.warning( + f'The import path "{import_path}" contains a ' + "decorated Serve deployment. The decorator's settings " + "are ignored when deploying via import path." + ) + deployment_def = deployment_def.func_or_class + + init_args = cloudpickle.loads(serialized_init_args) + init_kwargs = cloudpickle.loads(serialized_init_kwargs) + + if inspect.isfunction(deployment_def): + is_function = True + elif inspect.isclass(deployment_def): + is_function = False + else: + assert False, ( + "deployment_def must be function, class, or " + "corresponding import path. Instead, it's type was " + f"{type(deployment_def)}." ) - if deployment_config.logging_config is None: - logging_config = LoggingConfig() - else: - logging_config = LoggingConfig(**deployment_config.logging_config) - - self._configure_logger_and_profilers(replica_tag, logging_config) - - self._event_loop = get_or_create_event_loop() - - deployment_def = cloudpickle.loads(serialized_deployment_def) - - if isinstance(deployment_def, str): - import_path = deployment_def - module_name, attr_name = parse_import_path(import_path) - deployment_def = getattr(import_module(module_name), attr_name) - # For ray or serve decorated class or function, strip to return - # original body - if isinstance(deployment_def, RemoteFunction): - deployment_def = deployment_def._function - elif isinstance(deployment_def, ActorClass): - deployment_def = deployment_def.__ray_metadata__.modified_class - elif isinstance(deployment_def, Deployment): - logger.warning( - f'The import path "{import_path}" contains a ' - "decorated Serve deployment. The decorator's settings " - "are ignored when deploying via import path." - ) - deployment_def = deployment_def.func_or_class - init_args = cloudpickle.loads(serialized_init_args) - init_kwargs = cloudpickle.loads(serialized_init_kwargs) + # Set the controller name so that serve.connect() in the user's + # code will connect to the instance that this deployment is running + # in. + ray.serve.context._set_internal_replica_context( + app_name=app_name, + deployment=deployment_name, + replica_tag=replica_tag, + servable_object=None, + ) + + controller_handle = ray.get_actor( + SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE + ) + + # Indicates whether the replica has finished initializing. + self._initialized = False + + # This closure initializes user code and finalizes replica + # startup. By splitting the initialization step like this, + # we can already access this actor before the user code + # has finished initializing. + # The supervising state manager can then wait + # for allocation of this replica by using the `is_allocated` + # method. After that, it calls `reconfigure` to trigger + # user code initialization. + async def initialize_replica(): + logger.info( + "Started initializing replica.", + extra={"log_to_stderr": False}, + ) - if inspect.isfunction(deployment_def): - is_function = True - elif inspect.isclass(deployment_def): - is_function = False + if is_function: + _callable = deployment_def else: - assert False, ( - "deployment_def must be function, class, or " - "corresponding import path. Instead, it's type was " - f"{type(deployment_def)}." - ) + # This allows deployments to define an async __init__ + # method (mostly used for testing). + _callable = deployment_def.__new__(deployment_def) + await sync_to_async(_callable.__init__)(*init_args, **init_kwargs) + + if isinstance(_callable, ASGIAppReplicaWrapper): + await _callable._run_asgi_lifespan_startup() - # Set the controller name so that serve.connect() in the user's - # code will connect to the instance that this deployment is running - # in. + # Setting the context again to update the servable_object. ray.serve.context._set_internal_replica_context( app_name=app_name, deployment=deployment_name, replica_tag=replica_tag, - servable_object=None, + servable_object=_callable, ) - controller_handle = ray.get_actor( - SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE + self.replica = RayServeReplica( + _callable, + deployment_name, + replica_tag, + deployment_config.autoscaling_config, + version, + is_function, + controller_handle, + app_name, + ) + self._initialized = True + logger.info( + "Finished initializing replica.", + extra={"log_to_stderr": False}, ) - # Indicates whether the replica has finished initializing. - self._initialized = False - - # This closure initializes user code and finalizes replica - # startup. By splitting the initialization step like this, - # we can already access this actor before the user code - # has finished initializing. - # The supervising state manager can then wait - # for allocation of this replica by using the `is_allocated` - # method. After that, it calls `reconfigure` to trigger - # user code initialization. - async def initialize_replica(): - logger.info( - "Started initializing replica.", - extra={"log_to_stderr": False}, - ) + # Is it fine that replica is None here? + # Should we add a check in all methods that use self.replica + # or, alternatively, create an async get_replica() method? + self.replica = None + self._initialize_replica = initialize_replica - if is_function: - _callable = deployment_def - else: - # This allows deployments to define an async __init__ - # method (mostly used for testing). - _callable = deployment_def.__new__(deployment_def) - await sync_to_async(_callable.__init__)(*init_args, **init_kwargs) - - if isinstance(_callable, ASGIAppReplicaWrapper): - await _callable._run_asgi_lifespan_startup() - - # Setting the context again to update the servable_object. - ray.serve.context._set_internal_replica_context( - app_name=app_name, - deployment=deployment_name, - replica_tag=replica_tag, - servable_object=_callable, - ) + # Used to guard `initialize_replica` so that it isn't called twice. + self._replica_init_lock = asyncio.Lock() - self.replica = RayServeReplica( - _callable, - deployment_name, - replica_tag, - deployment_config.autoscaling_config, - version, - is_function, - controller_handle, - app_name, - ) - self._initialized = True - logger.info( - "Finished initializing replica.", - extra={"log_to_stderr": False}, - ) + def _configure_logger_and_profilers( + self, replica_tag: ReplicaTag, logging_config: LoggingConfig + ): + replica_name = ReplicaName.from_replica_tag(replica_tag) + if replica_name.app_name: + component_name = f"{replica_name.app_name}_{replica_name.deployment_name}" + else: + component_name = f"{replica_name.deployment_name}" + component_id = replica_name.replica_suffix + + configure_component_logger( + component_type=ServeComponentType.REPLICA, + component_name=component_name, + component_id=component_id, + logging_config=logging_config, + ) + configure_component_memory_profiler( + component_type=ServeComponentType.REPLICA, + component_name=component_name, + component_id=component_id, + ) + self.cpu_profiler, self.cpu_profiler_log = configure_component_cpu_profiler( + component_type=ServeComponentType.REPLICA, + component_name=component_name, + component_id=component_id, + ) - # Is it fine that replica is None here? - # Should we add a check in all methods that use self.replica - # or, alternatively, create an async get_replica() method? - self.replica = None - self._initialize_replica = initialize_replica - - # Used to guard `initialize_replica` so that it isn't called twice. - self._replica_init_lock = asyncio.Lock() - - def _configure_logger_and_profilers( - self, replica_tag: ReplicaTag, logging_config: LoggingConfig - ): - replica_name = ReplicaName.from_replica_tag(replica_tag) - if replica_name.app_name: - component_name = ( - f"{replica_name.app_name}_{replica_name.deployment_name}" - ) - else: - component_name = f"{replica_name.deployment_name}" - component_id = replica_name.replica_suffix + @ray.method(concurrency_group=REPLICA_CONTROL_PLANE_CONCURRENCY_GROUP) + def get_num_ongoing_requests(self) -> int: + """Fetch the number of ongoing requests at this replica (queue length). - configure_component_logger( - component_type=ServeComponentType.REPLICA, - component_name=component_name, - component_id=component_id, - logging_config=logging_config, + This runs on a separate thread (using a Ray concurrency group) so it will + not be blocked by user code. + """ + return self.replica.get_num_pending_and_running_requests() + + async def handle_request( + self, + pickled_request_metadata: bytes, + *request_args, + **request_kwargs, + ) -> Tuple[bytes, Any]: + request_metadata = pickle.loads(pickled_request_metadata) + if request_metadata.is_grpc_request: + # Ensure the request args are a single gRPCRequest object. + assert len(request_args) == 1 and isinstance(request_args[0], gRPCRequest) + result = await self.replica.call_user_method_grpc_unary( + request_metadata=request_metadata, request=request_args[0] ) - configure_component_memory_profiler( - component_type=ServeComponentType.REPLICA, - component_name=component_name, - component_id=component_id, + else: + result = await self.replica.call_user_method( + request_metadata, request_args, request_kwargs ) - self.cpu_profiler, self.cpu_profiler_log = configure_component_cpu_profiler( - component_type=ServeComponentType.REPLICA, - component_name=component_name, - component_id=component_id, + + return result + + async def _handle_http_request_generator( + self, + request_metadata: RequestMetadata, + request: StreamingHTTPRequest, + ) -> AsyncGenerator[Message, None]: + """Handle an HTTP request and stream ASGI messages to the caller. + + This is a generator that yields ASGI-compliant messages sent by user code + via an ASGI send interface. + """ + receiver_task = None + call_user_method_task = None + wait_for_message_task = None + try: + receiver = ASGIReceiveProxy( + request_metadata.request_id, request.http_proxy_handle + ) + receiver_task = self._event_loop.create_task( + receiver.fetch_until_disconnect() ) - @ray.method(concurrency_group=CONTROL_PLANE_CONCURRENCY_GROUP) - def get_num_ongoing_requests(self) -> int: - """Fetch the number of ongoing requests at this replica (queue length). - - This runs on a separate thread (using a Ray concurrency group) so it will - not be blocked by user code. - """ - return self.replica.get_num_pending_and_running_requests() - - async def handle_request( - self, - pickled_request_metadata: bytes, - *request_args, - **request_kwargs, - ) -> Tuple[bytes, Any]: - request_metadata = pickle.loads(pickled_request_metadata) - if request_metadata.is_grpc_request: - # Ensure the request args are a single gRPCRequest object. - assert len(request_args) == 1 and isinstance( - request_args[0], gRPCRequest - ) - result = await self.replica.call_user_method_grpc_unary( - request_metadata=request_metadata, request=request_args[0] - ) - else: - result = await self.replica.call_user_method( + scope = pickle.loads(request.pickled_asgi_scope) + asgi_queue_send = ASGIMessageQueue() + request_args = (scope, receiver, asgi_queue_send) + request_kwargs = {} + + # Handle the request in a background asyncio.Task. It's expected that + # this task will use the provided ASGI send interface to send its HTTP + # the response. We will poll for the sent messages and yield them back + # to the caller. + call_user_method_task = self._event_loop.create_task( + self.replica.call_user_method( request_metadata, request_args, request_kwargs ) + ) - return result - - async def _handle_http_request_generator( - self, - request_metadata: RequestMetadata, - request: StreamingHTTPRequest, - ) -> AsyncGenerator[Message, None]: - """Handle an HTTP request and stream ASGI messages to the caller. - - This is a generator that yields ASGI-compliant messages sent by user code - via an ASGI send interface. - """ - receiver_task = None - call_user_method_task = None - wait_for_message_task = None - try: - receiver = ASGIReceiveProxy( - request_metadata.request_id, request.http_proxy_handle - ) - receiver_task = self._event_loop.create_task( - receiver.fetch_until_disconnect() + while True: + wait_for_message_task = self._event_loop.create_task( + asgi_queue_send.wait_for_message() ) - - scope = pickle.loads(request.pickled_asgi_scope) - asgi_queue_send = ASGIMessageQueue() - request_args = (scope, receiver, asgi_queue_send) - request_kwargs = {} - - # Handle the request in a background asyncio.Task. It's expected that - # this task will use the provided ASGI send interface to send its HTTP - # the response. We will poll for the sent messages and yield them back - # to the caller. - call_user_method_task = self._event_loop.create_task( - self.replica.call_user_method( - request_metadata, request_args, request_kwargs - ) + done, _ = await asyncio.wait( + [call_user_method_task, wait_for_message_task], + return_when=asyncio.FIRST_COMPLETED, ) + # Consume and yield all available messages in the queue. + # The messages are batched into a list to avoid unnecessary RPCs and + # we use vanilla pickle because it's faster than cloudpickle and we + # know it's safe for these messages containing primitive types. + yield pickle.dumps(asgi_queue_send.get_messages_nowait()) + + # Exit once `call_user_method` has finished. In this case, all + # messages must have already been sent. + if call_user_method_task in done: + break + + e = call_user_method_task.exception() + if e is not None: + raise e from None + finally: + if receiver_task is not None: + receiver_task.cancel() - while True: - wait_for_message_task = self._event_loop.create_task( - asgi_queue_send.wait_for_message() - ) - done, _ = await asyncio.wait( - [call_user_method_task, wait_for_message_task], - return_when=asyncio.FIRST_COMPLETED, - ) - # Consume and yield all available messages in the queue. - # The messages are batched into a list to avoid unnecessary RPCs and - # we use vanilla pickle because it's faster than cloudpickle and we - # know it's safe for these messages containing primitive types. - yield pickle.dumps(asgi_queue_send.get_messages_nowait()) - - # Exit once `call_user_method` has finished. In this case, all - # messages must have already been sent. - if call_user_method_task in done: - break - - e = call_user_method_task.exception() - if e is not None: - raise e from None - finally: - if receiver_task is not None: - receiver_task.cancel() - - if ( - call_user_method_task is not None - and not call_user_method_task.done() - ): - call_user_method_task.cancel() + if call_user_method_task is not None and not call_user_method_task.done(): + call_user_method_task.cancel() - if ( - wait_for_message_task is not None - and not wait_for_message_task.done() - ): - wait_for_message_task.cancel() - - async def handle_request_streaming( - self, - pickled_request_metadata: bytes, - *request_args, - **request_kwargs, - ) -> AsyncGenerator[Any, None]: - """Generator that is the entrypoint for all `stream=True` handle calls.""" - request_metadata = pickle.loads(pickled_request_metadata) - if request_metadata.is_grpc_request: - # Ensure the request args are a single gRPCRequest object. - assert len(request_args) == 1 and isinstance( - request_args[0], gRPCRequest - ) - generator = self.replica.call_user_method_with_grpc_unary_stream( - request_metadata, request_args[0] - ) - elif request_metadata.is_http_request: - assert len(request_args) == 1 and isinstance( - request_args[0], StreamingHTTPRequest - ) - generator = self._handle_http_request_generator( - request_metadata, request_args[0] - ) - else: - generator = self.replica.call_user_method_generator( - request_metadata, request_args, request_kwargs - ) + if wait_for_message_task is not None and not wait_for_message_task.done(): + wait_for_message_task.cancel() - async for result in generator: - yield result - - async def handle_request_from_java( - self, - proto_request_metadata: bytes, - *request_args, - **request_kwargs, - ) -> Any: - from ray.serve.generated.serve_pb2 import ( - RequestMetadata as RequestMetadataProto, + async def handle_request_streaming( + self, + pickled_request_metadata: bytes, + *request_args, + **request_kwargs, + ) -> AsyncGenerator[Any, None]: + """Generator that is the entrypoint for all `stream=True` handle calls.""" + request_metadata = pickle.loads(pickled_request_metadata) + if request_metadata.is_grpc_request: + # Ensure the request args are a single gRPCRequest object. + assert len(request_args) == 1 and isinstance(request_args[0], gRPCRequest) + generator = self.replica.call_user_method_with_grpc_unary_stream( + request_metadata, request_args[0] ) - - proto = RequestMetadataProto.FromString(proto_request_metadata) - request_metadata: RequestMetadata = RequestMetadata( - proto.request_id, - proto.endpoint, - call_method=proto.call_method, - multiplexed_model_id=proto.multiplexed_model_id, - route=proto.route, + elif request_metadata.is_http_request: + assert len(request_args) == 1 and isinstance( + request_args[0], StreamingHTTPRequest ) - request_args = request_args[0] - return await self.replica.call_user_method( + generator = self._handle_http_request_generator( + request_metadata, request_args[0] + ) + else: + generator = self.replica.call_user_method_generator( request_metadata, request_args, request_kwargs ) - async def is_allocated(self) -> str: - """poke the replica to check whether it's alive. - - When calling this method on an ActorHandle, it will complete as - soon as the actor has started running. We use this mechanism to - detect when a replica has been allocated a worker slot. - At this time, the replica can transition from PENDING_ALLOCATION - to PENDING_INITIALIZATION startup state. - - Returns: - The PID, actor ID, node ID, node IP, and log filepath id of the replica. - """ - - return ( - os.getpid(), - ray.get_runtime_context().get_actor_id(), - ray.get_runtime_context().get_worker_id(), - ray.get_runtime_context().get_node_id(), - ray.util.get_node_ip_address(), - get_component_logger_file_path(), + async for result in generator: + yield result + + async def handle_request_from_java( + self, + proto_request_metadata: bytes, + *request_args, + **request_kwargs, + ) -> Any: + from ray.serve.generated.serve_pb2 import ( + RequestMetadata as RequestMetadataProto, + ) + + proto = RequestMetadataProto.FromString(proto_request_metadata) + request_metadata: RequestMetadata = RequestMetadata( + proto.request_id, + proto.endpoint, + call_method=proto.call_method, + multiplexed_model_id=proto.multiplexed_model_id, + route=proto.route, + ) + request_args = request_args[0] + return await self.replica.call_user_method( + request_metadata, request_args, request_kwargs + ) + + async def is_allocated(self) -> str: + """poke the replica to check whether it's alive. + + When calling this method on an ActorHandle, it will complete as + soon as the actor has started running. We use this mechanism to + detect when a replica has been allocated a worker slot. + At this time, the replica can transition from PENDING_ALLOCATION + to PENDING_INITIALIZATION startup state. + + Returns: + The PID, actor ID, node ID, node IP, and log filepath id of the replica. + """ + + return ( + os.getpid(), + ray.get_runtime_context().get_actor_id(), + ray.get_runtime_context().get_worker_id(), + ray.get_runtime_context().get_node_id(), + ray.util.get_node_ip_address(), + get_component_logger_file_path(), + ) + + async def initialize_and_get_metadata( + self, + deployment_config: DeploymentConfig = None, + _after: Optional[Any] = None, + ) -> Tuple[DeploymentConfig, DeploymentVersion]: + # Unused `_after` argument is for scheduling: passing an ObjectRef + # allows delaying this call until after the `_after` call has returned. + try: + # Ensure that initialization is only performed once. + # When controller restarts, it will call this method again. + async with self._replica_init_lock: + if not self._initialized: + await self._initialize_replica() + if deployment_config: + await self.replica.update_user_config(deployment_config.user_config) + + # A new replica should not be considered healthy until it passes + # an initial health check. If an initial health check fails, + # consider it an initialization failure. + await self.check_health() + return await self._get_metadata() + except Exception: + raise RuntimeError(traceback.format_exc()) from None + + async def reconfigure( + self, + deployment_config: DeploymentConfig, + ) -> Tuple[DeploymentConfig, DeploymentVersion]: + try: + await self.replica.reconfigure(deployment_config) + return await self._get_metadata() + except Exception: + raise RuntimeError(traceback.format_exc()) from None + + async def _get_metadata( + self, + ) -> Tuple[DeploymentConfig, DeploymentVersion]: + return self.replica.version.deployment_config, self.replica.version + + def _save_cpu_profile_data(self) -> str: + """Saves CPU profiling data, if CPU profiling is enabled. + + Logs a warning if CPU profiling is disabled. + """ + + if self.cpu_profiler is not None: + import marshal + + self.cpu_profiler.snapshot_stats() + with open(self.cpu_profiler_log, "wb") as f: + marshal.dump(self.cpu_profiler.stats, f) + logger.info(f'Saved CPU profile data to file "{self.cpu_profiler_log}"') + return self.cpu_profiler_log + else: + logger.error( + "Attempted to save CPU profile data, but failed because no " + "CPU profiler was running! Enable CPU profiling by enabling " + "the RAY_SERVE_ENABLE_CPU_PROFILING env var." ) - async def initialize_and_get_metadata( - self, - deployment_config: DeploymentConfig = None, - _after: Optional[Any] = None, - ) -> Tuple[DeploymentConfig, DeploymentVersion]: - # Unused `_after` argument is for scheduling: passing an ObjectRef - # allows delaying this call until after the `_after` call has returned. - try: - # Ensure that initialization is only performed once. - # When controller restarts, it will call this method again. - async with self._replica_init_lock: - if not self._initialized: - await self._initialize_replica() - if deployment_config: - await self.replica.update_user_config( - deployment_config.user_config - ) - - # A new replica should not be considered healthy until it passes - # an initial health check. If an initial health check fails, - # consider it an initialization failure. - await self.check_health() - return await self._get_metadata() - except Exception: - raise RuntimeError(traceback.format_exc()) from None - - async def reconfigure( - self, - deployment_config: DeploymentConfig, - ) -> Tuple[DeploymentConfig, DeploymentVersion]: - try: - await self.replica.reconfigure(deployment_config) - return await self._get_metadata() - except Exception: - raise RuntimeError(traceback.format_exc()) from None - - async def _get_metadata( - self, - ) -> Tuple[DeploymentConfig, DeploymentVersion]: - return self.replica.version.deployment_config, self.replica.version - - def _save_cpu_profile_data(self) -> str: - """Saves CPU profiling data, if CPU profiling is enabled. - - Logs a warning if CPU profiling is disabled. - """ - - if self.cpu_profiler is not None: - import marshal - - self.cpu_profiler.snapshot_stats() - with open(self.cpu_profiler_log, "wb") as f: - marshal.dump(self.cpu_profiler.stats, f) - logger.info(f'Saved CPU profile data to file "{self.cpu_profiler_log}"') - return self.cpu_profiler_log - else: - logger.error( - "Attempted to save CPU profile data, but failed because no " - "CPU profiler was running! Enable CPU profiling by enabling " - "the RAY_SERVE_ENABLE_CPU_PROFILING env var." - ) + async def prepare_for_shutdown(self): + if self.replica is not None: + return await self.replica.prepare_for_shutdown() - async def prepare_for_shutdown(self): - if self.replica is not None: - return await self.replica.prepare_for_shutdown() - - @ray.method(concurrency_group=CONTROL_PLANE_CONCURRENCY_GROUP) - async def check_health(self): - await self.replica.check_health() - - # Dynamically create a new class with custom name here so Ray picks it up - # correctly in actor metadata table and observability stack. - return type( - actor_class_name, - (RayServeWrappedReplica,), - dict(RayServeWrappedReplica.__dict__), - ) + @ray.method(concurrency_group=REPLICA_CONTROL_PLANE_CONCURRENCY_GROUP) + async def check_health(self): + await self.replica.check_health() class RayServeReplica: