Skip to content

Commit

Permalink
Convert job_manager to be async (ray-project#27123)
Browse files Browse the repository at this point in the history
Updates jobs api
Updates snapshot api
Updates state api

Increases jobs api version to 2

Signed-off-by: Alan Guo [email protected]

Why are these changes needed?
follow-up for ray-project#25902 (comment)
  • Loading branch information
alanwguo committed Aug 6, 2022
1 parent a82af86 commit 326b5bd
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 159 deletions.
51 changes: 32 additions & 19 deletions dashboard/modules/job/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import pickle
import time
from dataclasses import dataclass, replace
Expand All @@ -6,12 +7,10 @@
from typing import Any, Dict, Optional, Tuple

from ray._private import ray_constants
from ray._private.gcs_utils import GcsAioClient
from ray._private.runtime_env.packaging import parse_uri
from ray.experimental.internal_kv import (
_internal_kv_get,
_internal_kv_initialized,
_internal_kv_list,
_internal_kv_put,
)

# NOTE(edoakes): these constants should be considered a public API because
Expand Down Expand Up @@ -97,30 +96,34 @@ class JobInfoStorageClient:
JOB_DATA_KEY_PREFIX = f"{ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX}job_info_"
JOB_DATA_KEY = f"{JOB_DATA_KEY_PREFIX}{{job_id}}"

def __init__(self):
def __init__(self, gcs_aio_client: GcsAioClient):
self._gcs_aio_client = gcs_aio_client
assert _internal_kv_initialized()

def put_info(self, job_id: str, data: JobInfo):
_internal_kv_put(
self.JOB_DATA_KEY.format(job_id=job_id),
async def put_info(self, job_id: str, data: JobInfo):
await self._gcs_aio_client.internal_kv_put(
self.JOB_DATA_KEY.format(job_id=job_id).encode(),
pickle.dumps(data),
True,
namespace=ray_constants.KV_NAMESPACE_JOB,
)

def get_info(self, job_id: str) -> Optional[JobInfo]:
pickled_info = _internal_kv_get(
self.JOB_DATA_KEY.format(job_id=job_id),
async def get_info(self, job_id: str) -> Optional[JobInfo]:
pickled_info = await self._gcs_aio_client.internal_kv_get(
self.JOB_DATA_KEY.format(job_id=job_id).encode(),
namespace=ray_constants.KV_NAMESPACE_JOB,
)
if pickled_info is None:
return None
else:
return pickle.loads(pickled_info)

def put_status(self, job_id: str, status: JobStatus, message: Optional[str] = None):
async def put_status(
self, job_id: str, status: JobStatus, message: Optional[str] = None
):
"""Puts or updates job status. Sets end_time if status is terminal."""

old_info = self.get_info(job_id)
old_info = await self.get_info(job_id)

if old_info is not None:
if status != old_info.status and old_info.status.is_terminal():
Expand All @@ -134,18 +137,18 @@ def put_status(self, job_id: str, status: JobStatus, message: Optional[str] = No
if status.is_terminal():
new_info.end_time = int(time.time() * 1000)

self.put_info(job_id, new_info)
await self.put_info(job_id, new_info)

def get_status(self, job_id: str) -> Optional[JobStatus]:
job_info = self.get_info(job_id)
async def get_status(self, job_id: str) -> Optional[JobStatus]:
job_info = await self.get_info(job_id)
if job_info is None:
return None
else:
return job_info.status

def get_all_jobs(self) -> Dict[str, JobInfo]:
raw_job_ids_with_prefixes = _internal_kv_list(
self.JOB_DATA_KEY_PREFIX, namespace=ray_constants.KV_NAMESPACE_JOB
async def get_all_jobs(self) -> Dict[str, JobInfo]:
raw_job_ids_with_prefixes = await self._gcs_aio_client.internal_kv_keys(
self.JOB_DATA_KEY_PREFIX.encode(), namespace=ray_constants.KV_NAMESPACE_JOB
)
job_ids_with_prefixes = [
job_id.decode() for job_id in raw_job_ids_with_prefixes
Expand All @@ -156,7 +159,17 @@ def get_all_jobs(self) -> Dict[str, JobInfo]:
self.JOB_DATA_KEY_PREFIX
), "Unexpected format for internal_kv key for Job submission"
job_ids.append(job_id_with_prefix[len(self.JOB_DATA_KEY_PREFIX) :])
return {job_id: self.get_info(job_id) for job_id in job_ids}

async def get_job_info(job_id: str):
job_info = await self.get_info(job_id)
return job_id, job_info

return {
job_id: job_info
for job_id, job_info in await asyncio.gather(
*[get_job_info(job_id) for job_id in job_ids]
)
}


def uri_to_http_components(package_uri: str) -> Tuple[str, str]:
Expand Down
16 changes: 4 additions & 12 deletions dashboard/modules/job/job_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asyncio
import concurrent
import dataclasses
import json
import logging
Expand Down Expand Up @@ -54,7 +52,6 @@ def __init__(self, dashboard_head):
self._dashboard_head = dashboard_head
self._job_manager = None
self._gcs_job_info_stub = None
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)

async def _parse_and_validate_request(
self, req: Request, request_type: dataclass
Expand Down Expand Up @@ -95,9 +92,7 @@ async def find_job_by_ids(self, job_or_submission_id: str) -> Optional[JobDetail
# then lets try to search for a submission with given id
submission_id = job_or_submission_id

job_info = await asyncio.get_event_loop().run_in_executor(
self._executor, lambda: self._job_manager.get_job_info(submission_id)
)
job_info = await self._job_manager.get_job_info(submission_id)
if job_info:
driver = submission_job_drivers.get(submission_id)
job = JobDetails(
Expand Down Expand Up @@ -182,7 +177,7 @@ async def submit_job(self, req: Request) -> Response:
request_submission_id = submit_request.submission_id or submit_request.job_id

try:
submission_id = self._job_manager.submit_job(
submission_id = await self._job_manager.submit_job(
entrypoint=submit_request.entrypoint,
submission_id=request_submission_id,
runtime_env=submit_request.runtime_env,
Expand Down Expand Up @@ -257,10 +252,7 @@ async def get_job_info(self, req: Request) -> Response:
async def list_jobs(self, req: Request) -> Response:
driver_jobs, submission_job_drivers = await self._get_driver_jobs()

# TODO(aguo): convert _job_manager.list_jobs to an async function.
submission_jobs = await asyncio.get_event_loop().run_in_executor(
self._executor, self._job_manager.list_jobs
)
submission_jobs = await self._job_manager.list_jobs()
submission_jobs = [
JobDetails(
**dataclasses.asdict(job),
Expand Down Expand Up @@ -386,7 +378,7 @@ async def tail_job_logs(self, req: Request) -> Response:

async def run(self, server):
if not self._job_manager:
self._job_manager = JobManager()
self._job_manager = JobManager(self._dashboard_head.gcs_aio_client)

self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel
Expand Down
72 changes: 42 additions & 30 deletions dashboard/modules/job/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Dict, Iterator, Optional, Tuple

import ray
from ray._private.gcs_utils import GcsAioClient
import ray._private.ray_constants as ray_constants
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
from ray.actor import ActorHandle
Expand Down Expand Up @@ -103,9 +104,16 @@ class JobSupervisor:

SUBPROCESS_POLL_PERIOD_S = 0.1

def __init__(self, job_id: str, entrypoint: str, user_metadata: Dict[str, str]):
def __init__(
self,
job_id: str,
entrypoint: str,
user_metadata: Dict[str, str],
gcs_address: str,
):
self._job_id = job_id
self._job_info_client = JobInfoStorageClient()
gcs_aio_client = GcsAioClient(address=gcs_address)
self._job_info_client = JobInfoStorageClient(gcs_aio_client)
self._log_client = JobLogStorageClient()
self._driver_runtime_env = self._get_driver_runtime_env()
self._entrypoint = entrypoint
Expand Down Expand Up @@ -227,14 +235,14 @@ async def run(
variables.
3) Handle concurrent events of driver execution and
"""
curr_status = self._job_info_client.get_status(self._job_id)
curr_status = await self._job_info_client.get_status(self._job_id)
assert curr_status == JobStatus.PENDING, "Run should only be called once."

if _start_signal_actor:
# Block in PENDING state until start signal received.
await _start_signal_actor.wait.remote()

self._job_info_client.put_status(self._job_id, JobStatus.RUNNING)
await self._job_info_client.put_status(self._job_id, JobStatus.RUNNING)

try:
# Configure environment variables for the child process. These
Expand All @@ -257,15 +265,17 @@ async def run(
polling_task.cancel()
# TODO (jiaodong): Improve this with SIGTERM then SIGKILL
child_process.kill()
self._job_info_client.put_status(self._job_id, JobStatus.STOPPED)
await self._job_info_client.put_status(self._job_id, JobStatus.STOPPED)
else:
# Child process finished execution and no stop event is set
# at the same time
assert len(finished) == 1, "Should have only one coroutine done"
[child_process_task] = finished
return_code = child_process_task.result()
if return_code == 0:
self._job_info_client.put_status(self._job_id, JobStatus.SUCCEEDED)
await self._job_info_client.put_status(
self._job_id, JobStatus.SUCCEEDED
)
else:
log_tail = self._log_client.get_last_n_log_lines(self._job_id)
if log_tail is not None and log_tail != "":
Expand All @@ -275,7 +285,7 @@ async def run(
)
else:
message = None
self._job_info_client.put_status(
await self._job_info_client.put_status(
self._job_id, JobStatus.FAILED, message=message
)
except Exception:
Expand Down Expand Up @@ -307,20 +317,22 @@ class JobManager:
LOG_TAIL_SLEEP_S = 1
JOB_MONITOR_LOOP_PERIOD_S = 1

def __init__(self):
self._job_info_client = JobInfoStorageClient()
def __init__(self, gcs_aio_client: GcsAioClient):
self._gcs_aio_client = gcs_aio_client
self._job_info_client = JobInfoStorageClient(gcs_aio_client)
self._gcs_address = gcs_aio_client._channel._gcs_address
self._log_client = JobLogStorageClient()
self._supervisor_actor_cls = ray.remote(JobSupervisor)

self._recover_running_jobs()
create_task(self._recover_running_jobs())

def _recover_running_jobs(self):
async def _recover_running_jobs(self):
"""Recovers all running jobs from the status client.
For each job, we will spawn a coroutine to monitor it.
Each will be added to self._running_jobs and reconciled.
"""
all_jobs = self._job_info_client.get_all_jobs()
all_jobs = await self._job_info_client.get_all_jobs()
for job_id, job_info in all_jobs.items():
if not job_info.status.is_terminal():
create_task(self._monitor_job(job_id))
Expand All @@ -345,7 +357,7 @@ async def _monitor_job(

if job_supervisor is None:
logger.error(f"Failed to get job supervisor for job {job_id}.")
self._job_info_client.put_status(
await self._job_info_client.put_status(
job_id,
JobStatus.FAILED,
message="Unexpected error occurred: Failed to get job supervisor.",
Expand All @@ -358,13 +370,14 @@ async def _monitor_job(
await asyncio.sleep(self.JOB_MONITOR_LOOP_PERIOD_S)
except Exception as e:
is_alive = False
if self._job_info_client.get_status(job_id).is_terminal():
job_status = await self._job_info_client.get_status(job_id)
if job_status.is_terminal():
# If the job is already in a terminal state, then the actor
# exiting is expected.
pass
elif isinstance(e, RuntimeEnvSetupError):
logger.info(f"Failed to set up runtime_env for job {job_id}.")
self._job_info_client.put_status(
await self._job_info_client.put_status(
job_id,
JobStatus.FAILED,
message=f"runtime_env setup failed: {e}",
Expand All @@ -373,7 +386,7 @@ async def _monitor_job(
logger.warning(
f"Job supervisor for job {job_id} failed unexpectedly: {e}."
)
self._job_info_client.put_status(
await self._job_info_client.put_status(
job_id,
JobStatus.FAILED,
message=f"Unexpected error occurred: {e}",
Expand Down Expand Up @@ -413,7 +426,6 @@ def _handle_supervisor_startup(self, job_id: str, result: Optional[Exception]):
def _get_supervisor_runtime_env(
self, user_runtime_env: Dict[str, Any]
) -> Dict[str, Any]:

"""Configure and return the runtime_env for the supervisor actor."""

# Make a copy to avoid mutating passed runtime_env.
Expand All @@ -434,7 +446,7 @@ def _get_supervisor_runtime_env(
runtime_env["env_vars"] = env_vars
return runtime_env

def submit_job(
async def submit_job(
self,
*,
entrypoint: str,
Expand Down Expand Up @@ -473,7 +485,7 @@ def submit_job(
"""
if submission_id is None:
submission_id = generate_job_id()
elif self._job_info_client.get_status(submission_id) is not None:
elif await self._job_info_client.get_status(submission_id) is not None:
raise RuntimeError(f"Job {submission_id} already exists.")

logger.info(f"Starting job with submission_id: {submission_id}")
Expand All @@ -484,7 +496,7 @@ def submit_job(
metadata=metadata,
runtime_env=runtime_env,
)
self._job_info_client.put_info(submission_id, job_info)
await self._job_info_client.put_info(submission_id, job_info)

# Wait for the actor to start up asynchronously so this call always
# returns immediately and we can catch errors with the actor starting
Expand All @@ -500,14 +512,14 @@ def submit_job(
self._get_current_node_resource_key(): 0.001,
},
runtime_env=self._get_supervisor_runtime_env(runtime_env),
).remote(submission_id, entrypoint, metadata or {})
).remote(submission_id, entrypoint, metadata or {}, self._gcs_address)
supervisor.run.remote(_start_signal_actor=_start_signal_actor)

# Monitor the job in the background so we can detect errors without
# requiring a client to poll.
create_task(self._monitor_job(submission_id, job_supervisor=supervisor))
except Exception as e:
self._job_info_client.put_status(
await self._job_info_client.put_status(
submission_id,
JobStatus.FAILED,
message=f"Failed to start job supervisor: {e}.",
Expand All @@ -529,31 +541,31 @@ def stop_job(self, job_id) -> bool:
else:
return False

def get_job_status(self, job_id: str) -> Optional[JobStatus]:
async def get_job_status(self, job_id: str) -> Optional[JobStatus]:
"""Get latest status of a job."""
return self._job_info_client.get_status(job_id)
return await self._job_info_client.get_status(job_id)

def get_job_info(self, job_id: str) -> Optional[JobInfo]:
async def get_job_info(self, job_id: str) -> Optional[JobInfo]:
"""Get latest info of a job."""
return self._job_info_client.get_info(job_id)
return await self._job_info_client.get_info(job_id)

def list_jobs(self) -> Dict[str, JobInfo]:
async def list_jobs(self) -> Dict[str, JobInfo]:
"""Get info for all jobs."""
return self._job_info_client.get_all_jobs()
return await self._job_info_client.get_all_jobs()

def get_job_logs(self, job_id: str) -> str:
"""Get all logs produced by a job."""
return self._log_client.get_logs(job_id)

async def tail_job_logs(self, job_id: str) -> Iterator[str]:
"""Return an iterator following the logs of a job."""
if self.get_job_status(job_id) is None:
if await self.get_job_status(job_id) is None:
raise RuntimeError(f"Job '{job_id}' does not exist.")

for line in self._log_client.tail_logs(job_id):
if line is None:
# Return if the job has exited and there are no new log lines.
status = self.get_job_status(job_id)
status = await self.get_job_status(job_id)
if status not in {JobStatus.PENDING, JobStatus.RUNNING}:
return

Expand Down
Loading

0 comments on commit 326b5bd

Please sign in to comment.