Skip to content

Commit

Permalink
[runtime env] plugin refactor [5/n]: support priority (ray-project#26659
Browse files Browse the repository at this point in the history
)
  • Loading branch information
SongGuyang authored Jul 20, 2022
1 parent b87731c commit f96f5a1
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 47 deletions.
34 changes: 20 additions & 14 deletions dashboard/modules/runtime_env/runtime_env_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,20 +296,26 @@ async def _setup_runtime_env(

def setup_plugins():
# Run setup function from all the plugins
for name, config in runtime_env.plugins():
per_job_logger.debug(f"Setting up runtime env plugin {name}")
plugin = self._runtime_env_plugin_manager.get_plugin(name)
if plugin is None:
raise RuntimeError(f"runtime env plugin {name} not found.")
# TODO(architkulkarni): implement uri support
plugin.validate(runtime_env)
plugin.create("uri not implemented", config, context)
plugin.modify_context(
"uri not implemented",
config,
context,
per_job_logger,
)
if runtime_env.plugins():
for (
setup_context
) in self._runtime_env_plugin_manager.sorted_plugin_setup_contexts(
runtime_env.plugins()
):
per_job_logger.debug(
f"Setting up runtime env plugin {setup_context.name}"
)
# TODO(architkulkarni): implement uri support
setup_context.class_instance.validate(runtime_env)
setup_context.class_instance.create(
"uri not implemented", setup_context.config, context
)
setup_context.class_instance.modify_context(
"uri not implemented",
setup_context.config,
context,
per_job_logger,
)

loop = asyncio.get_event_loop()
# Plugins setup method is sync process, running in other threads
Expand Down
19 changes: 18 additions & 1 deletion python/ray/_private/runtime_env/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
# Env var set by job manager to pass runtime env and metadata to subprocess
RAY_JOB_CONFIG_JSON_ENV_VAR = "RAY_JOB_CONFIG_JSON_ENV_VAR"

# The plugins which should be loaded when ray cluster starts.
# The plugin config which should be loaded when ray cluster starts.
# It is a json formatted config,
# e.g. [{"class": "xxx.xxx.xxx_plugin", "priority": 10}].
RAY_RUNTIME_ENV_PLUGINS_ENV_VAR = "RAY_RUNTIME_ENV_PLUGINS"

# The field name of plugin class in the plugin config.
RAY_RUNTIME_ENV_CLASS_FIELD_NAME = "class"

# The field name of priority in the plugin config.
RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME = "priority"

# The default priority of runtime env plugin.
RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY = 10

# The minimum priority of runtime env plugin.
RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY = 0

# The maximum priority of runtime env plugin.
RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY = 100

# The schema files or directories of plugins which should be loaded in workers.
RAY_RUNTIME_ENV_PLUGIN_SCHEMAS_ENV_VAR = "RAY_RUNTIME_ENV_PLUGIN_SCHEMAS"

Expand Down
116 changes: 91 additions & 25 deletions python/ray/_private/runtime_env/plugin.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import logging
import os
import json
from abc import ABC
from typing import List
from typing import List, Dict, Tuple, Any

from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.uri_cache import URICache
from ray._private.runtime_env.constants import RAY_RUNTIME_ENV_PLUGINS_ENV_VAR
from ray._private.runtime_env.constants import (
RAY_RUNTIME_ENV_PLUGINS_ENV_VAR,
RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY,
RAY_RUNTIME_ENV_CLASS_FIELD_NAME,
RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME,
RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY,
RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY,
)
from ray.util.annotations import DeveloperAPI
from ray._private.utils import import_attr

Expand All @@ -17,6 +25,7 @@ class RuntimeEnvPlugin(ABC):
"""Abstract base class for runtime environment plugins."""

name: str = None
priority: int = RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY

@staticmethod
def validate(runtime_env_dict: dict) -> str:
Expand Down Expand Up @@ -91,43 +100,100 @@ def delete_uri(self, uri: str, logger: logging.Logger) -> float:
return 0


class PluginSetupContext:
def __init__(self, name: str, config: Any, class_instance: object):
self.name = name
self.config = config
self.class_instance = class_instance


class RuntimeEnvPluginManager:
"""This manager is used to load plugins in runtime env agent."""

class Context:
def __init__(self, class_instance, priority):
self.class_instance = class_instance
self.priority = priority

def __init__(self):
self.plugins = {}
plugins_config = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
if plugins_config:
self.load_plugins(plugins_config.split(","))
self.plugins: Dict[str, RuntimeEnvPluginManager.Context] = {}
plugin_config_str = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
if plugin_config_str:
plugin_configs = json.loads(plugin_config_str)
self.load_plugins(plugin_configs)

def load_plugins(self, plugin_classes: List[str]):
def load_plugins(self, plugin_configs: List[Dict]):
"""Load runtime env plugins"""
for plugin_class_path in plugin_classes:
plugin_class = import_attr(plugin_class_path)
for plugin_config in plugin_configs:
if (
not isinstance(plugin_config, dict)
or RAY_RUNTIME_ENV_CLASS_FIELD_NAME not in plugin_config
):
raise RuntimeError(
f"Invalid runtime env plugin config {plugin_config}, "
"it should be a object which contains the "
f"{RAY_RUNTIME_ENV_CLASS_FIELD_NAME} field."
)
plugin_class = import_attr(plugin_config[RAY_RUNTIME_ENV_CLASS_FIELD_NAME])
if not issubclass(plugin_class, RuntimeEnvPlugin):
default_logger.warning(
"Invalid runtime env plugin class %s. "
raise RuntimeError(
f"Invalid runtime env plugin class {plugin_class}. "
"The plugin class must inherit "
"ray._private.runtime_env.plugin.RuntimeEnvPlugin.",
plugin_class,
"ray._private.runtime_env.plugin.RuntimeEnvPlugin."
)
continue
if not plugin_class.name:
default_logger.warning(
"No valid name in runtime env plugin %s", plugin_class
raise RuntimeError(
f"No valid name in runtime env plugin {plugin_class}."
)
continue
if plugin_class.name in self.plugins:
default_logger.warning(
"The name of runtime env plugin %s conflicts with %s",
plugin_class,
self.plugins[plugin_class.name],
raise RuntimeError(
f"The name of runtime env plugin {plugin_class} conflicts "
f"with {self.plugins[plugin_class.name]}.",
)
continue
self.plugins[plugin_class.name] = plugin_class()

def get_plugin(self, name: str):
return self.plugins.get(name)
# The priority should be an integer between 0 and 100.
# The default priority is 10. A smaller number indicates a
# higher priority and the plugin will be set up first.
if RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME in plugin_config:
priority = plugin_config[RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME]
else:
priority = plugin_class.priority
if (
not isinstance(priority, int)
or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY
or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY
):
raise RuntimeError(
f"Invalid runtime env priority {priority}, "
"it should be an integer between "
f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} "
f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}."
)

self.plugins[plugin_class.name] = RuntimeEnvPluginManager.Context(
plugin_class(), priority
)

def sorted_plugin_setup_contexts(
self, inputs: List[Tuple[str, Any]]
) -> List[PluginSetupContext]:
used_plugins = []
for name, config in inputs:
if name not in self.plugins:
raise RuntimeError(f"Runtime env plugin {name} not found.")
used_plugins.append(
(
name,
config,
self.plugins[name].class_instance,
self.plugins[name].priority,
)
)
sort_used_plugins = sorted(used_plugins, key=lambda x: x[3], reverse=False)
return [
PluginSetupContext(name, config, class_instance)
for name, config, class_instance, _ in sort_used_plugins
]


@DeveloperAPI
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/test_placement_group_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def f():
@pytest.mark.parametrize(
"set_runtime_env_plugins",
[
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH,
'[{"class":"' + MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH + '"}]',
],
indirect=True,
)
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/test_runtime_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def modify_context(
@pytest.mark.parametrize(
"set_runtime_env_plugins",
[
MY_PLUGIN_CLASS_PATH,
'[{"class":"' + MY_PLUGIN_CLASS_PATH + '"}]',
],
indirect=True,
)
Expand Down
Loading

0 comments on commit f96f5a1

Please sign in to comment.