Skip to content

Commit

Permalink
[Java] Release actor instance reference when Ray.exitActor() is inv…
Browse files Browse the repository at this point in the history
  • Loading branch information
kfstorm committed Oct 14, 2020
1 parent c926838 commit abc6126
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,11 @@ public ObjectStore getObjectStore() {
return objectStore;
}

@Override
public TaskExecutor getTaskExecutor() {
return taskExecutor;
}

@Override
public FunctionManager getFunctionManager() {
return functionManager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.ray.runtime.functionmanager.FunctionManager;
import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.object.ObjectStore;
import io.ray.runtime.task.TaskExecutor;

/**
* This interface is required to make {@link RayRuntimeProxy} work.
Expand All @@ -21,6 +22,8 @@ public interface RayRuntimeInternal extends RayRuntime {

ObjectStore getObjectStore();

TaskExecutor getTaskExecutor();

FunctionManager getFunctionManager();

RayConfig getRayConfig();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ protected NativeActorContext createActorContext() {
return new NativeActorContext();
}

public void onWorkerShutdown(byte[] workerIdBytes) {
// This is to make sure no memory leak when `Ray.exitActor()` is called.
removeActorContext(new UniqueId(workerIdBytes));
}

@Override
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
if (!(actor instanceof Checkpointable)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ void setActorContext(T actorContext) {
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
}

protected void removeActorContext(UniqueId workerId) {
this.actorContextMap.remove(workerId);
}

private RayFunction getRayFunction(List<String> rayFunctionInfo) {
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
Expand Down
22 changes: 22 additions & 0 deletions java/test/src/main/java/io/ray/test/ExitActorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
import io.ray.api.id.ActorId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.exception.RayActorException;
import io.ray.runtime.task.TaskExecutor;
import io.ray.runtime.util.SystemUtil;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.testng.Assert;
import org.testng.annotations.Test;
Expand All @@ -31,6 +34,17 @@ public int getPid() {
return pid();
}

public int getSizeOfActorContextMap() {
TaskExecutor taskExecutor = TestUtils.getRuntime().getTaskExecutor();
try {
Field field = TaskExecutor.class.getDeclaredField("actorContextMap");
field.setAccessible(true);
return ((Map<?, ?>)field.get(taskExecutor)).size();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

@Override
public boolean shouldCheckpoint(CheckpointContext checkpointContext) {
return true;
Expand Down Expand Up @@ -77,6 +91,8 @@ public void testExitActorInMultiWorker() {
ActorHandle<ExitingActor> actor1 = Ray.actor(ExitingActor::new)
.setMaxRestarts(10000).remote();
int pid = actor1.task(ExitingActor::getPid).remote().get();
Assert.assertEquals(
1, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get());
ActorHandle<ExitingActor> actor2;
while (true) {
// Create another actor which share the same process of actor 1.
Expand All @@ -86,11 +102,17 @@ public void testExitActorInMultiWorker() {
break;
}
}
Assert.assertEquals(
2, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get());
Assert.assertEquals(
2, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get());
ObjectRef<Boolean> obj1 = actor1.task(ExitingActor::exit).remote();
Assert.assertThrows(RayActorException.class, obj1::get);
Assert.assertTrue(SystemUtil.isProcessAlive(pid));
// Actor 2 shouldn't exit or be reconstructed.
Assert.assertEquals(1, (int) actor2.task(ExitingActor::incr).remote().get());
Assert.assertEquals(
1, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get());
Assert.assertEquals(pid, (int) actor2.task(ExitingActor::getPid).remote().get());
Assert.assertTrue(SystemUtil.isProcessAlive(pid));
}
Expand Down
2 changes: 2 additions & 0 deletions python/ray/includes/libcoreworker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ from ray.includes.unique_ids cimport (
CTaskID,
CObjectID,
CPlacementGroupID,
CWorkerID,
)
from ray.includes.common cimport (
CAddress,
Expand Down Expand Up @@ -227,6 +228,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const c_vector[CObjectID] &return_ids,
c_vector[shared_ptr[CRayObject]] *returns) nogil
) task_execution_callback
(void(const CWorkerID &) nogil) on_worker_shutdown
(CRayStatus() nogil) check_signals
(void() nogil) gc_collect
(c_vector[c_string](const c_vector[CObjectID]&) nogil) spill_objects
Expand Down
3 changes: 3 additions & 0 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,9 @@ void CoreWorker::Shutdown() {
if (options_.worker_type == WorkerType::WORKER) {
task_execution_service_.stop();
}
if (options_.on_worker_shutdown) {
options_.on_worker_shutdown(GetWorkerID());
}
}

void CoreWorker::Disconnect() {
Expand Down
2 changes: 2 additions & 0 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ struct CoreWorkerOptions {
std::string stderr_file;
/// Language worker callback to execute tasks.
TaskExecutionCallback task_execution_callback;
/// The callback to be called when shutting down a `CoreWorker` instance.
std::function<void(const WorkerID &)> on_worker_shutdown;
/// Application-language callback to check for signals that have been received
/// since calling into C++. This will be called periodically (at least every
/// 1s) during long-running operations. If the function returns anything but StatusOK,
Expand Down
12 changes: 12 additions & 0 deletions src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
if (throwable &&
env->IsInstanceOf(throwable,
java_ray_intentional_system_exit_exception_class)) {
env->ExceptionClear();
return ray::Status::IntentionalSystemExit();
}
RAY_CHECK_JAVA_EXCEPTION(env);
Expand Down Expand Up @@ -211,6 +212,16 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
}
};

auto on_worker_shutdown = [](const ray::WorkerID &worker_id) {
JNIEnv *env = GetJNIEnv();
auto worker_id_bytes = IdToJavaByteArray<ray::WorkerID>(env, worker_id);
if (java_task_executor) {
env->CallVoidMethod(java_task_executor,
java_native_task_executor_on_worker_shutdown, worker_id_bytes);
RAY_CHECK_JAVA_EXCEPTION(env);
}
};

std::string serialized_job_config =
(jobConfig == nullptr ? "" : JavaByteArrayToNativeString(env, jobConfig));
ray::CoreWorkerOptions options;
Expand All @@ -229,6 +240,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
options.raylet_ip_address = JavaStringToNativeString(env, nodeIpAddress);
options.driver_name = JavaStringToNativeString(env, driverName);
options.task_execution_callback = task_execution_callback;
options.on_worker_shutdown = on_worker_shutdown;
options.gc_collect = gc_collect;
options.ref_counting_enabled = true;
options.num_workers = static_cast<int>(numWorkersPerProcess);
Expand Down
8 changes: 8 additions & 0 deletions src/ray/core_worker/lib/java/jni_init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ jclass java_task_executor_class;
jmethodID java_task_executor_parse_function_arguments;
jmethodID java_task_executor_execute;

jclass java_native_task_executor_class;
jmethodID java_native_task_executor_on_worker_shutdown;

jclass java_placement_group_class;
jfieldID java_placement_group_id;

Expand Down Expand Up @@ -267,6 +270,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
java_task_executor_execute =
env->GetMethodID(java_task_executor_class, "execute",
"(Ljava/util/List;Ljava/util/List;)Ljava/util/List;");
java_native_task_executor_class =
LoadClass(env, "io/ray/runtime/task/NativeTaskExecutor");
java_native_task_executor_on_worker_shutdown =
env->GetMethodID(java_native_task_executor_class, "onWorkerShutdown", "([B)V");
return CURRENT_JNI_VERSION;
}

Expand Down Expand Up @@ -298,4 +305,5 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) {
env->DeleteGlobalRef(java_actor_creation_options_class);
env->DeleteGlobalRef(java_native_ray_object_class);
env->DeleteGlobalRef(java_task_executor_class);
env->DeleteGlobalRef(java_native_task_executor_class);
}
5 changes: 5 additions & 0 deletions src/ray/core_worker/lib/java/jni_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ extern jmethodID java_task_executor_parse_function_arguments;
/// execute method of TaskExecutor class
extern jmethodID java_task_executor_execute;

/// NativeTaskExecutor class
extern jclass java_native_task_executor_class;
/// onWorkerShutdown method of NativeTaskExecutor class
extern jmethodID java_native_task_executor_on_worker_shutdown;

/// PlacementGroup class
extern jclass java_placement_group_class;
/// id field of PlacementGroup class
Expand Down

0 comments on commit abc6126

Please sign in to comment.