Skip to content

Commit

Permalink
[Core][C++ Worker]Add timeout_ms param for ray.get() in c++ worker (r…
Browse files Browse the repository at this point in the history
…ay-project#38603)

## Why are these changes needed?
Add timeout_ms param for ray.get() in c++ worker
```
  auto actor2 =
      ray::Actor(ActorConcurrentCall::FactoryCreate).SetMaxConcurrency(2).Remote();
  auto object2_1 = actor2.Task(&ActorConcurrentCall::CountDown).Remote();
  auto object2_2 = actor2.Task(&ActorConcurrentCall::CountDown).Remote();
  auto object2_3 = actor2.Task(&ActorConcurrentCall::CountDown).Remote();

  EXPECT_THROW(object2_1.Get(2), ray::internal::RayTimeoutException);
  EXPECT_THROW(object2_2.Get(2), ray::internal::RayTimeoutException);
  EXPECT_THROW(object2_3.Get(2), ray::internal::RayTimeoutException);
```
  • Loading branch information
larrylian committed Aug 22, 2023
1 parent ce25b40 commit c21389f
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 12 deletions.
54 changes: 48 additions & 6 deletions cpp/include/ray/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,27 @@ std::shared_ptr<T> Get(const ray::ObjectRef<T> &object);
template <typename T>
std::vector<std::shared_ptr<T>> Get(const std::vector<ray::ObjectRef<T>> &objects);

/// Get a single object from the object store.
/// This method will be blocked until the object is ready.
///
/// \param[in] object The object reference which should be returned.
/// \param[in] timeout_ms The maximum amount of time in miliseconds to wait before
/// returning.
/// \return shared pointer of the result.
template <typename T>
std::shared_ptr<T> Get(const ray::ObjectRef<T> &object, const int &timeout_ms);

/// Get a list of objects from the object store.
/// This method will be blocked until all the objects are ready.
///
/// \param[in] objects The object array which should be got.
/// \param[in] timeout_ms The maximum amount of time in miliseconds to wait before
/// returning.
/// \return shared pointer array of the result.
template <typename T>
std::vector<std::shared_ptr<T>> Get(const std::vector<ray::ObjectRef<T>> &objects,
const int &timeout_ms);

/// Wait for a list of objects to be locally available,
/// until specified number of objects are ready, or specified timeout has passed.
///
Expand Down Expand Up @@ -141,6 +162,10 @@ void ExitActor();
template <typename T>
std::vector<std::shared_ptr<T>> Get(const std::vector<std::string> &ids);

template <typename T>
std::vector<std::shared_ptr<T>> Get(const std::vector<std::string> &ids,
const int &timeout_ms);

/// Create a placement group on remote nodes.
///
/// \param[in] create_options Creation options of the placement group.
Expand Down Expand Up @@ -193,13 +218,14 @@ inline ray::ObjectRef<T> Put(const T &obj) {
}

template <typename T>
inline std::shared_ptr<T> Get(const ray::ObjectRef<T> &object) {
return GetFromRuntime(object);
inline std::shared_ptr<T> Get(const ray::ObjectRef<T> &object, const int &timeout_ms) {
return GetFromRuntime(object, timeout_ms);
}

template <typename T>
inline std::vector<std::shared_ptr<T>> Get(const std::vector<std::string> &ids) {
auto result = ray::internal::GetRayRuntime()->Get(ids);
inline std::vector<std::shared_ptr<T>> Get(const std::vector<std::string> &ids,
const int &timeout_ms) {
auto result = ray::internal::GetRayRuntime()->Get(ids, timeout_ms);
std::vector<std::shared_ptr<T>> return_objects;
return_objects.reserve(result.size());
for (auto it = result.begin(); it != result.end(); it++) {
Expand All @@ -211,9 +237,25 @@ inline std::vector<std::shared_ptr<T>> Get(const std::vector<std::string> &ids)
}

template <typename T>
inline std::vector<std::shared_ptr<T>> Get(const std::vector<ray::ObjectRef<T>> &ids) {
inline std::vector<std::shared_ptr<T>> Get(const std::vector<ray::ObjectRef<T>> &ids,
const int &timeout_ms) {
auto object_ids = ObjectRefsToObjectIDs<T>(ids);
return Get<T>(object_ids);
return Get<T>(object_ids, timeout_ms);
}

template <typename T>
inline std::shared_ptr<T> Get(const ray::ObjectRef<T> &object) {
return Get<T>(object, -1);
}

template <typename T>
inline std::vector<std::shared_ptr<T>> Get(const std::vector<std::string> &ids) {
return Get<T>(ids, -1);
}

template <typename T>
inline std::vector<std::shared_ptr<T>> Get(const std::vector<ray::ObjectRef<T>> &ids) {
return Get<T>(ids, -1);
}

template <typename T>
Expand Down
31 changes: 28 additions & 3 deletions cpp/include/ray/api/object_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ class ObjectRef {
/// \return shared pointer of the result.
std::shared_ptr<T> Get() const;

/// Get the object from the object store.
/// This method will be blocked until the object is ready.
///
/// \param timeout_ms The maximum amount of time in miliseconds to wait before
/// returning.
/// \return shared pointer of the result.
std::shared_ptr<T> Get(const int &timeout_ms) const;

/// Make ObjectRef serializable
MSGPACK_DEFINE(id_);

Expand All @@ -112,8 +120,9 @@ class ObjectRef {

// ---------- implementation ----------
template <typename T>
inline static std::shared_ptr<T> GetFromRuntime(const ObjectRef<T> &object) {
auto packed_object = internal::GetRayRuntime()->Get(object.ID());
inline static std::shared_ptr<T> GetFromRuntime(const ObjectRef<T> &object,
const int &timeout_ms) {
auto packed_object = internal::GetRayRuntime()->Get(object.ID(), timeout_ms);
CheckResult(packed_object);

if (ray::internal::Serializer::IsXLang(packed_object->data(), packed_object->size())) {
Expand Down Expand Up @@ -156,7 +165,12 @@ const std::string &ObjectRef<T>::ID() const {

template <typename T>
inline std::shared_ptr<T> ObjectRef<T>::Get() const {
return GetFromRuntime(*this);
return GetFromRuntime(*this, -1);
}

template <typename T>
inline std::shared_ptr<T> ObjectRef<T>::Get(const int &timeout_ms) const {
return GetFromRuntime(*this, timeout_ms);
}

template <>
Expand Down Expand Up @@ -190,6 +204,17 @@ class ObjectRef<void> {
CheckResult(packed_object);
}

/// Get the object from the object store.
/// This method will be blocked until the object is ready.
///
/// \param timeout_ms The maximum amount of time in miliseconds to wait before
/// returning.
/// \return shared pointer of the result.
void Get(const int &timeout_ms) const {
auto packed_object = internal::GetRayRuntime()->Get(id_, timeout_ms);
CheckResult(packed_object);
}

/// Make ObjectRef serializable
MSGPACK_DEFINE(id_);

Expand Down
6 changes: 6 additions & 0 deletions cpp/include/ray/api/ray_exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,11 @@ class RayRuntimeEnvException : public RayException {
public:
RayRuntimeEnvException(const std::string &msg) : RayException(msg){};
};

class RayTimeoutException : public RayException {
public:
RayTimeoutException(const std::string &msg) : RayException(msg){};
};

} // namespace internal
} // namespace ray
6 changes: 6 additions & 0 deletions cpp/include/ray/api/ray_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class RayRuntime {
virtual std::vector<std::shared_ptr<msgpack::sbuffer>> Get(
const std::vector<std::string> &ids) = 0;

virtual std::shared_ptr<msgpack::sbuffer> Get(const std::string &object_id,
const int &timeout_ms) = 0;

virtual std::vector<std::shared_ptr<msgpack::sbuffer>> Get(
const std::vector<std::string> &ids, const int &timeout_ms) = 0;

virtual std::vector<bool> Wait(const std::vector<std::string> &ids,
int num_objects,
int timeout_ms) = 0;
Expand Down
14 changes: 12 additions & 2 deletions cpp/src/ray/runtime/abstract_ray_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ std::string AbstractRayRuntime::Put(std::shared_ptr<msgpack::sbuffer> data) {
}

std::shared_ptr<msgpack::sbuffer> AbstractRayRuntime::Get(const std::string &object_id) {
return object_store_->Get(ObjectID::FromBinary(object_id), -1);
return Get(object_id, -1);
}

inline static std::vector<ObjectID> StringIDsToObjectIDs(
Expand All @@ -107,7 +107,17 @@ inline static std::vector<ObjectID> StringIDsToObjectIDs(

std::vector<std::shared_ptr<msgpack::sbuffer>> AbstractRayRuntime::Get(
const std::vector<std::string> &ids) {
return object_store_->Get(StringIDsToObjectIDs(ids), -1);
return Get(ids, -1);
}

std::shared_ptr<msgpack::sbuffer> AbstractRayRuntime::Get(const std::string &object_id,
const int &timeout_ms) {
return object_store_->Get(ObjectID::FromBinary(object_id), timeout_ms);
}

std::vector<std::shared_ptr<msgpack::sbuffer>> AbstractRayRuntime::Get(
const std::vector<std::string> &ids, const int &timeout_ms) {
return object_store_->Get(StringIDsToObjectIDs(ids), timeout_ms);
}

std::vector<bool> AbstractRayRuntime::Wait(const std::vector<std::string> &ids,
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/ray/runtime/abstract_ray_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ class AbstractRayRuntime : public RayRuntime {

std::vector<std::shared_ptr<msgpack::sbuffer>> Get(const std::vector<std::string> &ids);

std::shared_ptr<msgpack::sbuffer> Get(const std::string &object_id,
const int &timeout_ms);

std::vector<std::shared_ptr<msgpack::sbuffer>> Get(const std::vector<std::string> &ids,
const int &timeout_ms);

std::vector<bool> Wait(const std::vector<std::string> &ids,
int num_objects,
int timeout_ms);
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/ray/runtime/object/native_object_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ std::vector<std::shared_ptr<msgpack::sbuffer>> NativeObjectStore::GetRaw(
std::vector<std::shared_ptr<::ray::RayObject>> results;
::ray::Status status = core_worker.Get(ids, timeout_ms, &results);
if (!status.ok()) {
if (status.IsTimedOut()) {
throw RayTimeoutException("Get object error:" + status.message());
}
throw RayException("Get object error: " + status.ToString());
}
RAY_CHECK(results.size() == ids.size());
Expand Down
19 changes: 18 additions & 1 deletion cpp/src/ray/test/cluster/cluster_mode_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ TEST(RayClusterModeTest, FullTest) {
auto get_result = *(ray::Get(obj));
EXPECT_EQ(12345, get_result);

EXPECT_EQ(12345, *(ray::Get(obj, 5)));

auto named_obj =
ray::Task(Return1).SetName("named_task").SetResources({{"CPU", 1.0}}).Remote();
EXPECT_EQ(1, *named_obj.Get());
Expand Down Expand Up @@ -168,6 +170,11 @@ TEST(RayClusterModeTest, FullTest) {
EXPECT_EQ(result1, 31);
EXPECT_EQ(result2, 25);

result_vector = ray::Get(objects, 5);
EXPECT_EQ(*(result_vector[0]), 1);
EXPECT_EQ(*(result_vector[1]), 31);
EXPECT_EQ(*(result_vector[2]), 25);

/// general function remote call(args passed by reference)
auto r3 = ray::Task(Return1).Remote();
auto r4 = ray::Task(Plus1).Remote(r3);
Expand Down Expand Up @@ -231,7 +238,7 @@ TEST(RayClusterModeTest, FullTest) {
EXPECT_EQ(arr, *(ray::Get(r17)));
EXPECT_EQ(arr, *(ray::Get(r18)));

uint64_t pid = *actor1.Task(&Counter::GetPid).Remote().Get();
uint64_t pid = *actor1.Task(&Counter::GetPid).Remote().Get(5);
EXPECT_TRUE(Counter::IsProcessAlive(pid));

auto actor_object4 = actor1.Task(&Counter::Exit).Remote();
Expand Down Expand Up @@ -305,6 +312,16 @@ TEST(RayClusterModeTest, MaxConcurrentTest) {
EXPECT_EQ(*object1.Get(), "ok");
EXPECT_EQ(*object2.Get(), "ok");
EXPECT_EQ(*object3.Get(), "ok");

auto actor2 =
ray::Actor(ActorConcurrentCall::FactoryCreate).SetMaxConcurrency(2).Remote();
auto object2_1 = actor2.Task(&ActorConcurrentCall::CountDown).Remote();
auto object2_2 = actor2.Task(&ActorConcurrentCall::CountDown).Remote();
auto object2_3 = actor2.Task(&ActorConcurrentCall::CountDown).Remote();

EXPECT_THROW(object2_1.Get(2), ray::internal::RayTimeoutException);
EXPECT_THROW(object2_2.Get(2), ray::internal::RayTimeoutException);
EXPECT_THROW(object2_3.Get(2), ray::internal::RayTimeoutException);
}

TEST(RayClusterModeTest, ResourcesManagementTest) {
Expand Down

0 comments on commit c21389f

Please sign in to comment.