Skip to content

Commit

Permalink
[Serve] Make the checkpint and recover only from GCS
Browse files Browse the repository at this point in the history
  • Loading branch information
sihanwang41 committed Jul 20, 2022
1 parent 0f6beca commit 9abca9d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 89 deletions.
15 changes: 2 additions & 13 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def start(
detached: bool = False,
http_options: Optional[Union[dict, HTTPOptions]] = None,
dedicated_cpu: bool = False,
_checkpoint_path: str = DEFAULT_CHECKPOINT_PATH,
**kwargs,
) -> ServeControllerClient:
"""Initialize a serve instance.
Expand Down Expand Up @@ -121,7 +120,7 @@ def start(
f'Connecting to existing Serve app in namespace "{SERVE_NAMESPACE}".'
)

_check_http_and_checkpoint_options(client, http_options, _checkpoint_path)
_check_http_and_checkpoint_options(client, http_options)
return client
except RayServeException:
pass
Expand Down Expand Up @@ -154,7 +153,6 @@ def start(
).remote(
controller_name,
http_config=http_options,
checkpoint_path=_checkpoint_path,
head_node_id=head_node_id,
detached=detached,
)
Expand Down Expand Up @@ -643,17 +641,8 @@ def build(target: Union[ClassNode, FunctionNode]) -> Application:


def _check_http_and_checkpoint_options(
client: ServeControllerClient,
http_options: Union[dict, HTTPOptions],
checkpoint_path: str,
client: ServeControllerClient, http_options: Union[dict, HTTPOptions]
) -> None:
if checkpoint_path and checkpoint_path != client.checkpoint_path:
logger.warning(
f"The new client checkpoint path '{checkpoint_path}' "
f"is different from the existing one '{client.checkpoint_path}'. "
"The new checkpoint path is ignored."
)

if http_options:
client_http_options = client.http_config
new_http_options = (
Expand Down
8 changes: 5 additions & 3 deletions python/ray/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ray.serve.config import DeploymentConfig, HTTPOptions, ReplicaConfig
from ray.serve.constants import (
CONTROL_LOOP_PERIOD_S,
DEFAULT_CHECKPOINT_PATH,
SERVE_LOGGER_NAME,
CONTROLLER_MAX_CONCURRENCY,
SERVE_ROOT_URL_ENV_KEY,
Expand Down Expand Up @@ -86,7 +87,6 @@ async def __init__(
controller_name: str,
*,
http_config: HTTPOptions,
checkpoint_path: str,
head_node_id: str,
detached: bool = False,
):
Expand All @@ -97,9 +97,11 @@ async def __init__(
# Used to read/write checkpoints.
self.ray_worker_namespace = ray.get_runtime_context().namespace
self.controller_name = controller_name
self.checkpoint_path = checkpoint_path
self.checkpoint_path = DEFAULT_CHECKPOINT_PATH
kv_store_namespace = f"{self.controller_name}-{self.ray_worker_namespace}"
self.kv_store = make_kv_store(checkpoint_path, namespace=kv_store_namespace)
self.kv_store = make_kv_store(
self.checkpoint_path, namespace=kv_store_namespace
)
self.snapshot_store = RayInternalKVStore(namespace=kv_store_namespace)

# Dictionary of deployment_name -> proxy_name -> queue length.
Expand Down
11 changes: 1 addition & 10 deletions python/ray/serve/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,11 @@ def cli():
help="Location of the HTTP servers. Defaults to HeadOnly.",
)
@click.option(
"--checkpoint-path",
default=DEFAULT_CHECKPOINT_PATH,
required=False,
type=str,
hidden=True,
)
def start(
address,
http_host,
http_port,
http_location,
checkpoint_path,
):
def start(address, http_host, http_port, http_location):
ray.init(
address=address,
namespace=SERVE_NAMESPACE,
Expand All @@ -103,7 +95,6 @@ def start(
port=http_port,
location=http_location,
),
_checkpoint_path=checkpoint_path,
)


Expand Down
66 changes: 3 additions & 63 deletions python/ray/serve/tests/test_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,69 +568,11 @@ class A:
)


def test_local_store_recovery(ray_shutdown):
_, tmp_path = mkstemp()

@serve.deployment
def hello(_):
return "hello"

# https://github.com/ray-project/ray/issues/19987
@serve.deployment
def world(_):
return "world"

def check(name, raise_error=False):
try:
resp = requests.get(f"http:https://localhost:8000/{name}")
assert resp.text == name
return True
except Exception as e:
if raise_error:
raise e
return False

# https://github.com/ray-project/ray/issues/20159
# https://github.com/ray-project/ray/issues/20158
def clean_up_leaked_processes():
import psutil

for proc in psutil.process_iter():
try:
cmdline = " ".join(proc.cmdline())
if "ray::" in cmdline:
print(f"Kill {proc} {cmdline}")
proc.kill()
except Exception:
pass

def crash():
subprocess.call(["ray", "stop", "--force"])
clean_up_leaked_processes()
ray.shutdown()
serve.shutdown()

serve.start(detached=True, _checkpoint_path=f"file:https://{tmp_path}")
hello.deploy()
world.deploy()
assert check("hello", raise_error=True)
assert check("world", raise_error=True)
crash()

# Simulate a crash

serve.start(detached=True, _checkpoint_path=f"file:https://{tmp_path}")
wait_for_condition(lambda: check("hello"))
# wait_for_condition(lambda: check("world"))
crash()


@pytest.mark.parametrize("ray_start_with_dashboard", [{"num_cpus": 4}], indirect=True)
def test_snapshot_always_written_to_internal_kv(
ray_start_with_dashboard, ray_shutdown # noqa: F811
):
# https://github.com/ray-project/ray/issues/19752
_, tmp_path = mkstemp()

@serve.deployment()
def hello(_):
Expand All @@ -644,7 +586,7 @@ def check():
except Exception:
return False

serve.start(detached=True, _checkpoint_path=f"file:https://{tmp_path}")
serve.start(detached=True)
serve.run(hello.bind())
check()

Expand Down Expand Up @@ -687,12 +629,10 @@ def emit(self, record):

# create a different config
test_http = dict(host="127.1.1.8", port=new_port())
_, tmp_path = mkstemp()
test_ckpt = f"file:https://{tmp_path}"

serve.start(detached=True, http_options=test_http, _checkpoint_path=test_ckpt)
serve.start(detached=True, http_options=test_http)

for test_config, msg in zip([[test_ckpt], ["host", "port"]], warning_msg):
for test_config, msg in zip([["host", "port"]], warning_msg):
for test_msg in test_config:
if "Autoscaling metrics pusher thread" in msg:
continue
Expand Down

0 comments on commit 9abca9d

Please sign in to comment.