Skip to content

Commit

Permalink
[runtime env][core] Use Proto message RuntimeEnvInfo between user c…
Browse files Browse the repository at this point in the history
…ode and core_worker (#22856)
  • Loading branch information
Catch-Bull authored Mar 11, 2022
1 parent 965d609 commit 0cbbb8c
Show file tree
Hide file tree
Showing 15 changed files with 249 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,32 @@ public RuntimeEnvImpl(Map<String, String> envVars) {

@Override
public String toJsonBytes() {
// Get serializedRuntimeEnv
String serializedRuntimeEnv = "{}";
if (!envVars.isEmpty()) {
RuntimeEnvCommon.RuntimeEnv.Builder protoRuntimeEnvBuilder =
RuntimeEnvCommon.RuntimeEnv.newBuilder();
protoRuntimeEnvBuilder.putAllEnvVars(envVars);
JsonFormat.Printer printer = JsonFormat.printer();
try {
return printer.print(protoRuntimeEnvBuilder);
serializedRuntimeEnv = printer.print(protoRuntimeEnvBuilder);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}
return "{}";

// Get serializedRuntimeEnvInfo
if (serializedRuntimeEnv.equals("{}") || serializedRuntimeEnv.isEmpty()) {
return "{}";
}
RuntimeEnvCommon.RuntimeEnvInfo.Builder protoRuntimeEnvInfoBuilder =
RuntimeEnvCommon.RuntimeEnvInfo.newBuilder();
protoRuntimeEnvInfoBuilder.setSerializedRuntimeEnv(serializedRuntimeEnv);
JsonFormat.Printer printer = JsonFormat.printer();
try {
return printer.print(protoRuntimeEnvInfoBuilder);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}
}
8 changes: 4 additions & 4 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ cdef class CoreWorker:
c_bool retry_exceptions,
scheduling_strategy,
c_string debugger_breakpoint,
c_string serialized_runtime_env,
c_string serialized_runtime_env_info,
):
cdef:
unordered_map[c_string, double] c_resources
Expand Down Expand Up @@ -1523,7 +1523,7 @@ cdef class CoreWorker:
ray_function, args_vector, CTaskOptions(
name, num_returns, c_resources,
b"",
serialized_runtime_env),
serialized_runtime_env_info),
max_retries, retry_exceptions,
c_scheduling_strategy,
debugger_breakpoint)
Expand Down Expand Up @@ -1555,7 +1555,7 @@ cdef class CoreWorker:
c_string ray_namespace,
c_bool is_asyncio,
c_string extension_data,
c_string serialized_runtime_env,
c_string serialized_runtime_env_info,
concurrency_groups_dict,
int32_t max_pending_calls,
scheduling_strategy,
Expand Down Expand Up @@ -1600,7 +1600,7 @@ cdef class CoreWorker:
ray_namespace,
is_asyncio,
c_scheduling_strategy,
serialized_runtime_env,
serialized_runtime_env_info,
c_concurrency_groups,
# execute out of order for
# async or threaded actors.
Expand Down
55 changes: 13 additions & 42 deletions python/ray/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ray.ray_constants as ray_constants
import ray._raylet
import ray._private.signature as signature
from ray.runtime_env import RuntimeEnv
from ray.utils import get_runtime_env_info, parse_runtime_env
import ray.worker
from ray.util.annotations import PublicAPI
from ray.util.placement_group import configure_placement_group_based_on_context
Expand Down Expand Up @@ -464,16 +464,7 @@ def __init__(self, *args, **kwargs):
modified_class.__ray_actor_class__
)

# Parse local pip/conda config files here. If we instead did it in
# .remote(), it would get run in the Ray Client server, which runs on
# a remote node where the files aren't available.
if runtime_env:
if isinstance(runtime_env, str):
new_runtime_env = runtime_env
else:
new_runtime_env = RuntimeEnv(**runtime_env).serialize()
else:
new_runtime_env = None
new_runtime_env = parse_runtime_env(runtime_env)

