diff --git a/java/runtime/src/main/java/io/ray/runtime/runtimeenv/RuntimeEnvImpl.java b/java/runtime/src/main/java/io/ray/runtime/runtimeenv/RuntimeEnvImpl.java index cdc7b7d7f444e..ce3d736fdbeef 100644 --- a/java/runtime/src/main/java/io/ray/runtime/runtimeenv/RuntimeEnvImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/runtimeenv/RuntimeEnvImpl.java @@ -17,17 +17,32 @@ public RuntimeEnvImpl(Map envVars) { @Override public String toJsonBytes() { + // Get serializedRuntimeEnv + String serializedRuntimeEnv = "{}"; if (!envVars.isEmpty()) { RuntimeEnvCommon.RuntimeEnv.Builder protoRuntimeEnvBuilder = RuntimeEnvCommon.RuntimeEnv.newBuilder(); protoRuntimeEnvBuilder.putAllEnvVars(envVars); JsonFormat.Printer printer = JsonFormat.printer(); try { - return printer.print(protoRuntimeEnvBuilder); + serializedRuntimeEnv = printer.print(protoRuntimeEnvBuilder); } catch (InvalidProtocolBufferException e) { throw new RuntimeException(e); } } - return "{}"; + + // Get serializedRuntimeEnvInfo + if (serializedRuntimeEnv.equals("{}") || serializedRuntimeEnv.isEmpty()) { + return "{}"; + } + RuntimeEnvCommon.RuntimeEnvInfo.Builder protoRuntimeEnvInfoBuilder = + RuntimeEnvCommon.RuntimeEnvInfo.newBuilder(); + protoRuntimeEnvInfoBuilder.setSerializedRuntimeEnv(serializedRuntimeEnv); + JsonFormat.Printer printer = JsonFormat.printer(); + try { + return printer.print(protoRuntimeEnvInfoBuilder); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } } } diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 75591cb45b77d..bd0ac678cd831 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1495,7 +1495,7 @@ cdef class CoreWorker: c_bool retry_exceptions, scheduling_strategy, c_string debugger_breakpoint, - c_string serialized_runtime_env, + c_string serialized_runtime_env_info, ): cdef: unordered_map[c_string, double] c_resources @@ -1523,7 +1523,7 @@ cdef class CoreWorker: ray_function, args_vector, CTaskOptions( name, num_returns, c_resources, b"", - serialized_runtime_env), + serialized_runtime_env_info), max_retries, retry_exceptions, c_scheduling_strategy, debugger_breakpoint) @@ -1555,7 +1555,7 @@ cdef class CoreWorker: c_string ray_namespace, c_bool is_asyncio, c_string extension_data, - c_string serialized_runtime_env, + c_string serialized_runtime_env_info, concurrency_groups_dict, int32_t max_pending_calls, scheduling_strategy, @@ -1600,7 +1600,7 @@ cdef class CoreWorker: ray_namespace, is_asyncio, c_scheduling_strategy, - serialized_runtime_env, + serialized_runtime_env_info, c_concurrency_groups, # execute out of order for # async or threaded actors. diff --git a/python/ray/actor.py b/python/ray/actor.py index 4a462ee6891e4..416cc0e8c3f8b 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -5,7 +5,7 @@ import ray.ray_constants as ray_constants import ray._raylet import ray._private.signature as signature -from ray.runtime_env import RuntimeEnv +from ray.utils import get_runtime_env_info, parse_runtime_env import ray.worker from ray.util.annotations import PublicAPI from ray.util.placement_group import configure_placement_group_based_on_context @@ -464,16 +464,7 @@ def __init__(self, *args, **kwargs): modified_class.__ray_actor_class__ ) - # Parse local pip/conda config files here. If we instead did it in - # .remote(), it would get run in the Ray Client server, which runs on - # a remote node where the files aren't available. - if runtime_env: - if isinstance(runtime_env, str): - new_runtime_env = runtime_env - else: - new_runtime_env = RuntimeEnv(**runtime_env).serialize() - else: - new_runtime_env = None + new_runtime_env = parse_runtime_env(runtime_env) self.__ray_metadata__ = ActorClassMetadata( Language.PYTHON, @@ -512,16 +503,7 @@ def _ray_from_function_descriptor( ): self = ActorClass.__new__(ActorClass) - # Parse local pip/conda config files here. If we instead did it in - # .remote(), it would get run in the Ray Client server, which runs on - # a remote node where the files aren't available. - if runtime_env: - if isinstance(runtime_env, str): - new_runtime_env = runtime_env - else: - new_runtime_env = RuntimeEnv(**runtime_env).serialize() - else: - new_runtime_env = None + new_runtime_env = parse_runtime_env(runtime_env) self.__ray_metadata__ = ActorClassMetadata( language, @@ -600,19 +582,7 @@ def method(self): actor_cls = self - # Parse local pip/conda config files here. If we instead did it in - # .remote(), it would get run in the Ray Client server, which runs on - # a remote node where the files aren't available. - if runtime_env: - if isinstance(runtime_env, str): - new_runtime_env = runtime_env - else: - new_runtime_env = RuntimeEnv(**(runtime_env or {})).serialize() - else: - # Keep the new_runtime_env as None. In .remote(), we need to know - # if runtime_env is None to know whether or not to fall back to the - # runtime_env specified in the @ray.remote decorator. - new_runtime_env = None + new_runtime_env = parse_runtime_env(runtime_env) cls_options = dict( num_cpus=num_cpus, @@ -966,15 +936,16 @@ def _remote( scheduling_strategy = "DEFAULT" if runtime_env: - if isinstance(runtime_env, str): - # Serialzed protobuf runtime env from Ray client. - new_runtime_env = runtime_env - elif isinstance(runtime_env, RuntimeEnv): - new_runtime_env = runtime_env.serialize() - else: - raise TypeError(f"Error runtime env type {type(runtime_env)}") + new_runtime_env = parse_runtime_env(runtime_env) else: new_runtime_env = meta.runtime_env + serialized_runtime_env_info = None + if new_runtime_env is not None: + serialized_runtime_env_info = get_runtime_env_info( + new_runtime_env, + is_job_runtime_env=False, + serialize=True, + ) concurrency_groups_dict = {} for cg_name in meta.concurrency_groups: @@ -1021,7 +992,7 @@ def _remote( is_asyncio, # Store actor_method_cpu in actor handle's extension data. extension_data=str(actor_method_cpu), - serialized_runtime_env=new_runtime_env or "{}", + serialized_runtime_env_info=serialized_runtime_env_info or "{}", concurrency_groups_dict=concurrency_groups_dict or dict(), max_pending_calls=max_pending_calls, scheduling_strategy=scheduling_strategy, diff --git a/python/ray/job_config.py b/python/ray/job_config.py index d7a32b4c280f5..77c817feaa983 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -65,7 +65,7 @@ def set_runtime_env( """ self.runtime_env = runtime_env if runtime_env is not None else {} if validate: - self.runtime_env = self._validate_runtime_env()[0] + self.runtime_env = self._validate_runtime_env() self._cached_pb = None def set_ray_namespace(self, ray_namespace: str) -> None: @@ -91,15 +91,17 @@ def _validate_runtime_env(self): # this dependency and pass in a validated runtime_env instead. from ray.runtime_env import RuntimeEnv - eager_install = self.runtime_env.get("eager_install", True) - if not isinstance(eager_install, bool): - raise TypeError("eager_install must be a boolean.") if isinstance(self.runtime_env, RuntimeEnv): - return self.runtime_env, eager_install - return RuntimeEnv(**self.runtime_env), eager_install + return self.runtime_env + return RuntimeEnv(**self.runtime_env) def get_proto_job_config(self): """Return the protobuf structure of JobConfig.""" + # TODO(edoakes): this is really unfortunate, but JobConfig is imported + # all over the place so this causes circular imports. We should remove + # this dependency and pass in a validated runtime_env instead. + from ray.utils import get_runtime_env_info + if self._cached_pb is None: pb = gcs_utils.JobConfig() if self.ray_namespace is None: @@ -112,10 +114,14 @@ def get_proto_job_config(self): for k, v in self.metadata.items(): pb.metadata[k] = v - parsed_env, eager_install = self._validate_runtime_env() - pb.runtime_env_info.uris[:] = parsed_env.get_uris() - pb.runtime_env_info.serialized_runtime_env = parsed_env.serialize() - pb.runtime_env_info.runtime_env_eager_install = eager_install + parsed_env = self._validate_runtime_env() + pb.runtime_env_info.CopyFrom( + get_runtime_env_info( + parsed_env, + is_job_runtime_env=True, + serialize=False, + ) + ) if self._default_actor_lifetime is not None: pb.default_actor_lifetime = self._default_actor_lifetime @@ -125,11 +131,11 @@ def get_proto_job_config(self): def runtime_env_has_uris(self): """Whether there are uris in runtime env or not""" - return self._validate_runtime_env()[0].has_uris() + return self._validate_runtime_env().has_uris() def get_serialized_runtime_env(self) -> str: """Return the JSON-serialized parsed runtime env dict""" - return self._validate_runtime_env()[0].serialize() + return self._validate_runtime_env().serialize() @classmethod def from_json(cls, job_config_json): diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 65b736c6b812b..760db761a2d96 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -14,7 +14,7 @@ from ray._private.client_mode_hook import client_mode_should_convert from ray.util.placement_group import configure_placement_group_based_on_context import ray._private.signature -from ray.runtime_env import RuntimeEnv +from ray.utils import get_runtime_env_info, parse_runtime_env from ray.util.tracing.tracing_helper import ( _tracing_task_invocation, _inject_tracing_into_function, @@ -139,16 +139,9 @@ def __init__( if retry_exceptions is None else retry_exceptions ) - # Parse local pip/conda config files here. If we instead did it in - # .remote(), it would get run in the Ray Client server, which runs on - # a remote node where the files aren't available. - if runtime_env: - if isinstance(runtime_env, str): - self._runtime_env = runtime_env - else: - self._runtime_env = RuntimeEnv(**(runtime_env or {})).serialize() - else: - self._runtime_env = None + + self._runtime_env = parse_runtime_env(runtime_env) + self._placement_group = placement_group self._decorator = getattr(function, "__ray_invocation_decorator__", None) self._function_signature = ray._private.signature.extract_signature( @@ -211,20 +204,7 @@ def f(): """ func_cls = self - # Parse local pip/conda config files here. If we instead did it in - # .remote(), it would get run in the Ray Client server, which runs on - # a remote node where the files aren't available. - if runtime_env: - if isinstance(runtime_env, str): - # Serialzed protobuf runtime env from Ray client. - new_runtime_env = runtime_env - else: - new_runtime_env = RuntimeEnv(**runtime_env).serialize() - else: - # Keep the runtime_env as None. In .remote(), we need to know if - # runtime_env is None to know whether or not to fall back to the - # runtime_env specified in the @ray.remote decorator. - new_runtime_env = None + new_runtime_env = parse_runtime_env(runtime_env) options = dict( num_returns=num_returns, @@ -419,6 +399,13 @@ def _remote( if not runtime_env or runtime_env == "{}": runtime_env = self._runtime_env + serialized_runtime_env_info = None + if runtime_env is not None: + serialized_runtime_env_info = get_runtime_env_info( + runtime_env, + is_job_runtime_env=False, + serialize=True, + ) def invocation(args, kwargs): if self._is_cross_language: @@ -445,7 +432,7 @@ def invocation(args, kwargs): retry_exceptions, scheduling_strategy, worker.debugger_breakpoint, - runtime_env or "{}", + serialized_runtime_env_info or "{}", ) # Reset worker's debug context from the last "remote" command # (which applies only to this .remote call). diff --git a/python/ray/runtime_env.py b/python/ray/runtime_env.py index 6800fadaecb21..85442f7a482a3 100644 --- a/python/ray/runtime_env.py +++ b/python/ray/runtime_env.py @@ -575,3 +575,13 @@ def _build_proto_plugin_runtime_env(self, runtime_env: ProtoRuntimeEnv): plugin = runtime_env.python_runtime_env.plugin_runtime_env.plugins.add() plugin.class_path = class_path plugin.config = plugin_field + + def __getstate__(self): + # When pickle serialization, exclude some fields + # which can't be serialized by pickle + return dict(**self) + + def __setstate__(self, state): + for k, v in state.items(): + self[k] = v + self.__proto_runtime_env = None diff --git a/python/ray/tests/test_runtime_env_env_vars.py b/python/ray/tests/test_runtime_env_env_vars.py index 4991f11d84560..9a0f269f1905d 100644 --- a/python/ray/tests/test_runtime_env_env_vars.py +++ b/python/ray/tests/test_runtime_env_env_vars.py @@ -48,6 +48,7 @@ def get(self, key): def test_environment_variables_nested_task(ray_start_regular): @ray.remote def get_env(key): + print(os.environ) return os.environ.get(key) @ray.remote diff --git a/python/ray/utils.py b/python/ray/utils.py index aff4f519fc974..c01d6d304b8a4 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -1,4 +1,11 @@ +from typing import Dict, Union, Optional +from google.protobuf import json_format + import ray._private.utils as private_utils +from ray.runtime_env import RuntimeEnv +from ray.core.generated.runtime_env_common_pb2 import ( + RuntimeEnvInfo as ProtoRuntimeEnvInfo, +) deprecated = private_utils.deprecated( "If you need to use this function, open a feature request issue on " "GitHub.", @@ -7,3 +14,62 @@ ) get_system_memory = deprecated(private_utils.get_system_memory) + + +def get_runtime_env_info( + runtime_env: RuntimeEnv, + *, + is_job_runtime_env: bool = False, + serialize: bool = False, +): + """Create runtime env info from runtime env. + + In the user interface, the argument `runtime_env` contains some fields + which not contained in `ProtoRuntimeEnv` but in `ProtoRuntimeEnvInfo`, + such as `eager_install`. This function will extract those fields from + `RuntimeEnv` and create a new `ProtoRuntimeEnvInfo`, and serialize it. + """ + proto_runtime_env_info = ProtoRuntimeEnvInfo() + + proto_runtime_env_info.uris[:] = runtime_env.get_uris() + + # Normally, `RuntimeEnv` should guarantee the accuracy of field eager_install, + # but so far, the internal code has not completely prohibited direct + # modification of fields in RuntimeEnv, so we should check it for insurance. + # TODO(Catch-Bull): overload `__setitem__` for `RuntimeEnv`, change the + # runtime_env of all internal code from dict to RuntimeEnv. + + eager_install = runtime_env.get("eager_install") + if is_job_runtime_env or eager_install is not None: + if eager_install is None: + eager_install = True + elif not isinstance(eager_install, bool): + raise TypeError( + f"eager_install must be a boolean. got {type(eager_install)}" + ) + proto_runtime_env_info.runtime_env_eager_install = eager_install + + proto_runtime_env_info.serialized_runtime_env = runtime_env.serialize() + + if not serialize: + return proto_runtime_env_info + + return json_format.MessageToJson(proto_runtime_env_info) + + +def parse_runtime_env(runtime_env: Optional[Union[Dict, RuntimeEnv]]): + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + if runtime_env: + if isinstance(runtime_env, dict): + return RuntimeEnv(**(runtime_env or {})) + raise TypeError( + "runtime_env must be dict or RuntimeEnv, ", + f"but got: {type(runtime_env)}", + ) + else: + # Keep the new_runtime_env as None. In .remote(), we need to know + # if runtime_env is None to know whether or not to fall back to the + # runtime_env specified in the @ray.remote decorator. + return None diff --git a/src/ray/common/runtime_env_common.cc b/src/ray/common/runtime_env_common.cc index 891d390744758..a07c4e88cda90 100644 --- a/src/ray/common/runtime_env_common.cc +++ b/src/ray/common/runtime_env_common.cc @@ -19,4 +19,8 @@ bool IsRuntimeEnvEmpty(const std::string &serialized_runtime_env) { return serialized_runtime_env == "{}" || serialized_runtime_env == ""; } +bool IsRuntimeEnvInfoEmpty(const std::string &serialized_runtime_env_info) { + return serialized_runtime_env_info == "{}" || serialized_runtime_env_info == ""; +} + } // namespace ray diff --git a/src/ray/common/runtime_env_common.h b/src/ray/common/runtime_env_common.h index d025536615a59..89cc570fdf420 100644 --- a/src/ray/common/runtime_env_common.h +++ b/src/ray/common/runtime_env_common.h @@ -21,4 +21,9 @@ namespace ray { // or "{}" (from serializing an empty Python dict or a JSON file.) bool IsRuntimeEnvEmpty(const std::string &serialized_runtime_env); +// Return whether a string representation of a runtime env info represents an empty +// runtime env info. It could either be "" (from the default string value in protobuf), +// or "{}" (from serializing an empty Python dict or a JSON file.) +bool IsRuntimeEnvInfoEmpty(const std::string &serialized_runtime_env_info); + } // namespace ray diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index 9e837d7f2fe67..c043b42332749 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -104,8 +104,7 @@ class TaskSpecBuilder { const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, const std::string &debugger_breakpoint, int64_t depth, - const std::string &serialized_runtime_env = "{}", - const std::vector &runtime_env_uris = {}, + const std::shared_ptr runtime_env_info = nullptr, const std::string &concurrency_group_name = "") { message_->set_type(TaskType::NORMAL_TASK); message_->set_name(name); @@ -124,10 +123,8 @@ class TaskSpecBuilder { required_placement_resources.begin(), required_placement_resources.end()); message_->set_debugger_breakpoint(debugger_breakpoint); message_->set_depth(depth); - message_->mutable_runtime_env_info()->set_serialized_runtime_env( - serialized_runtime_env); - for (const std::string &uri : runtime_env_uris) { - message_->mutable_runtime_env_info()->add_uris(uri); + if (runtime_env_info) { + message_->mutable_runtime_env_info()->CopyFrom(*runtime_env_info); } message_->set_concurrency_group_name(concurrency_group_name); return *this; diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 30794494d81e8..64ce5fa559fe9 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -59,12 +59,12 @@ struct TaskOptions { TaskOptions(std::string name, int num_returns, std::unordered_map &resources, const std::string &concurrency_group_name = "", - const std::string &serialized_runtime_env = "{}") + const std::string &serialized_runtime_env_info = "{}") : name(name), num_returns(num_returns), resources(resources), concurrency_group_name(concurrency_group_name), - serialized_runtime_env(serialized_runtime_env) {} + serialized_runtime_env_info(serialized_runtime_env_info) {} /// The name of this task. std::string name; @@ -74,8 +74,10 @@ struct TaskOptions { std::unordered_map resources; /// The name of the concurrency group in which this task will be executed. std::string concurrency_group_name; - // Runtime Env used by this task. Propagated to child actors and tasks. - std::string serialized_runtime_env; + /// Runtime Env Info used by this task. It includes Runtime Env and some + /// fields which not contained in Runtime Env, such as eager_install. + /// Propagated to child actors and tasks. + std::string serialized_runtime_env_info; }; /// Options for actor creation tasks. @@ -89,7 +91,7 @@ struct ActorCreationOptions { std::optional is_detached, std::string &name, std::string &ray_namespace, bool is_asyncio, const rpc::SchedulingStrategy &scheduling_strategy, - const std::string &serialized_runtime_env = "{}", + const std::string &serialized_runtime_env_info = "{}", const std::vector &concurrency_groups = {}, bool execute_out_of_order = false, int32_t max_pending_calls = -1) : max_restarts(max_restarts), @@ -102,7 +104,7 @@ struct ActorCreationOptions { name(name), ray_namespace(ray_namespace), is_asyncio(is_asyncio), - serialized_runtime_env(serialized_runtime_env), + serialized_runtime_env_info(serialized_runtime_env_info), concurrency_groups(concurrency_groups.begin(), concurrency_groups.end()), execute_out_of_order(execute_out_of_order), max_pending_calls(max_pending_calls), @@ -138,8 +140,10 @@ struct ActorCreationOptions { const std::string ray_namespace; /// Whether to use async mode of direct actor call. const bool is_asyncio = false; - // Runtime Env used by this actor. Propagated to child actors and tasks. - std::string serialized_runtime_env; + /// Runtime Env Info used by this task. It includes Runtime Env and some + /// fields which not contained in Runtime Env, such as eager_install. + /// Propagated to child actors and tasks. + std::string serialized_runtime_env_info; /// The actor concurrency groups to indicate how this actor perform its /// methods concurrently. const std::vector concurrency_groups; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 959de5c72d1c0..d301f28658abe 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1440,56 +1440,72 @@ static std::vector GetUrisFromRuntimeEnv( return result; } -static std::vector GetUrisFromSerializedRuntimeEnv( - const std::string &serialized_runtime_env) { - rpc::RuntimeEnv runtime_env; - if (!google::protobuf::util::JsonStringToMessage(serialized_runtime_env, &runtime_env) - .ok()) { - RAY_LOG(WARNING) << "Parse runtime env failed for " << serialized_runtime_env; - // TODO(SongGuyang): We pass the raw string here and the task will fail after an - // exception raised in runtime env agent. Actually, we can fail the task here. - return {}; - } - return GetUrisFromRuntimeEnv(&runtime_env); -} - -std::string CoreWorker::OverrideTaskOrActorRuntimeEnv( - const std::string &serialized_runtime_env, - std::vector *runtime_env_uris) { +std::shared_ptr CoreWorker::OverrideTaskOrActorRuntimeEnvInfo( + const std::string &serialized_runtime_env_info) { + // TODO(Catch-Bull,SongGuyang): task runtime env not support the field eager_install + // yet, we will overwrite the filed eager_install when it did. std::shared_ptr parent = nullptr; + std::shared_ptr runtime_env_info = nullptr; + runtime_env_info.reset(new rpc::RuntimeEnvInfo()); + + if (!IsRuntimeEnvInfoEmpty(serialized_runtime_env_info)) { + RAY_CHECK(google::protobuf::util::JsonStringToMessage(serialized_runtime_env_info, + runtime_env_info.get()) + .ok()); + } + if (options_.worker_type == WorkerType::DRIVER) { - if (IsRuntimeEnvEmpty(serialized_runtime_env)) { - *runtime_env_uris = GetUrisFromRuntimeEnv(job_runtime_env_.get()); - return job_config_->runtime_env_info().serialized_runtime_env(); + if (IsRuntimeEnvEmpty(runtime_env_info->serialized_runtime_env())) { + runtime_env_info->set_serialized_runtime_env( + job_config_->runtime_env_info().serialized_runtime_env()); + runtime_env_info->clear_uris(); + for (const std::string &uri : GetUrisFromRuntimeEnv(job_runtime_env_.get())) { + runtime_env_info->add_uris(uri); + } + + return runtime_env_info; } parent = job_runtime_env_; } else { - if (IsRuntimeEnvEmpty(serialized_runtime_env)) { - *runtime_env_uris = - GetUrisFromRuntimeEnv(worker_context_.GetCurrentRuntimeEnv().get()); - return worker_context_.GetCurrentSerializedRuntimeEnv(); + if (IsRuntimeEnvEmpty(runtime_env_info->serialized_runtime_env())) { + runtime_env_info->set_serialized_runtime_env( + worker_context_.GetCurrentSerializedRuntimeEnv()); + runtime_env_info->clear_uris(); + for (const std::string &uri : + GetUrisFromRuntimeEnv(worker_context_.GetCurrentRuntimeEnv().get())) { + runtime_env_info->add_uris(uri); + } + + return runtime_env_info; } parent = worker_context_.GetCurrentRuntimeEnv(); } if (parent) { + std::string serialized_runtime_env = runtime_env_info->serialized_runtime_env(); rpc::RuntimeEnv child_runtime_env; if (!google::protobuf::util::JsonStringToMessage(serialized_runtime_env, &child_runtime_env) .ok()) { - RAY_LOG(WARNING) << "Parse runtime env failed for " << serialized_runtime_env; + RAY_LOG(WARNING) << "Parse runtime env failed for " << serialized_runtime_env + << ". serialized runtime env info: " + << serialized_runtime_env_info; // TODO(SongGuyang): We pass the raw string here and the task will fail after an // exception raised in runtime env agent. Actually, we can fail the task here. - return serialized_runtime_env; + return runtime_env_info; } auto override_runtime_env = OverrideRuntimeEnv(child_runtime_env, parent); - std::string result; - RAY_CHECK( - google::protobuf::util::MessageToJsonString(override_runtime_env, &result).ok()); - *runtime_env_uris = GetUrisFromRuntimeEnv(&override_runtime_env); - return result; + std::string serialized_override_runtime_env; + RAY_CHECK(google::protobuf::util::MessageToJsonString( + override_runtime_env, &serialized_override_runtime_env) + .ok()); + runtime_env_info->set_serialized_runtime_env(serialized_override_runtime_env); + runtime_env_info->clear_uris(); + for (const std::string &uri : GetUrisFromRuntimeEnv(&override_runtime_env)) { + runtime_env_info->add_uris(uri); + } + return runtime_env_info; } else { - *runtime_env_uris = GetUrisFromSerializedRuntimeEnv(serialized_runtime_env); - return serialized_runtime_env; + return runtime_env_info; } } @@ -1501,17 +1517,16 @@ void CoreWorker::BuildCommonTaskSpec( const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, const std::string &debugger_breakpoint, int64_t depth, - const std::string &serialized_runtime_env, + const std::string &serialized_runtime_env_info, const std::string &concurrency_group_name) { // Build common task spec. - std::vector runtime_env_uris; - auto override_runtime_env = - OverrideTaskOrActorRuntimeEnv(serialized_runtime_env, &runtime_env_uris); + auto override_runtime_env_info = + OverrideTaskOrActorRuntimeEnvInfo(serialized_runtime_env_info); builder.SetCommonTaskSpec( task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, current_task_id, task_index, caller_id, address, num_returns, required_resources, - required_placement_resources, debugger_breakpoint, depth, override_runtime_env, - runtime_env_uris, concurrency_group_name); + required_placement_resources, debugger_breakpoint, depth, override_runtime_env_info, + concurrency_group_name); // Set task arguments. for (const auto &arg : args) { builder.AddArg(*arg); @@ -1544,7 +1559,7 @@ std::vector CoreWorker::SubmitTask( worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, function, args, task_options.num_returns, constrained_resources, required_resources, debugger_breakpoint, - depth, task_options.serialized_runtime_env); + depth, task_options.serialized_runtime_env_info); builder.SetNormalTaskSpec(max_retries, retry_exceptions, scheduling_strategy); TaskSpecification task_spec = builder.Build(); RAY_LOG(DEBUG) << "Submitting normal task " << task_spec.DebugString(); @@ -1612,7 +1627,7 @@ Status CoreWorker::CreateActor(const RayFunction &function, worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, function, args, 1, new_resource, new_placement_resources, "" /* debugger_breakpoint */, depth, - actor_creation_options.serialized_runtime_env); + actor_creation_options.serialized_runtime_env_info); // If the namespace is not specified, get it from the job. const auto &ray_namespace = (actor_creation_options.ray_namespace.empty() @@ -1801,7 +1816,7 @@ std::optional> CoreWorker::SubmitActorTask( rpc_address_, function, args, num_returns, task_options.resources, required_resources, "", /* debugger_breakpoint */ depth, /*depth*/ - "{}", /* serialized_runtime_env */ + "{}", /* serialized_runtime_env_info */ task_options.concurrency_group_name); // NOTE: placement_group_capture_child_tasks and runtime_env will // be ignored in the actor because we should always follow the actor's option. diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index d3c5ff4d7e04f..d12a41b89e255 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -791,9 +791,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { FRIEND_TEST(TestOverrideRuntimeEnv, TestCondaInherit); FRIEND_TEST(TestOverrideRuntimeEnv, TestCondaOverride); - std::string OverrideTaskOrActorRuntimeEnv( - const std::string &serialized_runtime_env, - std::vector *runtime_env_uris /* output */); + std::shared_ptr OverrideTaskOrActorRuntimeEnvInfo( + const std::string &serialized_runtime_env_info); void BuildCommonTaskSpec( TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, @@ -803,7 +802,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, const std::string &debugger_breakpoint, int64_t depth, - const std::string &serialized_runtime_env, + const std::string &serialized_runtime_env_info, const std::string &concurrency_group_name = ""); void SetCurrentTaskId(const TaskID &task_id, uint64_t attempt_number); diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index 93f87b610ce58..74a5cf65cf083 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -135,10 +135,10 @@ std::shared_ptr CreateSingleNodeScheduler( return scheduler; } -RayTask CreateTask(const std::unordered_map &required_resources, - int num_args = 0, std::vector args = {}, - const std::string &serialized_runtime_env = "{}", - const std::vector &runtime_env_uris = {}) { +RayTask CreateTask( + const std::unordered_map &required_resources, int num_args = 0, + std::vector args = {}, + const std::shared_ptr runtime_env_info = nullptr) { TaskSpecBuilder spec_builder; TaskID id = RandomTaskId(); JobID job_id = RandomJobId(); @@ -146,8 +146,7 @@ RayTask CreateTask(const std::unordered_map &required_resou spec_builder.SetCommonTaskSpec(id, "dummy_task", Language::PYTHON, FunctionDescriptorBuilder::BuildPython("", "", "", ""), job_id, TaskID::Nil(), 0, TaskID::Nil(), address, 0, - required_resources, {}, "", 0, serialized_runtime_env, - runtime_env_uris); + required_resources, {}, "", 0, runtime_env_info); if (!args.empty()) { for (auto &arg : args) { @@ -474,8 +473,12 @@ TEST_F(ClusterTaskManagerTest, DispatchQueueNonBlockingTest) { {ray::kCPU_ResourceLabel, 4}}; std::string serialized_runtime_env_A = "mock_env_A"; - RayTask task_A = CreateTask(required_resources, /*num_args=*/0, /*args=*/{}, - serialized_runtime_env_A); + std::shared_ptr runtime_env_info_A = nullptr; + runtime_env_info_A.reset(new rpc::RuntimeEnvInfo()); + runtime_env_info_A->set_serialized_runtime_env(serialized_runtime_env_A); + + RayTask task_A = + CreateTask(required_resources, /*num_args=*/0, /*args=*/{}, runtime_env_info_A); rpc::RequestWorkerLeaseReply reply_A; bool callback_occurred = false; bool *callback_occurred_ptr = &callback_occurred; @@ -485,10 +488,14 @@ TEST_F(ClusterTaskManagerTest, DispatchQueueNonBlockingTest) { }; std::string serialized_runtime_env_B = "mock_env_B"; - RayTask task_B_1 = CreateTask(required_resources, /*num_args=*/0, /*args=*/{}, - serialized_runtime_env_B); - RayTask task_B_2 = CreateTask(required_resources, /*num_args=*/0, /*args=*/{}, - serialized_runtime_env_B); + std::shared_ptr runtime_env_info_B = nullptr; + runtime_env_info_B.reset(new rpc::RuntimeEnvInfo()); + runtime_env_info_B->set_serialized_runtime_env(serialized_runtime_env_B); + + RayTask task_B_1 = + CreateTask(required_resources, /*num_args=*/0, /*args=*/{}, runtime_env_info_B); + RayTask task_B_2 = + CreateTask(required_resources, /*num_args=*/0, /*args=*/{}, runtime_env_info_B); rpc::RequestWorkerLeaseReply reply_B_1; rpc::RequestWorkerLeaseReply reply_B_2; auto empty_callback = [](Status, std::function, std::function) {}; @@ -1785,8 +1792,12 @@ TEST_F(ClusterTaskManagerTest, TestResourceDiff) { TEST_F(ClusterTaskManagerTest, PopWorkerExactlyOnce) { // Create and queue one task. std::string serialized_runtime_env = "mock_env"; + std::shared_ptr runtime_env_info = nullptr; + runtime_env_info.reset(new rpc::RuntimeEnvInfo()); + runtime_env_info->set_serialized_runtime_env(serialized_runtime_env); + RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 4}}, /*num_args=*/0, /*args=*/{}, - serialized_runtime_env); + runtime_env_info); auto runtime_env_hash = task.GetTaskSpecification().GetRuntimeEnvHash(); rpc::RequestWorkerLeaseReply reply; bool callback_occurred = false;