Skip to content

Commit

Permalink
[Java] some small improvements (ray-project#14565)
Browse files Browse the repository at this point in the history
  • Loading branch information
kfstorm authored Mar 12, 2021
1 parent 9cf328d commit f60bd3a
Show file tree
Hide file tree
Showing 13 changed files with 121 additions and 149 deletions.
4 changes: 2 additions & 2 deletions java/api/src/main/java/io/ray/api/ActorHandle.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
* }
* }
* // Create an actor, and get a handle.
* ActorHandle<MyActor> myActor = Ray.createActor(MyActor::new);
* ActorHandle<MyActor> myActor = Ray.actor(MyActor::new).remote();
* // Call the `echo` method remotely.
* ObjectRef<Integer> result = myActor.call(MyActor::echo, 1);
* ObjectRef<Integer> result = myActor.task(MyActor::echo, 1).remote();
* // Get the result of the remote `echo` method.
* Assert.assertEqual(result.get(), 1);
* }</pre>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import java.util.List;

/** A class used for getting information of Ray runtime. */
Expand All @@ -10,6 +11,9 @@ public interface RuntimeContext {
/** Get the current Job ID. */
JobId getCurrentJobId();

/** Get current task ID. */
TaskId getCurrentTaskId();

/**
* Get the current actor ID.
*
Expand Down
38 changes: 35 additions & 3 deletions java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.config.RunMode;
import io.ray.runtime.context.RuntimeContextImpl;
import io.ray.runtime.context.WorkerContext;
import io.ray.runtime.functionmanager.FunctionDescriptor;
Expand Down Expand Up @@ -71,6 +72,9 @@ public AbstractRayRuntime(RayConfig rayConfig) {

@Override
public <T> ObjectRef<T> put(T obj) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Putting Object in Task {}.", workerContext.getCurrentTaskId());
}
ObjectId objectId = objectStore.put(obj);
return new ObjectRefImpl<T>(objectId, (Class<T>) (obj == null ? Object.class : obj.getClass()));
}
Expand All @@ -90,21 +94,30 @@ public <T> List<T> get(List<ObjectRef<T>> objectRefs) {
objectIds.add(objectRefImpl.getId());
objectType = objectRefImpl.getType();
}
LOGGER.debug("Getting Objects {}.", objectIds);
return objectStore.get(objectIds, objectType);
}

@Override
public void free(List<ObjectRef<?>> objectRefs, boolean localOnly) {
objectStore.delete(
List<ObjectId> objectIds =
objectRefs.stream()
.map(ref -> ((ObjectRefImpl<?>) ref).getId())
.collect(Collectors.toList()),
localOnly);
.collect(Collectors.toList());
LOGGER.debug("Freeing Objects {}, localOnly = {}.", objectIds, localOnly);
objectStore.delete(objectIds, localOnly);
}

@Override
public <T> WaitResult<T> wait(
List<ObjectRef<T>> waitList, int numReturns, int timeoutMs, boolean fetchLocal) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"Waiting Objects {} with minimum number {} within {} ms.",
waitList,
numReturns,
timeoutMs);
}
return objectStore.wait(waitList, numReturns, timeoutMs, fetchLocal);
}

Expand Down Expand Up @@ -259,6 +272,9 @@ private ObjectRef callNormalFunction(
CallOptions options) {
int numReturns = returnType.isPresent() ? 1 : 0;
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
if (options == null) {
options = new CallOptions.Builder().build();
}
List<ObjectId> returnIds =
taskSubmitter.submitTask(functionDescriptor, functionArgs, numReturns, options);
Preconditions.checkState(returnIds.size() == numReturns);
Expand All @@ -275,6 +291,9 @@ private ObjectRef callActorFunction(
Object[] args,
Optional<Class<?>> returnType) {
int numReturns = returnType.isPresent() ? 1 : 0;
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Submitting Actor Task {}.", functionDescriptor);
}
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
List<ObjectId> returnIds =
taskSubmitter.submitActorTask(rayActor, functionDescriptor, functionArgs, numReturns, null);
Expand All @@ -288,6 +307,19 @@ private ObjectRef callActorFunction(

private BaseActorHandle createActorImpl(
FunctionDescriptor functionDescriptor, Object[] args, ActorCreationOptions options) {
if (LOGGER.isDebugEnabled()) {
if (options == null) {
LOGGER.debug("Creating Actor {} with default options.", functionDescriptor);
} else {
LOGGER.debug("Creating Actor {}, jvmOptions = {}.", functionDescriptor, options.jvmOptions);
}
}
if (rayConfig.runMode == RunMode.SINGLE_PROCESS
&& functionDescriptor.getLanguage() != Language.JAVA) {
throw new IllegalArgumentException(
"Ray doesn't support cross-language invocation in local mode.");
}

List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.ray.runtime.context;

import com.google.common.base.Preconditions;
import com.google.protobuf.ByteString;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
Expand All @@ -9,6 +10,7 @@
import io.ray.runtime.generated.Common.TaskSpec;
import io.ray.runtime.generated.Common.TaskType;
import io.ray.runtime.task.LocalModeTaskSubmitter;
import java.util.Random;

/** Worker context for local mode. */
public class LocalModeWorkerContext implements WorkerContext {
Expand All @@ -19,6 +21,14 @@ public class LocalModeWorkerContext implements WorkerContext {

public LocalModeWorkerContext(JobId jobId) {
this.jobId = jobId;

// Create a dummy driver task with a random task id, so that we can call
// `getCurrentTaskId` from a driver.
byte[] driverTaskId = new byte[TaskId.LENGTH];
new Random().nextBytes(driverTaskId);
TaskSpec dummyDriverTask =
TaskSpec.newBuilder().setTaskId(ByteString.copyFrom(driverTaskId)).build();
currentTask.set(dummyDriverTask);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.google.common.base.Preconditions;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import io.ray.api.runtimecontext.NodeInfo;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.runtime.RayRuntimeInternal;
Expand Down Expand Up @@ -30,6 +31,11 @@ public ActorId getCurrentActorId() {
return actorId;
}

@Override
public TaskId getCurrentTaskId() {
return runtime.getWorkerContext().getCurrentTaskId();
}

@Override
public boolean wasCurrentActorRestarted() {
if (isSingleProcess()) {
Expand Down
4 changes: 0 additions & 4 deletions java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
import io.ray.api.runtimecontext.NodeInfo;
import io.ray.runtime.generated.Gcs;
import io.ray.runtime.generated.Gcs.GcsNodeInfo;
import io.ray.runtime.generated.Gcs.TablePrefix;
import io.ray.runtime.placementgroup.PlacementGroupUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -132,8 +130,6 @@ public boolean actorExists(ActorId actorId) {
}

public boolean wasCurrentActorRestarted(ActorId actorId) {
byte[] key = ArrayUtils.addAll(TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes());

// TODO(ZhuSenlin): Get the actor table data from CoreWorker later.
byte[] value = globalStateAccessor.getActorInfo(actorId);
if (value == null) {
Expand Down
18 changes: 0 additions & 18 deletions java/runtime/src/main/java/io/ray/runtime/gcs/RedisClient.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.ray.runtime.gcs;

import com.google.common.base.Strings;
import java.util.List;
import java.util.Map;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
Expand Down Expand Up @@ -55,12 +54,6 @@ public String hmset(String key, Map<String, String> hash) {
}
}

public Map<byte[], byte[]> hgetAll(byte[] key) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.hgetAll(key);
}
}

public String get(final String key, final String field) {
try (Jedis jedis = jedisPool.getResource()) {
if (field == null) {
Expand All @@ -85,17 +78,6 @@ public byte[] get(byte[] key, byte[] field) {
}
}

/**
* Return the specified elements of the list stored at the specified key.
*
* @return Multi bulk reply, specifically a list of elements in the specified range.
*/
public List<byte[]> lrange(byte[] key, long start, long end) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.lrange(key, start, end);
}
}

/** Whether the key exists in Redis. */
public boolean exists(byte[] key) {
try (Jedis jedis = jedisPool.getResource()) {
Expand Down
60 changes: 37 additions & 23 deletions java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ protected List<NativeRayObject> execute(List<String> rayFunctionInfo, List<Objec
runtime.setIsContextSet(true);
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
LOGGER.debug("Executing task {}", taskId);
LOGGER.debug("Executing task {} {}", taskId, rayFunctionInfo);

T actorContext = null;
if (taskType == TaskType.ACTOR_CREATION_TASK) {
Expand All @@ -103,6 +103,8 @@ protected List<NativeRayObject> execute(List<String> rayFunctionInfo, List<Objec

List<NativeRayObject> returnObjects = new ArrayList<>();
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
// Find the executable object.

RayFunction rayFunction = localRayFunction.get();
try {
// Find the executable object.
Expand Down Expand Up @@ -133,6 +135,7 @@ protected List<NativeRayObject> execute(List<String> rayFunctionInfo, List<Objec
result = rayFunction.getConstructor().newInstance(args);
}
} catch (InvocationTargetException e) {
LOGGER.error("Execute rayFunction {} failed. actor {}, args {}", rayFunction, actor, args);
if (e.getCause() != null) {
throw e.getCause();
} else {
Expand All @@ -156,30 +159,41 @@ protected List<NativeRayObject> execute(List<String> rayFunctionInfo, List<Objec
throw (RayIntentionalSystemExitException) e;
}
LOGGER.error("Error executing task " + taskId, e);

if (taskType != TaskType.ACTOR_CREATION_TASK) {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
if (hasReturn || isCrossLanguage) {
NativeRayObject serializedException;
try {
serializedException =
ObjectSerializer.serialize(
new RayTaskException("Error executing task " + taskId, e));
} catch (Exception unserializable) {
// We should try-catch `ObjectSerializer.serialize` here. Because otherwise if the
// application-level exception is not serializable. `ObjectSerializer.serialize`
// will throw an exception and crash the worker.
// Refer to the case `TaskExceptionTest.java` for more details.
LOGGER.warn("Failed to serialize the exception to a RayObject.", unserializable);
serializedException =
ObjectSerializer.serialize(
new RayTaskException(
String.format(
"Error executing task %s with the exception: %s",
taskId, ExceptionUtils.getStackTrace(e))));
if (rayFunction != null) {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
if (hasReturn || isCrossLanguage) {
NativeRayObject serializedException;
try {
serializedException =
ObjectSerializer.serialize(
new RayTaskException("Error executing task " + taskId, e));
} catch (Exception unserializable) {
// We should try-catch `ObjectSerializer.serialize` here. Because otherwise if the
// application-level exception is not serializable. `ObjectSerializer.serialize`
// will throw an exception and crash the worker.
// Refer to the case `TaskExceptionTest.java` for more details.
LOGGER.warn("Failed to serialize the exception to a RayObject.", unserializable);
serializedException =
ObjectSerializer.serialize(
new RayTaskException(
String.format(
"Error executing task %s with the exception: %s",
taskId, ExceptionUtils.getStackTrace(e))));
}
Preconditions.checkNotNull(serializedException);
returnObjects.add(serializedException);
}
Preconditions.checkNotNull(serializedException);
returnObjects.add(serializedException);
} else {
returnObjects.add(
ObjectSerializer.serialize(
new RayTaskException(
String.format(
"Function %s of task %s doesn't exist",
String.join(".", rayFunctionInfo), taskId),
e)));
}
} else {
actorContext.actorCreationException = e;
Expand Down
32 changes: 0 additions & 32 deletions java/runtime/src/main/java/io/ray/runtime/util/NetworkUtil.java
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
package io.ray.runtime.util;

import com.google.common.base.Strings;
import java.io.IOException;
import java.net.DatagramSocket;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.ServerSocket;
import java.util.Enumeration;
import java.util.concurrent.ThreadLocalRandom;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NetworkUtil {

private static final Logger LOGGER = LoggerFactory.getLogger(NetworkUtil.class);

private static final int MIN_PORT = 10000;
private static final int MAX_PORT = 65535;

public static String getIpAddress(String interfaceName) {
try {
Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces();
Expand Down Expand Up @@ -50,29 +43,4 @@ public static String getIpAddress(String interfaceName) {

return "127.0.0.1";
}

public static int getUnusedPort() {
while (true) {
int port = ThreadLocalRandom.current().nextInt(MAX_PORT - MIN_PORT) + MIN_PORT;
if (isPortAvailable(port)) {
return port;
}
}
}

public static boolean isPortAvailable(int port) {
if (port < 1 || port > 65535) {
throw new IllegalArgumentException("Invalid start port: " + port);
}

try (ServerSocket ss = new ServerSocket(port);
DatagramSocket ds = new DatagramSocket(port)) {
ss.setReuseAddress(true);
ds.setReuseAddress(true);
return true;
} catch (IOException ignored) {
/* should not be thrown */
return false;
}
}
}
Loading

0 comments on commit f60bd3a

Please sign in to comment.