self.__ray_metadata__ = ActorClassMetadata(
Language.PYTHON,
Expand Down Expand Up @@ -512,16 +503,7 @@ def _ray_from_function_descriptor(
):
self = ActorClass.__new__(ActorClass)

# Parse local pip/conda config files here. If we instead did it in
# .remote(), it would get run in the Ray Client server, which runs on
# a remote node where the files aren't available.
if runtime_env:
if isinstance(runtime_env, str):
new_runtime_env = runtime_env
else:
new_runtime_env = RuntimeEnv(**runtime_env).serialize()
else:
new_runtime_env = None
new_runtime_env = parse_runtime_env(runtime_env)

self.__ray_metadata__ = ActorClassMetadata(
language,
Expand Down Expand Up @@ -600,19 +582,7 @@ def method(self):

actor_cls = self

# Parse local pip/conda config files here. If we instead did it in
# .remote(), it would get run in the Ray Client server, which runs on
# a remote node where the files aren't available.
if runtime_env:
if isinstance(runtime_env, str):
new_runtime_env = runtime_env
else:
new_runtime_env = RuntimeEnv(**(runtime_env or {})).serialize()
else:
# Keep the new_runtime_env as None. In .remote(), we need to know
# if runtime_env is None to know whether or not to fall back to the
# runtime_env specified in the @ray.remote decorator.
new_runtime_env = None
new_runtime_env = parse_runtime_env(runtime_env)

cls_options = dict(
num_cpus=num_cpus,
Expand Down Expand Up @@ -966,15 +936,16 @@ def _remote(
scheduling_strategy = "DEFAULT"

if runtime_env:
if isinstance(runtime_env, str):
# Serialzed protobuf runtime env from Ray client.
new_runtime_env = runtime_env
elif isinstance(runtime_env, RuntimeEnv):
new_runtime_env = runtime_env.serialize()
else:
raise TypeError(f"Error runtime env type {type(runtime_env)}")
new_runtime_env = parse_runtime_env(runtime_env)
else:
new_runtime_env = meta.runtime_env
serialized_runtime_env_info = None
if new_runtime_env is not None:
serialized_runtime_env_info = get_runtime_env_info(
new_runtime_env,
is_job_runtime_env=False,
serialize=True,
)

concurrency_groups_dict = {}
for cg_name in meta.concurrency_groups:
Expand Down Expand Up @@ -1021,7 +992,7 @@ def _remote(
is_asyncio,
# Store actor_method_cpu in actor handle's extension data.
extension_data=str(actor_method_cpu),
serialized_runtime_env=new_runtime_env or "{}",
serialized_runtime_env_info=serialized_runtime_env_info or "{}",
concurrency_groups_dict=concurrency_groups_dict or dict(),
max_pending_calls=max_pending_calls,
scheduling_strategy=scheduling_strategy,
Expand Down
30 changes: 18 additions & 12 deletions python/ray/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def set_runtime_env(
"""
self.runtime_env = runtime_env if runtime_env is not None else {}
if validate:
self.runtime_env = self._validate_runtime_env()[0]
self.runtime_env = self._validate_runtime_env()
self._cached_pb = None

def set_ray_namespace(self, ray_namespace: str) -> None:
Expand All @@ -91,15 +91,17 @@ def _validate_runtime_env(self):
# this dependency and pass in a validated runtime_env instead.
from ray.runtime_env import RuntimeEnv

eager_install = self.runtime_env.get("eager_install", True)
if not isinstance(eager_install, bool):
raise TypeError("eager_install must be a boolean.")
if isinstance(self.runtime_env, RuntimeEnv):
return self.runtime_env, eager_install
return RuntimeEnv(**self.runtime_env), eager_install
return self.runtime_env
return RuntimeEnv(**self.runtime_env)

def get_proto_job_config(self):
"""Return the protobuf structure of JobConfig."""
# TODO(edoakes): this is really unfortunate, but JobConfig is imported
# all over the place so this causes circular imports. We should remove
# this dependency and pass in a validated runtime_env instead.
from ray.utils import get_runtime_env_info

if self._cached_pb is None:
pb = gcs_utils.JobConfig()
if self.ray_namespace is None:
Expand All @@ -112,10 +114,14 @@ def get_proto_job_config(self):
for k, v in self.metadata.items():
pb.metadata[k] = v

parsed_env, eager_install = self._validate_runtime_env()
pb.runtime_env_info.uris[:] = parsed_env.get_uris()
pb.runtime_env_info.serialized_runtime_env = parsed_env.serialize()
pb.runtime_env_info.runtime_env_eager_install = eager_install
parsed_env = self._validate_runtime_env()
pb.runtime_env_info.CopyFrom(
get_runtime_env_info(
parsed_env,
is_job_runtime_env=True,
serialize=False,
)
)

if self._default_actor_lifetime is not None:
pb.default_actor_lifetime = self._default_actor_lifetime
Expand All @@ -125,11 +131,11 @@ def get_proto_job_config(self):

def runtime_env_has_uris(self):
"""Whether there are uris in runtime env or not"""
return self._validate_runtime_env()[0].has_uris()
return self._validate_runtime_env().has_uris()

def get_serialized_runtime_env(self) -> str:
"""Return the JSON-serialized parsed runtime env dict"""
return self._validate_runtime_env()[0].serialize()
return self._validate_runtime_env().serialize()

@classmethod
def from_json(cls, job_config_json):
Expand Down
39 changes: 13 additions & 26 deletions python/ray/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray._private.client_mode_hook import client_mode_should_convert
from ray.util.placement_group import configure_placement_group_based_on_context
import ray._private.signature
from ray.runtime_env import RuntimeEnv
from ray.utils import get_runtime_env_info, parse_runtime_env
from ray.util.tracing.tracing_helper import (
_tracing_task_invocation,
_inject_tracing_into_function,
Expand Down Expand Up @@ -139,16 +139,9 @@ def __init__(
if retry_exceptions is None
else retry_exceptions
)
# Parse local pip/conda config files here. If we instead did it in
# .remote(), it would get run in the Ray Client server, which runs on
# a remote node where the files aren't available.
if runtime_env:
if isinstance(runtime_env, str):
self._runtime_env = runtime_env
else:
self._runtime_env = RuntimeEnv(**(runtime_env or {})).serialize()
else:
self._runtime_env = None

self._runtime_env = parse_runtime_env(runtime_env)

self._placement_group = placement_group
self._decorator = getattr(function, "__ray_invocation_decorator__", None)
self._function_signature = ray._private.signature.extract_signature(
Expand Down Expand Up @@ -211,20 +204,7 @@ def f():
"""

func_cls = self
# Parse local pip/conda config files here. If we instead did it in
# .remote(), it would get run in the Ray Client server, which runs on
# a remote node where the files aren't available.
if runtime_env:
if isinstance(runtime_env, str):
# Serialzed protobuf runtime env from Ray client.
new_runtime_env = runtime_env
else:
new_runtime_env = RuntimeEnv(**runtime_env).serialize()
else:
# Keep the runtime_env as None. In .remote(), we need to know if
# runtime_env is None to know whether or not to fall back to the
# runtime_env specified in the @ray.remote decorator.
new_runtime_env = None
new_runtime_env = parse_runtime_env(runtime_env)

options = dict(
num_returns=num_returns,
Expand Down Expand Up @@ -419,6 +399,13 @@ def _remote(

if not runtime_env or runtime_env == "{}":
runtime_env = self._runtime_env
serialized_runtime_env_info = None
if runtime_env is not None:
serialized_runtime_env_info = get_runtime_env_info(
runtime_env,
is_job_runtime_env=False,
serialize=True,
)

def invocation(args, kwargs):
if self._is_cross_language:
Expand All @@ -445,7 +432,7 @@ def invocation(args, kwargs):
retry_exceptions,
scheduling_strategy,
worker.debugger_breakpoint,
runtime_env or "{}",
serialized_runtime_env_info or "{}",
)
# Reset worker's debug context from the last "remote" command
# (which applies only to this .remote call).
Expand Down
10 changes: 10 additions & 0 deletions python/ray/runtime_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,13 @@ def _build_proto_plugin_runtime_env(self, runtime_env: ProtoRuntimeEnv):
plugin = runtime_env.python_runtime_env.plugin_runtime_env.plugins.add()
plugin.class_path = class_path
plugin.config = plugin_field

def __getstate__(self):
# When pickle serialization, exclude some fields
# which can't be serialized by pickle
return dict(**self)

def __setstate__(self, state):
for k, v in state.items():
self[k] = v
self.__proto_runtime_env = None
1 change: 1 addition & 0 deletions python/ray/tests/test_runtime_env_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get(self, key):
def test_environment_variables_nested_task(ray_start_regular):
@ray.remote
def get_env(key):
print(os.environ)
return os.environ.get(key)

@ray.remote
Expand Down
Loading

0 comments on commit 0cbbb8c

Please sign in to comment.