Skip to content

Commit

Permalink
[serdes] don't recreate type map
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Jul 23, 2024
1 parent 6fcaef6 commit 9eb3769
Showing 1 changed file with 66 additions and 22 deletions.
88 changes: 66 additions & 22 deletions python_modules/dagster/dagster/_serdes/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@
import dagster._check as check
import dagster._seven as seven
from dagster._model.pydantic_compat_layer import ModelFieldCompat, model_fields
from dagster._record import as_dict_for_new, get_record_annotations, has_generated_new, is_record
from dagster._record import (
as_dict_for_new,
get_record_annotations,
has_generated_new,
is_record,
)
from dagster._utils import is_named_tuple_instance, is_named_tuple_subclass
from dagster._utils.warnings import disable_dagster_warnings

Expand Down Expand Up @@ -148,6 +153,7 @@ class WhitelistMap(NamedTuple):
object_serializers: Dict[str, "ObjectSerializer"]
object_deserializers: Dict[str, "ObjectSerializer"]
enum_serializers: Dict[str, "EnumSerializer"]
object_type_map: Dict[str, Type]

def register_object(
self,
Expand Down Expand Up @@ -194,6 +200,8 @@ def register_object(
for old_storage_name in old_storage_names:
self.object_deserializers[old_storage_name] = serializer

self.object_type_map[name] = serializer.klass

def register_enum(
self,
name: str,
Expand All @@ -216,11 +224,12 @@ def register_enum(

@staticmethod
def create() -> "WhitelistMap":
return WhitelistMap(object_serializers={}, object_deserializers={}, enum_serializers={})

def get_type_map(self) -> Mapping[str, Type]:
# Return string name -> type mapping of all registered types
return {name: ser.klass for name, ser in self.object_serializers.items()}
return WhitelistMap(
object_serializers={},
object_deserializers={},
enum_serializers={},
object_type_map={},
)


_WHITELIST_MAP: Final[WhitelistMap] = WhitelistMap.create()
Expand Down Expand Up @@ -418,7 +427,9 @@ def __whitelist_for_serdes(klass: T_Type) -> T_Type:
)
return klass
else:
raise SerdesUsageError(f"Can not whitelist class {klass} for serializer {serializer}")
raise SerdesUsageError(
f"Can not whitelist class {klass} for serializer {serializer}"
)

return __whitelist_for_serdes

Expand Down Expand Up @@ -517,7 +528,12 @@ def get_storage_name(self) -> str:
return self.storage_name or self.klass.__name__


EMPTY_VALUES_TO_SKIP: Tuple[None, List[Any], Dict[Any, Any], Set[Any]] = (None, [], {}, set())
EMPTY_VALUES_TO_SKIP: Tuple[None, List[Any], Dict[Any, Any], Set[Any]] = (
None,
[],
{},
set(),
)


class ObjectSerializer(Serializer, Generic[T]):
Expand Down Expand Up @@ -607,12 +623,17 @@ def pack_items(
self,
value: T,
whitelist_map: WhitelistMap,
object_handler: Callable[[SerializableObject, WhitelistMap, str], JsonSerializableValue],
object_handler: Callable[
[SerializableObject, WhitelistMap, str], JsonSerializableValue
],
descent_path: str,
) -> Iterator[Tuple[str, JsonSerializableValue]]:
yield "__class__", self.get_storage_name()
for key, inner_value in self.object_as_mapping(self.before_pack(value)).items():
if key in self.skip_when_empty_fields and inner_value in EMPTY_VALUES_TO_SKIP:
if (
key in self.skip_when_empty_fields
and inner_value in EMPTY_VALUES_TO_SKIP
):
continue
storage_key = self.storage_field_names.get(key, key)
custom = self.field_serializers.get(key)
Expand Down Expand Up @@ -682,7 +703,9 @@ def constructor_param_names(self) -> Sequence[str]:
return names


T_Dataclass = TypeVar("T_Dataclass", bound="DataclassInstance", default="DataclassInstance")
T_Dataclass = TypeVar(
"T_Dataclass", bound="DataclassInstance", default="DataclassInstance"
)


class DataclassSerializer(ObjectSerializer[T_Dataclass]):
Expand All @@ -704,7 +727,8 @@ def object_as_mapping(self, value: T_PydanticModel) -> Mapping[str, Any]:
result = {}
for key, field in self._model_fields.items():
if field.alias is None and (
field.serialization_alias is not None or field.validation_alias is not None
field.serialization_alias is not None
or field.validation_alias is not None
):
raise SerializationError(
"Can't serialize pydantic models with serialization or validation aliases. Use "
Expand Down Expand Up @@ -760,10 +784,15 @@ def unpack(
return set(sequence_value) if sequence_value is not None else None

def pack(
self, set_value: Optional[AbstractSet[Any]], whitelist_map: WhitelistMap, descent_path: str
self,
set_value: Optional[AbstractSet[Any]],
whitelist_map: WhitelistMap,
descent_path: str,
) -> Optional[Sequence[Any]]:
return (
sorted([pack_value(x, whitelist_map, descent_path) for x in set_value], key=str)
sorted(
[pack_value(x, whitelist_map, descent_path) for x in set_value], key=str
)
if set_value is not None
else None
)
Expand Down Expand Up @@ -849,7 +878,9 @@ def pack_value(
def _transform_for_serialization(
val: PackableValue,
whitelist_map: WhitelistMap,
object_handler: Callable[[SerializableObject, WhitelistMap, str], JsonSerializableValue],
object_handler: Callable[
[SerializableObject, WhitelistMap, str], JsonSerializableValue
],
descent_path: str,
) -> JsonSerializableValue:
# this is a hot code path so we handle the common base cases without isinstance
Expand Down Expand Up @@ -1066,7 +1097,9 @@ def deserialize_value(
def deserialize_value(
val: str,
as_type: Optional[
Union[Type[T_PackableValue], Tuple[Type[T_PackableValue], Type[U_PackableValue]]]
Union[
Type[T_PackableValue], Tuple[Type[T_PackableValue], Type[U_PackableValue]]
]
] = None,
whitelist_map: WhitelistMap = _WHITELIST_MAP,
) -> Union[PackableValue, T_PackableValue, Union[T_PackableValue, U_PackableValue]]:
Expand All @@ -1081,11 +1114,14 @@ def deserialize_value(

# Never issue warnings when deserializing deprecated objects.
with disable_dagster_warnings(), check.EvalContext.contextual_namespace(
whitelist_map.get_type_map()
whitelist_map.object_type_map
):
context = UnpackContext()
unpacked_value = seven.json.loads(
val, object_hook=partial(_unpack_object, whitelist_map=whitelist_map, context=context)
val,
object_hook=partial(
_unpack_object, whitelist_map=whitelist_map, context=context
),
)
unpacked_value = context.finalize_unpack(unpacked_value)
if as_type and not (
Expand All @@ -1106,7 +1142,9 @@ def __init__(self, message: str, value: Mapping[str, UnpackedValue]):
self.value = value


def _unpack_object(val: dict, whitelist_map: WhitelistMap, context: UnpackContext) -> UnpackedValue:
def _unpack_object(
val: dict, whitelist_map: WhitelistMap, context: UnpackContext
) -> UnpackedValue:
if "__class__" in val:
klass_name = val["__class__"]
if klass_name not in whitelist_map.object_deserializers:
Expand Down Expand Up @@ -1184,7 +1222,9 @@ def unpack_value(
def unpack_value(
val: JsonSerializableValue,
as_type: Optional[
Union[Type[T_PackableValue], Tuple[Type[T_PackableValue], Type[U_PackableValue]]]
Union[
Type[T_PackableValue], Tuple[Type[T_PackableValue], Type[U_PackableValue]]
]
] = None,
whitelist_map: WhitelistMap = _WHITELIST_MAP,
context: Optional[UnpackContext] = None,
Expand Down Expand Up @@ -1221,7 +1261,9 @@ def _unpack_value(
return [_unpack_value(item, whitelist_map, context) for item in val]

if isinstance(val, dict):
unpacked_vals = {k: _unpack_value(v, whitelist_map, context) for k, v in val.items()}
unpacked_vals = {
k: _unpack_value(v, whitelist_map, context) for k, v in val.items()
}
return _unpack_object(unpacked_vals, whitelist_map, context)

return val
Expand Down Expand Up @@ -1251,7 +1293,9 @@ def _with_header(msg: str) -> str:

if cls_param.name not in {"cls", "_cls"}:
raise SerdesUsageError(
_with_header(f'First parameter must be _cls or cls. Got "{cls_param.name}".')
_with_header(
f'First parameter must be _cls or cls. Got "{cls_param.name}".'
)
)

value_params = dunder_new_params[1:]
Expand Down

0 comments on commit 9eb3769

Please sign in to comment.