Skip to content

Commit

Permalink
[runtime env] plugin refactor[3/n]: support strong type by @DataClass (
Browse files Browse the repository at this point in the history
  • Loading branch information
SongGuyang committed Jul 12, 2022
1 parent b3878e2 commit 781c2a7
Show file tree
Hide file tree
Showing 18 changed files with 591 additions and 4 deletions.
23 changes: 23 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,26 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

--------------------------------------------------------------------------------
Code in python/ray/_private/thirdparty/dacite is adapted from https://github.com/konradhalas/dacite/blob/master/dacite

Copyright (c) 2018 Konrad Hałas

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
4 changes: 3 additions & 1 deletion python/ray/_private/runtime_env/plugin_schema_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
class RuntimeEnvPluginSchemaManager:
"""This manager is used to load plugin json schemas."""

default_schema_path = os.path.join(os.path.dirname(__file__), "schemas")
default_schema_path = os.path.join(
os.path.dirname(__file__), "../../runtime_env/schemas"
)
schemas = {}
loaded = False

Expand Down
21 changes: 21 additions & 0 deletions python/ray/_private/thirdparty/dacite/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2018 Konrad Hałas

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
3 changes: 3 additions & 0 deletions python/ray/_private/thirdparty/dacite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .config import Config
from .core import from_dict
from .exceptions import *
12 changes: 12 additions & 0 deletions python/ray/_private/thirdparty/dacite/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import dataclass, field
from typing import Dict, Any, Callable, Optional, Type, List


@dataclass
class Config:
type_hooks: Dict[Type, Callable[[Any], Any]] = field(default_factory=dict)
cast: List[Type] = field(default_factory=list)
forward_references: Optional[Dict[str, Any]] = None
check_types: bool = True
strict: bool = False
strict_unions_match: bool = False
140 changes: 140 additions & 0 deletions python/ray/_private/thirdparty/dacite/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import copy
from dataclasses import is_dataclass
from itertools import zip_longest
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any

from .config import Config
from .data import Data
from .dataclasses import get_default_value_for_field, create_instance, DefaultValueNotFoundError, get_fields
from .exceptions import (
ForwardReferenceError,
WrongTypeError,
DaciteError,
UnionMatchError,
MissingValueError,
DaciteFieldError,
UnexpectedDataError,
StrictUnionMatchError,
)
from .types import (
is_instance,
is_generic_collection,
is_union,
extract_generic,
is_optional,
transform_value,
extract_origin_collection,
is_init_var,
extract_init_var,
)

T = TypeVar("T")


def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None) -> T:
"""Create a data class instance from a dictionary.
:param data_class: a data class type
:param data: a dictionary of a input data
:param config: a configuration of the creation process
:return: an instance of a data class
"""
init_values: Data = {}
post_init_values: Data = {}
config = config or Config()
try:
data_class_hints = get_type_hints(data_class, globalns=config.forward_references)
except NameError as error:
raise ForwardReferenceError(str(error))
data_class_fields = get_fields(data_class)
if config.strict:
extra_fields = set(data.keys()) - {f.name for f in data_class_fields}
if extra_fields:
raise UnexpectedDataError(keys=extra_fields)
for field in data_class_fields:
field = copy.copy(field)
field.type = data_class_hints[field.name]
try:
try:
field_data = data[field.name]
transformed_value = transform_value(
type_hooks=config.type_hooks, cast=config.cast, target_type=field.type, value=field_data
)
value = _build_value(type_=field.type, data=transformed_value, config=config)
except DaciteFieldError as error:
error.update_path(field.name)
raise
if config.check_types and not is_instance(value, field.type):
raise WrongTypeError(field_path=field.name, field_type=field.type, value=value)
except KeyError:
try:
value = get_default_value_for_field(field)
except DefaultValueNotFoundError:
if not field.init:
continue
raise MissingValueError(field.name)
if field.init:
init_values[field.name] = value
else:
post_init_values[field.name] = value

return create_instance(data_class=data_class, init_values=init_values, post_init_values=post_init_values)


def _build_value(type_: Type, data: Any, config: Config) -> Any:
if is_init_var(type_):
type_ = extract_init_var(type_)
if is_union(type_):
return _build_value_for_union(union=type_, data=data, config=config)
elif is_generic_collection(type_) and is_instance(data, extract_origin_collection(type_)):
return _build_value_for_collection(collection=type_, data=data, config=config)
elif is_dataclass(type_) and is_instance(data, Data):
return from_dict(data_class=type_, data=data, config=config)
return data


