Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serdes] deserialize_values #23186

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
)
from dagster._serdes import deserialize_value, serialize_value
from dagster._serdes.errors import DeserializationError
from dagster._serdes.serdes import deserialize_values
from dagster._time import datetime_from_timestamp, get_current_timestamp, utc_datetime_from_naive
from dagster._utils import PrintFn
from dagster._utils.concurrency import (
Expand Down Expand Up @@ -665,7 +666,7 @@ def get_step_stats_for_run(
results = conn.execute(raw_event_query).fetchall()

try:
records = [deserialize_value(json_str, EventLogEntry) for (json_str,) in results]
records = deserialize_values((json_str for (json_str,) in results), EventLogEntry)
return build_run_step_stats_from_events(run_id, records)
except (seven.JSONDecodeError, DeserializationError) as err:
raise DagsterEventLogInvalidForRun(run_id=run_id) from err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
)
from dagster._daemon.types import DaemonHeartbeat
from dagster._serdes import deserialize_value, serialize_value
from dagster._serdes.serdes import deserialize_values
from dagster._seven import JSONDecodeError
from dagster._time import datetime_from_timestamp, get_current_datetime, utc_datetime_from_naive
from dagster._utils import PrintFn
Expand Down Expand Up @@ -850,7 +851,7 @@ def get_backfills(
query = query.limit(limit)
query = query.order_by(BulkActionsTable.c.id.desc())
rows = self.fetchall(query)
return [deserialize_value(row["body"], PartitionBackfill) for row in rows]
return deserialize_values((row["body"] for row in rows), PartitionBackfill)

def get_backfill(self, backfill_id: str) -> Optional[PartitionBackfill]:
check.str_param(backfill_id, "backfill_id")
Expand Down
1 change: 1 addition & 0 deletions python_modules/dagster/dagster/_serdes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SerializableNonScalarKeyMapping as SerializableNonScalarKeyMapping,
WhitelistMap as WhitelistMap,
deserialize_value as deserialize_value,
deserialize_values as deserialize_values,
pack_value as pack_value,
serialize_value as serialize_value,
unpack_value as unpack_value,
Expand Down
70 changes: 55 additions & 15 deletions python_modules/dagster/dagster/_serdes/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Dict,
FrozenSet,
Generic,
Iterable,
Iterator,
List,
Mapping,
Expand Down Expand Up @@ -1085,26 +1086,65 @@ def deserialize_value(
"""
check.str_param(val, "val")

# Never issue warnings when deserializing deprecated objects.
return deserialize_values([val], as_type, whitelist_map)[0]


@overload
def deserialize_values(
vals: Iterable[str],
as_type: Type[T_PackableValue],
whitelist_map: WhitelistMap = ...,
) -> Sequence[T_PackableValue]: ...


@overload
def deserialize_values(
vals: Iterable[str],
as_type: None = ...,
whitelist_map: WhitelistMap = ...,
) -> Sequence[PackableValue]: ...


@overload
def deserialize_values(
vals: Iterable[str],
as_type: Optional[
Union[Type[T_PackableValue], Tuple[Type[T_PackableValue], Type[U_PackableValue]]]
],
whitelist_map: WhitelistMap = ...,
) -> Sequence[Union[PackableValue, T_PackableValue, Union[T_PackableValue, U_PackableValue]]]: ...


def deserialize_values(
vals: Iterable[str],
as_type: Optional[
Union[Type[T_PackableValue], Tuple[Type[T_PackableValue], Type[U_PackableValue]]]
] = None,
whitelist_map: WhitelistMap = _WHITELIST_MAP,
) -> Sequence[Union[PackableValue, T_PackableValue, Union[T_PackableValue, U_PackableValue]]]:
"""Deserialize a collection of values without having to repeatedly exit/enter the deserializing context."""
with disable_dagster_warnings(), check.EvalContext.contextual_namespace(
whitelist_map.object_type_map
):
context = UnpackContext()
unpacked_value = seven.json.loads(
val,
object_hook=partial(_unpack_object, whitelist_map=whitelist_map, context=context),
)
unpacked_value = context.finalize_unpack(unpacked_value)
if as_type and not (
is_named_tuple_instance(unpacked_value)
if as_type is NamedTuple
else isinstance(unpacked_value, as_type)
):
raise DeserializationError(
f"Deserialized object was not expected type {as_type}, got {type(unpacked_value)}"
unpacked_values = []
for val in vals:
context = UnpackContext()
unpacked_value = seven.json.loads(
val,
object_hook=partial(_unpack_object, whitelist_map=whitelist_map, context=context),
)
unpacked_value = context.finalize_unpack(unpacked_value)
if as_type and not (
is_named_tuple_instance(unpacked_value)
if as_type is NamedTuple
else isinstance(unpacked_value, as_type)
):
raise DeserializationError(
f"Deserialized object was not expected type {as_type}, got {type(unpacked_value)}"
)
unpacked_values.append(unpacked_value)

return unpacked_value
return unpacked_values


class UnknownSerdesValue:
Expand Down