Skip to content

Commit

Permalink
[client] Report number of currently active clients on connect (ray-pr…
Browse files Browse the repository at this point in the history
…oject#13326)

* wip

* update

* update

* reset worker

* fix conn

* fix

* disable pycodestyle
  • Loading branch information
ericl authored Jan 11, 2021
1 parent e2b2abb commit fbb9795
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 7 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ ignore =
W503
W504
W605
I
N
avoid-escape = no
26 changes: 26 additions & 0 deletions python/ray/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,36 @@
import sys
import logging

import ray.util.client.server.server as ray_client_server
from ray.util.client import RayAPIStub
from ray.util.client.common import ClientObjectRef
from ray.util.client.ray_client_helpers import ray_start_client_server


def test_num_clients(shutdown_only):
# Tests num clients reporting; useful if you want to build an app that
# load balances clients between Ray client servers.
server = ray_client_server.serve("localhost:50051")
try:
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
assert info2["num_clients"] == 2, info2

# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)

api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
assert info3["num_clients"] == 1, info3
finally:
server.stop(0)


def test_real_ray_fallback(ray_start_regular_shared):
with ray_start_client_server() as ray:

Expand Down
18 changes: 14 additions & 4 deletions python/ray/util/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Dict, Any

import logging

Expand All @@ -24,13 +24,16 @@ def __init__(self):
def connect(self,
conn_str: str,
secure: bool = False,
metadata: List[Tuple[str, str]] = None) -> None:
metadata: List[Tuple[str, str]] = None) -> Dict[str, Any]:
"""Connect the Ray Client to a server.
Args:
conn_str: Connection string, in the form "[host]:port"
secure: Whether to use a TLS secured gRPC channel
metadata: gRPC metadata to send on connect
Returns:
Dictionary of connection info, e.g., {"num_clients": 1}.
"""
# Delay imports until connect to avoid circular imports.
from ray.util.client.worker import Worker
Expand All @@ -44,8 +47,15 @@ def connect(self,
# If we're calling a client connect specifically and we're not
# currently in client mode, ensure we are.
ray._private.client_mode_hook._explicitly_enable_client_mode()
self.client_worker = Worker(conn_str, secure=secure, metadata=metadata)
self.api.worker = self.client_worker

try:
self.client_worker = Worker(
conn_str, secure=secure, metadata=metadata)
self.api.worker = self.client_worker
return self.client_worker.connection_info()
except Exception:
self.disconnect()
raise

def disconnect(self):
"""Disconnect the Ray Client.
Expand Down
7 changes: 7 additions & 0 deletions python/ray/util/client/dataclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def _blocking_send(self, req: ray_client_pb2.DataRequest
del self.ready_data[req_id]
return data

def ConnectionInfo(self,
context=None) -> ray_client_pb2.ConnectionInfoResponse:
datareq = ray_client_pb2.DataRequest(
connection_info=ray_client_pb2.ConnectionInfoRequest())
resp = self._blocking_send(datareq)
return resp.connection_info

def GetObject(self, request: ray_client_pb2.GetRequest,
context=None) -> ray_client_pb2.GetResponse:
datareq = ray_client_pb2.DataRequest(get=request, )
Expand Down
13 changes: 13 additions & 0 deletions python/ray/util/client/server/dataservicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import grpc

from typing import TYPE_CHECKING
from threading import Lock

import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
Expand All @@ -15,6 +16,8 @@
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):
self.basic_service = basic_service
self._clients_lock = Lock()
self._num_clients = 0 # guarded by self._clients_lock

def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
Expand All @@ -24,6 +27,8 @@ def Datapath(self, request_iterator, context):
return
logger.info(f"New data connection from client {client_id}")
try:
with self._clients_lock:
self._num_clients += 1
for req in request_iterator:
resp = None
req_type = req.WhichOneof("type")
Expand All @@ -42,6 +47,12 @@ def Datapath(self, request_iterator, context):
released.append(rel)
resp = ray_client_pb2.DataResponse(
release=ray_client_pb2.ReleaseResponse(ok=released))
elif req_type == "connection_info":
with self._clients_lock:
cur_num_clients = self._num_clients
info = ray_client_pb2.ConnectionInfoResponse(
num_clients=cur_num_clients)
resp = ray_client_pb2.DataResponse(connection_info=info)
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
Expand All @@ -52,3 +63,5 @@ def Datapath(self, request_iterator, context):
finally:
logger.info(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)
with self._clients_lock:
self._num_clients -= 1
7 changes: 7 additions & 0 deletions python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def __init__(self,
self.log_client.set_logstream_level(logging.INFO)
self.closed = False

def connection_info(self):
try:
data = self.data_client.ConnectionInfo()
except grpc.RpcError as e:
raise e.details()
return {"num_clients": data.num_clients}

def get(self, vals, *, timeout: Optional[float] = None) -> Any:
to_get = []
single = False
Expand Down
5 changes: 2 additions & 3 deletions python/ray/util/client_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from ray._private.client_mode_hook import _enable_client_hook
from ray._private.client_mode_hook import _explicitly_enable_client_mode

from typing import List
from typing import Tuple
from typing import List, Tuple, Dict, Any


def connect(conn_str: str,
secure: bool = False,
metadata: List[Tuple[str, str]] = None) -> None:
metadata: List[Tuple[str, str]] = None) -> Dict[str, Any]:
if ray.is_connected():
raise RuntimeError("Ray Client is already connected. "
"Maybe you called ray.util.connect twice by "
Expand Down
10 changes: 10 additions & 0 deletions src/ray/protobuf/ray_client.proto
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ message ReleaseResponse {
repeated bool ok = 2;
}

message ConnectionInfoRequest {
}

message ConnectionInfoResponse {
// The number of data clients connected to the server, including the caller.
int32 num_clients = 1;
}

message DataRequest {
// An incrementing counter of request IDs on the Datapath,
// to match requests with responses asynchronously.
Expand All @@ -212,6 +220,7 @@ message DataRequest {
GetRequest get = 2;
PutRequest put = 3;
ReleaseRequest release = 4;
ConnectionInfoRequest connection_info = 5;
}
}

Expand All @@ -222,6 +231,7 @@ message DataResponse {
GetResponse get = 2;
PutResponse put = 3;
ReleaseResponse release = 4;
ConnectionInfoResponse connection_info = 5;
}
}

Expand Down

0 comments on commit fbb9795

Please sign in to comment.