def _build_value_for_union(union: Type, data: Any, config: Config) -> Any:
types = extract_generic(union)
if is_optional(union) and len(types) == 2:
return _build_value(type_=types[0], data=data, config=config)
union_matches = {}
for inner_type in types:
try:
# noinspection PyBroadException
try:
data = transform_value(
type_hooks=config.type_hooks, cast=config.cast, target_type=inner_type, value=data
)
except Exception: # pylint: disable=broad-except
continue
value = _build_value(type_=inner_type, data=data, config=config)
if is_instance(value, inner_type):
if config.strict_unions_match:
union_matches[inner_type] = value
else:
return value
except DaciteError:
pass
if config.strict_unions_match:
if len(union_matches) > 1:
raise StrictUnionMatchError(union_matches)
return union_matches.popitem()[1]
if not config.check_types:
return data
raise UnionMatchError(field_type=union, value=data)


def _build_value_for_collection(collection: Type, data: Any, config: Config) -> Any:
data_type = data.__class__
if is_instance(data, Mapping):
item_type = extract_generic(collection, defaults=(Any, Any))[1]
return data_type((key, _build_value(type_=item_type, data=value, config=config)) for key, value in data.items())
elif is_instance(data, tuple):
types = extract_generic(collection)
if len(types) == 2 and types[1] == Ellipsis:
return data_type(_build_value(type_=types[0], data=item, config=config) for item in data)
return data_type(
_build_value(type_=type_, data=item, config=config) for item, type_ in zip_longest(data, types)
)
item_type = extract_generic(collection, defaults=(Any,))[0]
return data_type(_build_value(type_=item_type, data=item, config=config) for item in data)
3 changes: 3 additions & 0 deletions python/ray/_private/thirdparty/dacite/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Dict, Any

Data = Dict[str, Any]
33 changes: 33 additions & 0 deletions python/ray/_private/thirdparty/dacite/dataclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from dataclasses import Field, MISSING, _FIELDS, _FIELD, _FIELD_INITVAR # type: ignore
from typing import Type, Any, TypeVar, List

from .data import Data
from .types import is_optional

T = TypeVar("T", bound=Any)


class DefaultValueNotFoundError(Exception):
pass


def get_default_value_for_field(field: Field) -> Any:
if field.default != MISSING:
return field.default
elif field.default_factory != MISSING: # type: ignore
return field.default_factory() # type: ignore
elif is_optional(field.type):
return None
raise DefaultValueNotFoundError()


def create_instance(data_class: Type[T], init_values: Data, post_init_values: Data) -> T:
instance = data_class(**init_values)
for key, value in post_init_values.items():
setattr(instance, key, value)
return instance


def get_fields(data_class: Type[T]) -> List[Field]:
fields = getattr(data_class, _FIELDS)
return [f for f in fields.values() if f._field_type is _FIELD or f._field_type is _FIELD_INITVAR]
79 changes: 79 additions & 0 deletions python/ray/_private/thirdparty/dacite/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Any, Type, Optional, Set, Dict


def _name(type_: Type) -> str:
return type_.__name__ if hasattr(type_, "__name__") else str(type_)


class DaciteError(Exception):
pass


class DaciteFieldError(DaciteError):
def __init__(self, field_path: Optional[str] = None):
super().__init__()
self.field_path = field_path

def update_path(self, parent_field_path: str) -> None:
if self.field_path:
self.field_path = f"{parent_field_path}.{self.field_path}"
else:
self.field_path = parent_field_path


class WrongTypeError(DaciteFieldError):
def __init__(self, field_type: Type, value: Any, field_path: Optional[str] = None) -> None:
super().__init__(field_path=field_path)
self.field_type = field_type
self.value = value

def __str__(self) -> str:
return (
f'wrong value type for field "{self.field_path}" - should be "{_name(self.field_type)}" '
f'instead of value "{self.value}" of type "{_name(type(self.value))}"'
)


class MissingValueError(DaciteFieldError):
def __init__(self, field_path: Optional[str] = None):
super().__init__(field_path=field_path)

def __str__(self) -> str:
return f'missing value for field "{self.field_path}"'


class UnionMatchError(WrongTypeError):
def __str__(self) -> str:
return (
f'can not match type "{_name(type(self.value))}" to any type '
f'of "{self.field_path}" union: {_name(self.field_type)}'
)


class StrictUnionMatchError(DaciteFieldError):
def __init__(self, union_matches: Dict[Type, Any], field_path: Optional[str] = None) -> None:
super().__init__(field_path=field_path)
self.union_matches = union_matches

def __str__(self) -> str:
conflicting_types = ", ".join(_name(type_) for type_ in self.union_matches)
return f'can not choose between possible Union matches for field "{self.field_path}": {conflicting_types}'


class ForwardReferenceError(DaciteError):
def __init__(self, message: str) -> None:
super().__init__()
self.message = message

def __str__(self) -> str:
return f"can not resolve forward reference: {self.message}"


class UnexpectedDataError(DaciteError):
def __init__(self, keys: Set[str]) -> None:
super().__init__()
self.keys = keys

def __str__(self) -> str:
formatted_keys = ", ".join(f'"{key}"' for key in self.keys)
return f"can not match {formatted_keys} to any data class field"
Empty file.
Loading

0 comments on commit 781c2a7

Please sign in to comment.