Skip to content

Commit

Permalink
[core][state] Push down filtering to GCS for listing/getting task fro…
Browse files Browse the repository at this point in the history
…m state api (#34433)

Similar to #34348

This pushes down the below filters to GCS (source-side) filtering.

actor_id
task id
task name
job id
  • Loading branch information
rickyyx committed May 5, 2023
1 parent 28a7412 commit d8321a7
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 40 deletions.
8 changes: 1 addition & 7 deletions dashboard/state_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,16 +377,10 @@ async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
{task_id -> task_data_in_dict}
task_data_in_dict's schema is in TaskState
"""
job_id = None
for filter in option.filters:
if filter[0] == "job_id" and filter[1] == "=":
# Filtering by job_id == xxxx, pass it to source side filtering.
# tuple consists of (job_id, predicate, value)
job_id = filter[2]
try:
reply = await self._client.get_all_task_info(
timeout=option.timeout,
job_id=job_id,
filters=option.filters,
exclude_driver=option.exclude_driver,
)
except DataSourceUnavailable:
Expand Down
40 changes: 32 additions & 8 deletions python/ray/experimental/state/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ray._private import ray_constants
from ray._private.gcs_utils import GcsAioClient
from ray._private.utils import hex_to_binary
from ray._raylet import ActorID, JobID
from ray._raylet import ActorID, JobID, TaskID
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.gcs_pb2 import ActorTableData
from ray.core.generated.gcs_service_pb2 import (
Expand Down Expand Up @@ -262,16 +262,40 @@ async def get_all_task_info(
self,
timeout: int = None,
limit: int = None,
job_id: Optional[str] = None,
exclude_driver: bool = True,
filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None,
exclude_driver: bool = False,
) -> Optional[GetTaskEventsReply]:
if not limit:
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
if job_id:
job_id = JobID(hex_to_binary(job_id)).binary()
request = GetTaskEventsRequest(
limit=limit, exclude_driver=exclude_driver, job_id=job_id
)

if filters is None:
filters = []

req_filters = GetTaskEventsRequest.Filters()
for filter in filters:
key, predicate, value = filter
if predicate != "=":
# We only support EQUAL predicate for source side filtering.
continue

if key == "actor_id":
req_filters.actor_id = ActorID(hex_to_binary(value)).binary()
elif key == "job_id":
req_filters.job_id = JobID(hex_to_binary(value)).binary()
elif key == "name":
req_filters.name = value
elif key == "task_id":
req_filters.task_ids.append(TaskID(hex_to_binary(value)).binary())
else:
continue

# Remove the filter from the list so that we don't have to
# filter it again later.
filters.remove(filter)

req_filters.exclude_driver = exclude_driver

request = GetTaskEventsRequest(limit=limit, filters=req_filters)
reply = await self._gcs_task_info_stub.GetTaskEvents(request, timeout=timeout)
return reply

Expand Down
10 changes: 8 additions & 2 deletions python/ray/tests/test_state_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2282,7 +2282,7 @@ def g(dep):
def impossible():
pass

out = [f.remote() for _ in range(2)] # noqa
out = [f.options(name=f"f_{i}").remote() for i in range(2)] # noqa
g_out = g.remote(f.remote()) # noqa
im = impossible.remote() # noqa

Expand Down Expand Up @@ -2350,6 +2350,9 @@ def verify():
for task in tasks:
assert task["job_id"] == job_id

tasks = list_tasks(filters=[("name", "=", "f_0")])
assert len(tasks) == 1

return True

wait_for_condition(verify)
Expand Down Expand Up @@ -2540,7 +2543,6 @@ def verify():
for task in tasks:
assert task["job_id"] == job_id
for task in tasks:
print(task)
assert task["actor_id"] == actor_id
# Actor.__init__: 1 finished
# Actor.call: 1 running, 9 waiting for execution (queued).
Expand Down Expand Up @@ -2590,6 +2592,10 @@ def verify():
== 1
)

# Filters with actor id.
assert len(list_tasks(filters=[("actor_id", "=", actor_id)])) == 11
assert len(list_tasks(filters=[("actor_id", "!=", actor_id)])) == 0

return True

wait_for_condition(verify)
Expand Down
40 changes: 30 additions & 10 deletions src/ray/gcs/gcs_server/gcs_task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,17 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Getting task status:" << request.ShortDebugString();

// Select candidate events by indexing.
// Select candidate events by indexing if possible.
std::vector<rpc::TaskEvents> task_events;
if (request.has_task_ids()) {
const auto &filters = request.filters();
if (filters.task_ids_size() > 0) {
absl::flat_hash_set<TaskID> task_ids;
for (const auto &task_id_str : request.task_ids().vals()) {
for (const auto &task_id_str : filters.task_ids()) {
task_ids.insert(TaskID::FromBinary(task_id_str));
}
task_events = task_event_storage_->GetTaskEvents(task_ids);
} else if (request.has_job_id()) {
task_events = task_event_storage_->GetTaskEvents(JobID::FromBinary(request.job_id()));
} else if (filters.has_job_id()) {
task_events = task_event_storage_->GetTaskEvents(JobID::FromBinary(filters.job_id()));
} else {
task_events = task_event_storage_->GetTaskEvents();
}
Expand All @@ -334,15 +335,34 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request,
int32_t num_profile_event_limit = 0;
int32_t num_status_event_limit = 0;

for (auto itr = task_events.rbegin(); itr != task_events.rend(); ++itr) {
auto &task_event = *itr;
// A lambda filter fn, where it returns true for task events to be included in the
// result. Task ids and job ids are already filtered by the storage with indexing above.
auto filter_fn = [&filters](const rpc::TaskEvents &task_event) {
if (!task_event.has_task_info()) {
// Skip task events w/o task info.
continue;
return false;
}

if (request.exclude_driver() &&
if (filters.exclude_driver() &&
task_event.task_info().type() == rpc::TaskType::DRIVER_TASK) {
return false;
}

if (filters.has_actor_id() && task_event.task_info().has_actor_id() &&
ActorID::FromBinary(task_event.task_info().actor_id()) !=
ActorID::FromBinary(filters.actor_id())) {
return false;
}

if (filters.has_name() && task_event.task_info().name() != filters.name()) {
return false;
}

return true;
};

for (auto itr = task_events.rbegin(); itr != task_events.rend(); ++itr) {
auto &task_event = *itr;
if (!filter_fn(task_event)) {
continue;
}

Expand Down
84 changes: 79 additions & 5 deletions src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,36 @@ class GcsTaskManagerTest : public ::testing::Test {
rpc::GetTaskEventsReply SyncGetTaskEvents(absl::flat_hash_set<TaskID> task_ids,
absl::optional<JobID> job_id = absl::nullopt,
int64_t limit = -1,
bool exclude_driver = true) {
bool exclude_driver = true,
const std::string &name = "",
const ActorID &actor_id = ActorID::Nil()) {
rpc::GetTaskEventsRequest request;
rpc::GetTaskEventsReply reply;
std::promise<bool> promise;

if (!task_ids.empty()) {
for (const auto &task_id : task_ids) {
request.mutable_task_ids()->add_vals(task_id.Binary());
request.mutable_filters()->add_task_ids(task_id.Binary());
}
}

if (!name.empty()) {
request.mutable_filters()->set_name(name);
}

if (!actor_id.IsNil()) {
request.mutable_filters()->set_actor_id(actor_id.Binary());
}

if (job_id) {
request.set_job_id(job_id->Binary());
request.mutable_filters()->set_job_id(job_id->Binary());
}

if (limit >= 0) {
request.set_limit(limit);
}

request.set_exclude_driver(exclude_driver);
request.mutable_filters()->set_exclude_driver(exclude_driver);
task_manager->GetIoContext().dispatch(
[this, &promise, &request, &reply]() {
task_manager->HandleGetTaskEvents(
Expand All @@ -155,11 +165,15 @@ class GcsTaskManagerTest : public ::testing::Test {
static rpc::TaskInfoEntry GenTaskInfo(
JobID job_id,
TaskID parent_task_id = TaskID::Nil(),
rpc::TaskType task_type = rpc::TaskType::NORMAL_TASK) {
rpc::TaskType task_type = rpc::TaskType::NORMAL_TASK,
const ActorID actor_id = ActorID::Nil(),
const std::string name = "") {
rpc::TaskInfoEntry task_info;
task_info.set_job_id(job_id.Binary());
task_info.set_parent_task_id(parent_task_id.Binary());
task_info.set_type(task_type);
task_info.set_actor_id(actor_id.Binary());
task_info.set_name(name);
return task_info;
}

Expand Down Expand Up @@ -490,6 +504,66 @@ TEST_F(GcsTaskManagerTest, TestGetTaskEventsByJob) {
reply_job2.mutable_events_by_task());
}

TEST_F(GcsTaskManagerTest, TestGetTaskEventsFilters) {
// Generate task events

// A task event with actor id
ActorID actor_id = ActorID::Of(JobID::FromInt(1), TaskID::Nil(), 1);
{
auto task_ids = GenTaskIDs(1);
auto task_info_actor_id =
GenTaskInfo(JobID::FromInt(1), TaskID::Nil(), rpc::ACTOR_TASK, actor_id);
auto events = GenTaskEvents(task_ids,
/* attempt_number */
0,
/* job_id */ 1,
absl::nullopt,
absl::nullopt,
task_info_actor_id);
auto data = Mocker::GenTaskEventsData(events);
SyncAddTaskEventData(data);
}

// A task event with name.
{
auto task_ids = GenTaskIDs(1);
auto task_info_name = GenTaskInfo(
JobID::FromInt(1), TaskID::Nil(), rpc::NORMAL_TASK, ActorID::Nil(), "task_name");
auto events = GenTaskEvents(task_ids,
/* attempt_number */
0,
/* job_id */ 1,
absl::nullopt,
absl::nullopt,
task_info_name);
auto data = Mocker::GenTaskEventsData(events);
SyncAddTaskEventData(data);
}

auto reply_name = SyncGetTaskEvents({},
/* job_id */ absl::nullopt,
/* limit */ -1,
/* exclude_driver */ false,
"task_name");
EXPECT_EQ(reply_name.events_by_task_size(), 1);

auto reply_actor_id = SyncGetTaskEvents({},
/* job_id */ absl::nullopt,
/* limit */ -1,
/* exclude_driver */ false,
/* name */ "",
actor_id);
EXPECT_EQ(reply_name.events_by_task_size(), 1);

auto reply_both_and = SyncGetTaskEvents({},
/* job_id */ absl::nullopt,
/* limit */ -1,
/* exclude_driver */ false,
"task_name",
actor_id);
EXPECT_EQ(reply_both_and.events_by_task_size(), 0);
}

TEST_F(GcsTaskManagerTest, TestMarkTaskAttemptFailedIfNeeded) {
auto tasks = GenTaskIDs(3);
auto tasks_running = tasks[0];
Expand Down
21 changes: 13 additions & 8 deletions src/ray/protobuf/gcs_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -644,22 +644,27 @@ message AddTaskEventDataReply {
}

message GetTaskEventsRequest {
message TaskIDs {
repeated string vals = 1;
}
oneof select_by {
// Filter object where predicates are AND together.
message Filters {
// Get task events from a job.
string job_id = 1;
optional bytes job_id = 1;
// Get task events from a set of tasks.
TaskIDs task_ids = 2;
repeated bytes task_ids = 2;
// Get the task events with an actor id.
optional bytes actor_id = 3;
// Get the task events of task with names.
optional string name = 4;
// True if task events from driver (only profiling events) should be excluded.
optional bool exclude_driver = 5;
}

// Maximum number of TaskEvents to return.
// If set, the exact `limit` TaskEvents returned do not have any ordering or selection
// guarantee.
optional int64 limit = 3;
// True if task events from driver (only profiling events) should be excluded.
bool exclude_driver = 4;

// Filters to apply to the get query.
optional Filters filters = 4;
}

message GetTaskEventsReply {
Expand Down

0 comments on commit d8321a7

Please sign in to comment.