Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime env][core] Use Proto message RuntimeEnvInfo between user code and core_worker #22856

Merged
Merged
Prev Previous commit
Next Next commit
Fix lint
  • Loading branch information
Catch-Bull committed Mar 11, 2022
commit 59281a3de47506785d787f3b44884e410c46fa97
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io.ray.api.concurrencygroup.ConcurrencyGroup;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.api.runtimeenv.RuntimeEnvInfo;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -183,6 +184,10 @@ public Builder setPlacementGroup(PlacementGroup group, int bundleIndex) {
}

public ActorCreationOptions build() {
RuntimeEnvInfo runtimeEnvInfo =
new RuntimeEnvInfo.Builder()
.setSerializedRuntimeEnv(runtimeEnv != null ? runtimeEnv.toJsonBytes() : "{}")
.build();
return new ActorCreationOptions(
name,
lifetime,
Expand All @@ -193,7 +198,7 @@ public ActorCreationOptions build() {
group,
bundleIndex,
concurrencyGroups,
runtimeEnv != null ? runtimeEnv.toJsonBytes() : "",
runtimeEnvInfo.toJsonBytes(),
maxPendingCalls);
}

Expand Down
4 changes: 4 additions & 0 deletions java/api/src/main/java/io/ray/api/runtime/RayRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.ray.api.runtimecontext.ResourceValue;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.api.runtimeenv.RuntimeEnvInfo;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -284,4 +285,7 @@ <T> ActorHandle<T> createActor(

/** Create runtime env instance at runtime. */
RuntimeEnv createRuntimeEnv(Map<String, String> envVars);

/** Create runtime env info instance at runtime. */
RuntimeEnvInfo createRuntimeEnvInfo(String serializedRuntimeEnv);
}
23 changes: 23 additions & 0 deletions java/api/src/main/java/io/ray/api/runtimeenv/RuntimeEnvInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.ray.api.runtimeenv;

import io.ray.api.Ray;

/** This is an experimental API to let you set runtime environments info for your actors. */
public interface RuntimeEnvInfo {
Catch-Bull marked this conversation as resolved.
Show resolved Hide resolved

String toJsonBytes();

public static class Builder {

String serializedRuntimeEnv;

public Builder setSerializedRuntimeEnv(String serializedRuntimeEnv) {
this.serializedRuntimeEnv = serializedRuntimeEnv;
return this;
}

public RuntimeEnvInfo build() {
return Ray.internal().createRuntimeEnvInfo(serializedRuntimeEnv);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.api.runtimeenv.RuntimeEnvInfo;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.config.RunMode;
import io.ray.runtime.context.RuntimeContextImpl;
Expand All @@ -35,6 +36,7 @@
import io.ray.runtime.object.ObjectRefImpl;
import io.ray.runtime.object.ObjectStore;
import io.ray.runtime.runtimeenv.RuntimeEnvImpl;
import io.ray.runtime.runtimeenv.RuntimeEnvInfoImpl;
import io.ray.runtime.task.ArgumentsBuilder;
import io.ray.runtime.task.FunctionArg;
import io.ray.runtime.task.TaskExecutor;
Expand Down Expand Up @@ -293,6 +295,11 @@ public RuntimeEnv createRuntimeEnv(Map<String, String> envVars) {
return new RuntimeEnvImpl(envVars);
}

@Override
public RuntimeEnvInfo createRuntimeEnvInfo(String serializedRuntimeEnv) {
return new RuntimeEnvInfoImpl(serializedRuntimeEnv);
}

private ObjectRef callNormalFunction(
FunctionDescriptor functionDescriptor,
Object[] args,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package io.ray.runtime.runtimeenv;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import io.ray.api.runtimeenv.RuntimeEnvInfo;
import io.ray.runtime.generated.RuntimeEnvCommon;

public class RuntimeEnvInfoImpl implements RuntimeEnvInfo {

private String serializedRuntimeEnv = "{}";

public RuntimeEnvInfoImpl(String serializedRuntimeEnv) {
this.serializedRuntimeEnv = serializedRuntimeEnv;
}

@Override
public String toJsonBytes() {
if (serializedRuntimeEnv.equals("{}") || serializedRuntimeEnv.isEmpty()) {
return "{}";
}
RuntimeEnvCommon.RuntimeEnvInfo.Builder protoRuntimeEnvInfoBuilder =
RuntimeEnvCommon.RuntimeEnvInfo.newBuilder();
protoRuntimeEnvInfoBuilder.setSerializedRuntimeEnv(serializedRuntimeEnv);
JsonFormat.Printer printer = JsonFormat.printer();
try {
return printer.print(protoRuntimeEnvInfoBuilder);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}
}
2 changes: 2 additions & 0 deletions python/ray/tests/test_runtime_env_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def get(self, key):
def test_environment_variables_nested_task(ray_start_regular):
@ray.remote
def get_env(key):
print(os.environ)
return os.environ.get(key)

@ray.remote
def get_env_wrapper(key):
assert os.environ.get(key) == "b"
Catch-Bull marked this conversation as resolved.
Show resolved Hide resolved
return ray.get(get_env.remote(key))

assert (
Expand Down
19 changes: 4 additions & 15 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1448,21 +1448,12 @@ std::shared_ptr<rpc::RuntimeEnvInfo> CoreWorker::OverrideTaskOrActorRuntimeEnvIn
std::shared_ptr<rpc::RuntimeEnvInfo> runtime_env_info = nullptr;
runtime_env_info.reset(new rpc::RuntimeEnvInfo());

if (IsRuntimeEnvInfoEmpty(serialized_runtime_env_info)) {
runtime_env_info->set_serialized_runtime_env(
job_config_->runtime_env_info().serialized_runtime_env());
runtime_env_info->clear_uris();
for (const std::string &uri : GetUrisFromRuntimeEnv(job_runtime_env_.get())) {
runtime_env_info->add_uris(uri);
}

return runtime_env_info;
if (!IsRuntimeEnvInfoEmpty(serialized_runtime_env_info)) {
RAY_CHECK(google::protobuf::util::JsonStringToMessage(serialized_runtime_env_info,
runtime_env_info.get())
.ok());
}

RAY_CHECK(google::protobuf::util::JsonStringToMessage(serialized_runtime_env_info,
runtime_env_info.get())
.ok());

if (options_.worker_type == WorkerType::DRIVER) {
if (IsRuntimeEnvEmpty(runtime_env_info->serialized_runtime_env())) {
runtime_env_info->set_serialized_runtime_env(
Expand Down Expand Up @@ -1507,13 +1498,11 @@ std::shared_ptr<rpc::RuntimeEnvInfo> CoreWorker::OverrideTaskOrActorRuntimeEnvIn
RAY_CHECK(google::protobuf::util::MessageToJsonString(
override_runtime_env, &serialized_override_runtime_env)
.ok());

runtime_env_info->set_serialized_runtime_env(serialized_override_runtime_env);
runtime_env_info->clear_uris();
for (const std::string &uri : GetUrisFromRuntimeEnv(&override_runtime_env)) {
runtime_env_info->add_uris(uri);
}

return runtime_env_info;
} else {
return runtime_env_info;
Expand Down
32 changes: 21 additions & 11 deletions src/ray/raylet/scheduling/cluster_task_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,15 @@ std::shared_ptr<ClusterResourceScheduler> CreateSingleNodeScheduler(

RayTask CreateTask(const std::unordered_map<std::string, double> &required_resources,
int num_args = 0, std::vector<ObjectID> args = {},
const std::string &serialized_runtime_env = "{}",
const std::vector<std::string> &runtime_env_uris = {}) {
std::shared_ptr<rpc::RuntimeEnvInfo> runtime_env_info = nullptr) {
Catch-Bull marked this conversation as resolved.
Show resolved Hide resolved
TaskSpecBuilder spec_builder;
TaskID id = RandomTaskId();
JobID job_id = RandomJobId();
rpc::Address address;
spec_builder.SetCommonTaskSpec(id, "dummy_task", Language::PYTHON,
FunctionDescriptorBuilder::BuildPython("", "", "", ""),
job_id, TaskID::Nil(), 0, TaskID::Nil(), address, 0,
required_resources, {}, "", 0, serialized_runtime_env,
runtime_env_uris);
required_resources, {}, "", 0, runtime_env_info);

if (!args.empty()) {
for (auto &arg : args) {
Expand Down Expand Up @@ -474,8 +472,12 @@ TEST_F(ClusterTaskManagerTest, DispatchQueueNonBlockingTest) {
{ray::kCPU_ResourceLabel, 4}};

std::string serialized_runtime_env_A = "mock_env_A";
RayTask task_A = CreateTask(required_resources, /*num_args=*/0, /*args=*/{},
serialized_runtime_env_A);
std::shared_ptr<rpc::RuntimeEnvInfo> runtime_env_info_A = nullptr;
runtime_env_info_A.reset(new rpc::RuntimeEnvInfo());
runtime_env_info_A->set_serialized_runtime_env(serialized_runtime_env_A);

RayTask task_A =
CreateTask(required_resources, /*num_args=*/0, /*args=*/{}, runtime_env_info_A);
rpc::RequestWorkerLeaseReply reply_A;
bool callback_occurred = false;
bool *callback_occurred_ptr = &callback_occurred;
Expand All @@ -485,10 +487,14 @@ TEST_F(ClusterTaskManagerTest, DispatchQueueNonBlockingTest) {
};

std::string serialized_runtime_env_B = "mock_env_B";
RayTask task_B_1 = CreateTask(required_resources, /*num_args=*/0, /*args=*/{},
serialized_runtime_env_B);
RayTask task_B_2 = CreateTask(required_resources, /*num_args=*/0, /*args=*/{},
serialized_runtime_env_B);
std::shared_ptr<rpc::RuntimeEnvInfo> runtime_env_info_B = nullptr;
runtime_env_info_B.reset(new rpc::RuntimeEnvInfo());
runtime_env_info_B->set_serialized_runtime_env(serialized_runtime_env_B);

RayTask task_B_1 =
CreateTask(required_resources, /*num_args=*/0, /*args=*/{}, runtime_env_info_B);
RayTask task_B_2 =
CreateTask(required_resources, /*num_args=*/0, /*args=*/{}, runtime_env_info_B);
rpc::RequestWorkerLeaseReply reply_B_1;
rpc::RequestWorkerLeaseReply reply_B_2;
auto empty_callback = [](Status, std::function<void()>, std::function<void()>) {};
Expand Down Expand Up @@ -1785,8 +1791,12 @@ TEST_F(ClusterTaskManagerTest, TestResourceDiff) {
TEST_F(ClusterTaskManagerTest, PopWorkerExactlyOnce) {
// Create and queue one task.
std::string serialized_runtime_env = "mock_env";
std::shared_ptr<rpc::RuntimeEnvInfo> runtime_env_info = nullptr;
runtime_env_info.reset(new rpc::RuntimeEnvInfo());
runtime_env_info->set_serialized_runtime_env(serialized_runtime_env);

RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 4}}, /*num_args=*/0, /*args=*/{},
serialized_runtime_env);
runtime_env_info);
auto runtime_env_hash = task.GetTaskSpecification().GetRuntimeEnvHash();
rpc::RequestWorkerLeaseReply reply;
bool callback_occurred = false;
Expand Down