Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][state] Push down filtering to GCS for listing/getting task from state api #34433

Merged
merged 3 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions dashboard/state_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,16 +375,10 @@ async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
{task_id -> task_data_in_dict}
task_data_in_dict's schema is in TaskState
"""
job_id = None
for filter in option.filters:
if filter[0] == "job_id" and filter[1] == "=":
# Filtering by job_id == xxxx, pass it to source side filtering.
# tuple consists of (job_id, predicate, value)
job_id = filter[2]
try:
reply = await self._client.get_all_task_info(
timeout=option.timeout,
job_id=job_id,
filters=option.filters,
exclude_driver=option.exclude_driver,
)
except DataSourceUnavailable:
Expand Down
48 changes: 38 additions & 10 deletions python/ray/experimental/state/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from functools import wraps
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import grpc
from grpc.aio._call import UnaryStreamCall
Expand All @@ -12,7 +12,7 @@
from ray._private import ray_constants
from ray._private.gcs_utils import GcsAioClient
from ray._private.utils import hex_to_binary
from ray._raylet import JobID
from ray._raylet import ActorID, JobID, TaskID
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.gcs_service_pb2 import (
GetAllActorInfoReply,
Expand Down Expand Up @@ -47,7 +47,11 @@
from ray.dashboard.datacenter import DataSource
from ray.dashboard.modules.job.common import JobInfo, JobInfoStorageClient
from ray.dashboard.utils import Dict as Dictionary
from ray.experimental.state.common import RAY_MAX_LIMIT_FROM_DATA_SOURCE
from ray.experimental.state.common import (
RAY_MAX_LIMIT_FROM_DATA_SOURCE,
PredicateType,
SupportedFilterType,
)
from ray.experimental.state.exception import DataSourceUnavailable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -237,16 +241,40 @@ async def get_all_task_info(
self,
timeout: int = None,
limit: int = None,
job_id: Optional[str] = None,
exclude_driver: bool = True,
filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None,
exclude_driver: bool = False,
) -> Optional[GetTaskEventsReply]:
if not limit:
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
if job_id:
job_id = JobID(hex_to_binary(job_id)).binary()
request = GetTaskEventsRequest(
limit=limit, exclude_driver=exclude_driver, job_id=job_id
)

if filters is None:
filters = []

req_filters = GetTaskEventsRequest.Filters()
for filter in filters:
key, predicate, value = filter
if predicate != "=":
# We only support EQUAL predicate for source side filtering.
continue

if key == "actor_id":
req_filters.actor_id = ActorID(hex_to_binary(value)).binary()
elif key == "job_id":
req_filters.job_id = JobID(hex_to_binary(value)).binary()
elif key == "name":
req_filters.name = value
elif key == "task_id":
req_filters.task_ids.append(TaskID(hex_to_binary(value)).binary())
else:
continue

# Remove the filter from the list so that we don't have to
# filter it again later.
filters.remove(filter)

req_filters.exclude_driver = exclude_driver

request = GetTaskEventsRequest(limit=limit, filters=req_filters)
reply = await self._gcs_task_info_stub.GetTaskEvents(request, timeout=timeout)
return reply

Expand Down
10 changes: 8 additions & 2 deletions python/ray/tests/test_state_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2184,7 +2184,7 @@ def g(dep):
def impossible():
pass

out = [f.remote() for _ in range(2)] # noqa
out = [f.options(name=f"f_{i}").remote() for i in range(2)] # noqa
g_out = g.remote(f.remote()) # noqa
im = impossible.remote() # noqa

Expand Down Expand Up @@ -2252,6 +2252,9 @@ def verify():
for task in tasks:
assert task["job_id"] == job_id

tasks = list_tasks(filters=[("name", "=", "f_0")])
assert len(tasks) == 1

return True

wait_for_condition(verify)
Expand Down Expand Up @@ -2442,7 +2445,6 @@ def verify():
for task in tasks:
assert task["job_id"] == job_id
for task in tasks:
print(task)
assert task["actor_id"] == actor_id
# Actor.__init__: 1 finished
# Actor.call: 1 running, 9 waiting for execution (queued).
Expand Down Expand Up @@ -2492,6 +2494,10 @@ def verify():
== 1
)

# Filters with actor id.
assert len(list_tasks(filters=[("actor_id", "=", actor_id)])) == 11
assert len(list_tasks(filters=[("actor_id", "!=", actor_id)])) == 0

return True

wait_for_condition(verify)
Expand Down
40 changes: 30 additions & 10 deletions src/ray/gcs/gcs_server/gcs_task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,16 +369,17 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Getting task status:" << request.ShortDebugString();

// Select candidate events by indexing.
// Select candidate events by indexing if possible.
std::vector<rpc::TaskEvents> task_events;
if (request.has_task_ids()) {
const auto &filters = request.filters();
if (filters.task_ids_size() > 0) {
absl::flat_hash_set<TaskID> task_ids;
for (const auto &task_id_str : request.task_ids().vals()) {
for (const auto &task_id_str : filters.task_ids()) {
task_ids.insert(TaskID::FromBinary(task_id_str));
}
task_events = task_event_storage_->GetTaskEvents(task_ids);
} else if (request.has_job_id()) {
task_events = task_event_storage_->GetTaskEvents(JobID::FromBinary(request.job_id()));
} else if (filters.has_job_id()) {
task_events = task_event_storage_->GetTaskEvents(JobID::FromBinary(filters.job_id()));
} else {
task_events = task_event_storage_->GetTaskEvents();
}
Expand All @@ -390,15 +391,34 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request,
int32_t num_profile_event_limit = 0;
int32_t num_status_event_limit = 0;

for (auto itr = task_events.rbegin(); itr != task_events.rend(); ++itr) {
auto &task_event = *itr;
// A lambda filter fn, where it returns true for task events to be included in the
// result.
rickyyx marked this conversation as resolved.
Show resolved Hide resolved
auto filter_fn = [&filters](const rpc::TaskEvents &task_event) {
if (!task_event.has_task_info()) {
// Skip task events w/o task info.
continue;
return false;
}

if (request.exclude_driver() &&
if (filters.exclude_driver() &&
task_event.task_info().type() == rpc::TaskType::DRIVER_TASK) {
return false;
}

if (filters.has_actor_id() && task_event.task_info().has_actor_id() &&
ActorID::FromBinary(task_event.task_info().actor_id()) !=
ActorID::FromBinary(filters.actor_id())) {
return false;
}

if (filters.has_name() && task_event.task_info().name() != filters.name()) {
return false;
}

return true;
};

for (auto itr = task_events.rbegin(); itr != task_events.rend(); ++itr) {
auto &task_event = *itr;
if (!filter_fn(task_event)) {
continue;
}

Expand Down
84 changes: 79 additions & 5 deletions src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,36 @@ class GcsTaskManagerTest : public ::testing::Test {
rpc::GetTaskEventsReply SyncGetTaskEvents(absl::flat_hash_set<TaskID> task_ids,
absl::optional<JobID> job_id = absl::nullopt,
int64_t limit = -1,
bool exclude_driver = true) {
bool exclude_driver = true,
const std::string &name = "",
const ActorID &actor_id = ActorID::Nil()) {
rpc::GetTaskEventsRequest request;
rpc::GetTaskEventsReply reply;
std::promise<bool> promise;

if (!task_ids.empty()) {
for (const auto &task_id : task_ids) {
request.mutable_task_ids()->add_vals(task_id.Binary());
request.mutable_filters()->add_task_ids(task_id.Binary());
}
}

if (!name.empty()) {
request.mutable_filters()->set_name(name);
}

if (!actor_id.IsNil()) {
request.mutable_filters()->set_actor_id(actor_id.Binary());
}

if (job_id) {
request.set_job_id(job_id->Binary());
request.mutable_filters()->set_job_id(job_id->Binary());
}

if (limit >= 0) {
request.set_limit(limit);
}

request.set_exclude_driver(exclude_driver);
request.mutable_filters()->set_exclude_driver(exclude_driver);
task_manager->GetIoContext().dispatch(
[this, &promise, &request, &reply]() {
task_manager->HandleGetTaskEvents(
Expand All @@ -147,11 +157,15 @@ class GcsTaskManagerTest : public ::testing::Test {
static rpc::TaskInfoEntry GenTaskInfo(
JobID job_id,
TaskID parent_task_id = TaskID::Nil(),
rpc::TaskType task_type = rpc::TaskType::NORMAL_TASK) {
rpc::TaskType task_type = rpc::TaskType::NORMAL_TASK,
const ActorID actor_id = ActorID::Nil(),
const std::string name = "") {
rpc::TaskInfoEntry task_info;
task_info.set_job_id(job_id.Binary());
task_info.set_parent_task_id(parent_task_id.Binary());
task_info.set_type(task_type);
task_info.set_actor_id(actor_id.Binary());
task_info.set_name(name);
return task_info;
}

Expand Down Expand Up @@ -478,6 +492,66 @@ TEST_F(GcsTaskManagerTest, TestGetTaskEventsByJob) {
reply_job2.mutable_events_by_task());
}

TEST_F(GcsTaskManagerTest, TestGetTaskEventsFilters) {
// Generate task events

// A task event with actor id
ActorID actor_id = ActorID::Of(JobID::FromInt(1), TaskID::Nil(), 1);
{
auto task_ids = GenTaskIDs(1);
auto task_info_actor_id =
GenTaskInfo(JobID::FromInt(1), TaskID::Nil(), rpc::ACTOR_TASK, actor_id);
auto events = GenTaskEvents(task_ids,
/* attempt_number */
0,
/* job_id */ 1,
absl::nullopt,
absl::nullopt,
task_info_actor_id);
auto data = Mocker::GenTaskEventsData(events);
SyncAddTaskEventData(data);
}

// A task event with name.
{
auto task_ids = GenTaskIDs(1);
auto task_info_name = GenTaskInfo(
JobID::FromInt(1), TaskID::Nil(), rpc::NORMAL_TASK, ActorID::Nil(), "task_name");
auto events = GenTaskEvents(task_ids,
/* attempt_number */
0,
/* job_id */ 1,
absl::nullopt,
absl::nullopt,
task_info_name);
auto data = Mocker::GenTaskEventsData(events);
SyncAddTaskEventData(data);
}

auto reply_name = SyncGetTaskEvents({},
/* job_id */ absl::nullopt,
/* limit */ -1,
/* exclude_driver */ false,
"task_name");
EXPECT_EQ(reply_name.events_by_task_size(), 1);

auto reply_actor_id = SyncGetTaskEvents({},
rickyyx marked this conversation as resolved.
Show resolved Hide resolved
/* job_id */ absl::nullopt,
/* limit */ -1,
/* exclude_driver */ false,
/* name */ "",
actor_id);
EXPECT_EQ(reply_name.events_by_task_size(), 1);

auto reply_both_and = SyncGetTaskEvents({},
/* job_id */ absl::nullopt,
/* limit */ -1,
/* exclude_driver */ false,
"task_name",
actor_id);
EXPECT_EQ(reply_both_and.events_by_task_size(), 0);
}

TEST_F(GcsTaskManagerTest, TestFailingParentFailChildren) {
// This tests that a failed parent task should automatically fail its children tasks.
auto task_ids = GenTaskIDs(3);
Expand Down
21 changes: 13 additions & 8 deletions src/ray/protobuf/gcs_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -622,22 +622,27 @@ message AddTaskEventDataReply {
}

message GetTaskEventsRequest {
message TaskIDs {
repeated string vals = 1;
}
oneof select_by {
// Filter object where predicates are AND together.
message Filters {
// Get task events from a job.
string job_id = 1;
optional bytes job_id = 1;
// Get task events from a set of tasks.
TaskIDs task_ids = 2;
repeated bytes task_ids = 2;
// Get the task events with an actor id.
optional bytes actor_id = 3;
// Get the task events of task with names.
optional string name = 4;
// True if task events from driver (only profiling events) should be excluded.
optional bool exclude_driver = 5;
}

// Maximum number of TaskEvents to return.
// If set, the exact `limit` TaskEvents returned do not have any ordering or selection
// guarantee.
optional int64 limit = 3;
// True if task events from driver (only profiling events) should be excluded.
bool exclude_driver = 4;

// Filters to apply to the get query.
optional Filters filters = 4;
}

message GetTaskEventsReply {
Expand Down