diff --git a/dashboard/state_aggregator.py b/dashboard/state_aggregator.py index e4e38c9f323b1..b7cfd20b5c9cf 100644 --- a/dashboard/state_aggregator.py +++ b/dashboard/state_aggregator.py @@ -377,16 +377,10 @@ async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: {task_id -> task_data_in_dict} task_data_in_dict's schema is in TaskState """ - job_id = None - for filter in option.filters: - if filter[0] == "job_id" and filter[1] == "=": - # Filtering by job_id == xxxx, pass it to source side filtering. - # tuple consists of (job_id, predicate, value) - job_id = filter[2] try: reply = await self._client.get_all_task_info( timeout=option.timeout, - job_id=job_id, + filters=option.filters, exclude_driver=option.exclude_driver, ) except DataSourceUnavailable: diff --git a/python/ray/experimental/state/state_manager.py b/python/ray/experimental/state/state_manager.py index 11ea98b89c4c9..19e1fa318e381 100644 --- a/python/ray/experimental/state/state_manager.py +++ b/python/ray/experimental/state/state_manager.py @@ -12,7 +12,7 @@ from ray._private import ray_constants from ray._private.gcs_utils import GcsAioClient from ray._private.utils import hex_to_binary -from ray._raylet import ActorID, JobID +from ray._raylet import ActorID, JobID, TaskID from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated.gcs_pb2 import ActorTableData from ray.core.generated.gcs_service_pb2 import ( @@ -262,16 +262,40 @@ async def get_all_task_info( self, timeout: int = None, limit: int = None, - job_id: Optional[str] = None, - exclude_driver: bool = True, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + exclude_driver: bool = False, ) -> Optional[GetTaskEventsReply]: if not limit: limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE - if job_id: - job_id = JobID(hex_to_binary(job_id)).binary() - request = GetTaskEventsRequest( - limit=limit, exclude_driver=exclude_driver, job_id=job_id - ) + + if filters is None: + filters = [] + + req_filters = GetTaskEventsRequest.Filters() + for filter in filters: + key, predicate, value = filter + if predicate != "=": + # We only support EQUAL predicate for source side filtering. + continue + + if key == "actor_id": + req_filters.actor_id = ActorID(hex_to_binary(value)).binary() + elif key == "job_id": + req_filters.job_id = JobID(hex_to_binary(value)).binary() + elif key == "name": + req_filters.name = value + elif key == "task_id": + req_filters.task_ids.append(TaskID(hex_to_binary(value)).binary()) + else: + continue + + # Remove the filter from the list so that we don't have to + # filter it again later. + filters.remove(filter) + + req_filters.exclude_driver = exclude_driver + + request = GetTaskEventsRequest(limit=limit, filters=req_filters) reply = await self._gcs_task_info_stub.GetTaskEvents(request, timeout=timeout) return reply diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index 64882421c8562..38d097ae5d9c1 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -2282,7 +2282,7 @@ def g(dep): def impossible(): pass - out = [f.remote() for _ in range(2)] # noqa + out = [f.options(name=f"f_{i}").remote() for i in range(2)] # noqa g_out = g.remote(f.remote()) # noqa im = impossible.remote() # noqa @@ -2350,6 +2350,9 @@ def verify(): for task in tasks: assert task["job_id"] == job_id + tasks = list_tasks(filters=[("name", "=", "f_0")]) + assert len(tasks) == 1 + return True wait_for_condition(verify) @@ -2540,7 +2543,6 @@ def verify(): for task in tasks: assert task["job_id"] == job_id for task in tasks: - print(task) assert task["actor_id"] == actor_id # Actor.__init__: 1 finished # Actor.call: 1 running, 9 waiting for execution (queued). @@ -2590,6 +2592,10 @@ def verify(): == 1 ) + # Filters with actor id. + assert len(list_tasks(filters=[("actor_id", "=", actor_id)])) == 11 + assert len(list_tasks(filters=[("actor_id", "!=", actor_id)])) == 0 + return True wait_for_condition(verify) diff --git a/src/ray/gcs/gcs_server/gcs_task_manager.cc b/src/ray/gcs/gcs_server/gcs_task_manager.cc index 6771e042bb24e..e733856b8ee5e 100644 --- a/src/ray/gcs/gcs_server/gcs_task_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_task_manager.cc @@ -313,16 +313,17 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request, rpc::SendReplyCallback send_reply_callback) { RAY_LOG(DEBUG) << "Getting task status:" << request.ShortDebugString(); - // Select candidate events by indexing. + // Select candidate events by indexing if possible. std::vector task_events; - if (request.has_task_ids()) { + const auto &filters = request.filters(); + if (filters.task_ids_size() > 0) { absl::flat_hash_set task_ids; - for (const auto &task_id_str : request.task_ids().vals()) { + for (const auto &task_id_str : filters.task_ids()) { task_ids.insert(TaskID::FromBinary(task_id_str)); } task_events = task_event_storage_->GetTaskEvents(task_ids); - } else if (request.has_job_id()) { - task_events = task_event_storage_->GetTaskEvents(JobID::FromBinary(request.job_id())); + } else if (filters.has_job_id()) { + task_events = task_event_storage_->GetTaskEvents(JobID::FromBinary(filters.job_id())); } else { task_events = task_event_storage_->GetTaskEvents(); } @@ -334,15 +335,34 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request, int32_t num_profile_event_limit = 0; int32_t num_status_event_limit = 0; - for (auto itr = task_events.rbegin(); itr != task_events.rend(); ++itr) { - auto &task_event = *itr; + // A lambda filter fn, where it returns true for task events to be included in the + // result. Task ids and job ids are already filtered by the storage with indexing above. + auto filter_fn = [&filters](const rpc::TaskEvents &task_event) { if (!task_event.has_task_info()) { // Skip task events w/o task info. - continue; + return false; } - - if (request.exclude_driver() && + if (filters.exclude_driver() && task_event.task_info().type() == rpc::TaskType::DRIVER_TASK) { + return false; + } + + if (filters.has_actor_id() && task_event.task_info().has_actor_id() && + ActorID::FromBinary(task_event.task_info().actor_id()) != + ActorID::FromBinary(filters.actor_id())) { + return false; + } + + if (filters.has_name() && task_event.task_info().name() != filters.name()) { + return false; + } + + return true; + }; + + for (auto itr = task_events.rbegin(); itr != task_events.rend(); ++itr) { + auto &task_event = *itr; + if (!filter_fn(task_event)) { continue; } diff --git a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc index 91070fe1cf352..d60ea97f100f1 100644 --- a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc @@ -115,26 +115,36 @@ class GcsTaskManagerTest : public ::testing::Test { rpc::GetTaskEventsReply SyncGetTaskEvents(absl::flat_hash_set task_ids, absl::optional job_id = absl::nullopt, int64_t limit = -1, - bool exclude_driver = true) { + bool exclude_driver = true, + const std::string &name = "", + const ActorID &actor_id = ActorID::Nil()) { rpc::GetTaskEventsRequest request; rpc::GetTaskEventsReply reply; std::promise promise; if (!task_ids.empty()) { for (const auto &task_id : task_ids) { - request.mutable_task_ids()->add_vals(task_id.Binary()); + request.mutable_filters()->add_task_ids(task_id.Binary()); } } + if (!name.empty()) { + request.mutable_filters()->set_name(name); + } + + if (!actor_id.IsNil()) { + request.mutable_filters()->set_actor_id(actor_id.Binary()); + } + if (job_id) { - request.set_job_id(job_id->Binary()); + request.mutable_filters()->set_job_id(job_id->Binary()); } if (limit >= 0) { request.set_limit(limit); } - request.set_exclude_driver(exclude_driver); + request.mutable_filters()->set_exclude_driver(exclude_driver); task_manager->GetIoContext().dispatch( [this, &promise, &request, &reply]() { task_manager->HandleGetTaskEvents( @@ -155,11 +165,15 @@ class GcsTaskManagerTest : public ::testing::Test { static rpc::TaskInfoEntry GenTaskInfo( JobID job_id, TaskID parent_task_id = TaskID::Nil(), - rpc::TaskType task_type = rpc::TaskType::NORMAL_TASK) { + rpc::TaskType task_type = rpc::TaskType::NORMAL_TASK, + const ActorID actor_id = ActorID::Nil(), + const std::string name = "") { rpc::TaskInfoEntry task_info; task_info.set_job_id(job_id.Binary()); task_info.set_parent_task_id(parent_task_id.Binary()); task_info.set_type(task_type); + task_info.set_actor_id(actor_id.Binary()); + task_info.set_name(name); return task_info; } @@ -490,6 +504,66 @@ TEST_F(GcsTaskManagerTest, TestGetTaskEventsByJob) { reply_job2.mutable_events_by_task()); } +TEST_F(GcsTaskManagerTest, TestGetTaskEventsFilters) { + // Generate task events + + // A task event with actor id + ActorID actor_id = ActorID::Of(JobID::FromInt(1), TaskID::Nil(), 1); + { + auto task_ids = GenTaskIDs(1); + auto task_info_actor_id = + GenTaskInfo(JobID::FromInt(1), TaskID::Nil(), rpc::ACTOR_TASK, actor_id); + auto events = GenTaskEvents(task_ids, + /* attempt_number */ + 0, + /* job_id */ 1, + absl::nullopt, + absl::nullopt, + task_info_actor_id); + auto data = Mocker::GenTaskEventsData(events); + SyncAddTaskEventData(data); + } + + // A task event with name. + { + auto task_ids = GenTaskIDs(1); + auto task_info_name = GenTaskInfo( + JobID::FromInt(1), TaskID::Nil(), rpc::NORMAL_TASK, ActorID::Nil(), "task_name"); + auto events = GenTaskEvents(task_ids, + /* attempt_number */ + 0, + /* job_id */ 1, + absl::nullopt, + absl::nullopt, + task_info_name); + auto data = Mocker::GenTaskEventsData(events); + SyncAddTaskEventData(data); + } + + auto reply_name = SyncGetTaskEvents({}, + /* job_id */ absl::nullopt, + /* limit */ -1, + /* exclude_driver */ false, + "task_name"); + EXPECT_EQ(reply_name.events_by_task_size(), 1); + + auto reply_actor_id = SyncGetTaskEvents({}, + /* job_id */ absl::nullopt, + /* limit */ -1, + /* exclude_driver */ false, + /* name */ "", + actor_id); + EXPECT_EQ(reply_name.events_by_task_size(), 1); + + auto reply_both_and = SyncGetTaskEvents({}, + /* job_id */ absl::nullopt, + /* limit */ -1, + /* exclude_driver */ false, + "task_name", + actor_id); + EXPECT_EQ(reply_both_and.events_by_task_size(), 0); +} + TEST_F(GcsTaskManagerTest, TestMarkTaskAttemptFailedIfNeeded) { auto tasks = GenTaskIDs(3); auto tasks_running = tasks[0]; diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 38280e48d3f63..7bc382bc08425 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -644,22 +644,27 @@ message AddTaskEventDataReply { } message GetTaskEventsRequest { - message TaskIDs { - repeated string vals = 1; - } - oneof select_by { + // Filter object where predicates are AND together. + message Filters { // Get task events from a job. - string job_id = 1; + optional bytes job_id = 1; // Get task events from a set of tasks. - TaskIDs task_ids = 2; + repeated bytes task_ids = 2; + // Get the task events with an actor id. + optional bytes actor_id = 3; + // Get the task events of task with names. + optional string name = 4; + // True if task events from driver (only profiling events) should be excluded. + optional bool exclude_driver = 5; } // Maximum number of TaskEvents to return. // If set, the exact `limit` TaskEvents returned do not have any ordering or selection // guarantee. optional int64 limit = 3; - // True if task events from driver (only profiling events) should be excluded. - bool exclude_driver = 4; + + // Filters to apply to the get query. + optional Filters filters = 4; } message GetTaskEventsReply {