diff --git a/cpp/include/ray/api.h b/cpp/include/ray/api.h index 1d1c2b8b1b871..e7ac0ffad792f 100644 --- a/cpp/include/ray/api.h +++ b/cpp/include/ray/api.h @@ -97,6 +97,9 @@ ray::internal::TaskCaller Task(F func); template ray::internal::TaskCaller> Task(PyFunction func); +template +ray::internal::TaskCaller> Task(JavaFunction func); + /// Generic version of creating an actor /// It is used for creating an actor, such as: ActorCreator creator = /// ray::Actor(Counter::FactoryCreate).Remote(1); @@ -232,6 +235,13 @@ inline ray::internal::TaskCaller> Task(PyFunction func) { return {ray::internal::GetRayRuntime().get(), std::move(remote_func_holder)}; } +template +inline ray::internal::TaskCaller> Task(JavaFunction func) { + ray::internal::RemoteFunctionHolder remote_func_holder( + "", func.function_name, func.class_name, ray::internal::LangType::JAVA); + return {ray::internal::GetRayRuntime().get(), std::move(remote_func_holder)}; +} + inline ray::internal::ActorCreator Actor(JavaActorClass func) { ray::internal::RemoteFunctionHolder remote_func_holder(func.module_name, func.function_name, diff --git a/cpp/include/ray/api/task_caller.h b/cpp/include/ray/api/task_caller.h index 04d40aa3f14ee..519858ee0328e 100644 --- a/cpp/include/ray/api/task_caller.h +++ b/cpp/include/ray/api/task_caller.h @@ -74,7 +74,7 @@ ObjectRef> TaskCaller::Remote( Args &&...args) { CheckTaskOptions(task_options_.resources); - if constexpr (is_python_v) { + if constexpr (is_x_lang_v) { using ArgsTuple = std::tuple; Arguments::WrapArgs(remote_function_holder_.lang_type, &args_, diff --git a/cpp/include/ray/api/type_traits.h b/cpp/include/ray/api/type_traits.h index b7da5052a1641..92cd68ea71ef8 100644 --- a/cpp/include/ray/api/type_traits.h +++ b/cpp/include/ray/api/type_traits.h @@ -72,7 +72,7 @@ template auto constexpr is_java_v = is_java_t::value; template -auto constexpr is_x_lang_v = is_java_t::value || is_python_t::value; +auto constexpr is_x_lang_v = is_java_v || is_python_v; } // namespace internal } // namespace ray \ No newline at end of file diff --git a/cpp/include/ray/api/xlang_function.h b/cpp/include/ray/api/xlang_function.h index 30ca212009021..bb78f1ea49636 100644 --- a/cpp/include/ray/api/xlang_function.h +++ b/cpp/include/ray/api/xlang_function.h @@ -59,6 +59,14 @@ struct JavaActorMethod { std::string function_name; }; +template +struct JavaFunction { + bool IsJava() { return true; } + R operator()() { return {}; } + std::string class_name; + std::string function_name; +}; + namespace internal { enum class LangType { diff --git a/cpp/src/ray/test/cluster/cluster_mode_test.cc b/cpp/src/ray/test/cluster/cluster_mode_test.cc index 37e59297d8eac..c556716d87026 100644 --- a/cpp/src/ray/test/cluster/cluster_mode_test.cc +++ b/cpp/src/ray/test/cluster/cluster_mode_test.cc @@ -259,6 +259,12 @@ TEST(RayClusterModeTest, JavaInvocationTest) { auto java_actor_ret = java_actor_handle.Task(ray::JavaActorMethod{"increase"}).Remote(2); EXPECT_EQ(3, *java_actor_ret.Get()); + + auto java_task_ret = + ray::Task(ray::JavaFunction{"io.ray.test.CrossLanguageInvocationTest", + "returnInputString"}) + .Remote("helloworld"); + EXPECT_EQ("helloworld", *java_task_ret.Get()); } TEST(RayClusterModeTest, MaxConcurrentTest) {