Skip to content

Commit

Permalink
[serdes] don't recreate type map
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Jul 23, 2024
1 parent e3faf83 commit 6103eb4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def partitions_def(self) -> Optional[PartitionsDefinition]:
),
},
enum_serializers=_WHITELIST_MAP.enum_serializers,
object_type_map=_WHITELIST_MAP.object_type_map,
)

return deserialize_value(
Expand Down
31 changes: 22 additions & 9 deletions python_modules/dagster/dagster/_serdes/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class WhitelistMap(NamedTuple):
object_serializers: Dict[str, "ObjectSerializer"]
object_deserializers: Dict[str, "ObjectSerializer"]
enum_serializers: Dict[str, "EnumSerializer"]
object_type_map: Dict[str, Type]

def register_object(
self,
Expand Down Expand Up @@ -194,6 +195,8 @@ def register_object(
for old_storage_name in old_storage_names:
self.object_deserializers[old_storage_name] = serializer

self.object_type_map[name] = serializer.klass

def register_enum(
self,
name: str,
Expand All @@ -216,11 +219,12 @@ def register_enum(

@staticmethod
def create() -> "WhitelistMap":
return WhitelistMap(object_serializers={}, object_deserializers={}, enum_serializers={})

def get_type_map(self) -> Mapping[str, Type]:
# Return string name -> type mapping of all registered types
return {name: ser.klass for name, ser in self.object_serializers.items()}
return WhitelistMap(
object_serializers={},
object_deserializers={},
enum_serializers={},
object_type_map={},
)


_WHITELIST_MAP: Final[WhitelistMap] = WhitelistMap.create()
Expand Down Expand Up @@ -511,7 +515,12 @@ def get_storage_name(self) -> str:
return self.storage_name or self.klass.__name__


EMPTY_VALUES_TO_SKIP: Tuple[None, List[Any], Dict[Any, Any], Set[Any]] = (None, [], {}, set())
EMPTY_VALUES_TO_SKIP: Tuple[None, List[Any], Dict[Any, Any], Set[Any]] = (
None,
[],
{},
set(),
)


class ObjectSerializer(Serializer, Generic[T]):
Expand Down Expand Up @@ -754,7 +763,10 @@ def unpack(
return set(sequence_value) if sequence_value is not None else None

def pack(
self, set_value: Optional[AbstractSet[Any]], whitelist_map: WhitelistMap, descent_path: str
self,
set_value: Optional[AbstractSet[Any]],
whitelist_map: WhitelistMap,
descent_path: str,
) -> Optional[Sequence[Any]]:
return (
sorted([pack_value(x, whitelist_map, descent_path) for x in set_value], key=str)
Expand Down Expand Up @@ -1075,11 +1087,12 @@ def deserialize_value(

# Never issue warnings when deserializing deprecated objects.
with disable_dagster_warnings(), check.EvalContext.contextual_namespace(
whitelist_map.get_type_map()
whitelist_map.object_type_map
):
context = UnpackContext()
unpacked_value = seven.json.loads(
val, object_hook=partial(_unpack_object, whitelist_map=whitelist_map, context=context)
val,
object_hook=partial(_unpack_object, whitelist_map=whitelist_map, context=context),
)
unpacked_value = context.finalize_unpack(unpacked_value)
if as_type and not (
Expand Down

0 comments on commit 6103eb4

Please sign in to comment.