Skip to content

Commit

Permalink
add batch fetching of data version records (dagster-io#21798)
Browse files Browse the repository at this point in the history
## Summary & Motivation
Asset graphs with large fan-in can incur a hefty data-fetching cost when
used with data versions. This PR fetches the asset record for a batched
set of asset keys. The asset record has the last materialization record,
and potentially the last observation record (in Plus), reducing the
number of serial fetches we have to make to get the input data versions.

This batching of calls is only possible because we're not filtering the
records (obs/mats) that we're fetching (either by partition or by
storage id).

## How I Tested These Changes
Added an explicit fan-in data version test that checks the underlying
data fetching calls. It went from 200 calls to `get_event_records` => 1
call of `get_asset_records`.
  • Loading branch information
prha authored May 13, 2024
1 parent 7351009 commit e6ddee9
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import TYPE_CHECKING

"""This module contains the execution context objects that are internal to the system.
Expand Down Expand Up @@ -29,6 +30,7 @@
)
from dagster._core.event_api import EventLogRecord
from dagster._core.events import DagsterEventType
from dagster._core.storage.event_log.base import AssetRecord


if TYPE_CHECKING:
Expand Down Expand Up @@ -76,7 +78,7 @@ def maybe_fetch_and_get_input_asset_version_info(
self, key: AssetKey
) -> Optional["InputAssetVersionInfo"]:
if key not in self.input_asset_version_info:
self._fetch_input_asset_version_info(key)
self._fetch_input_asset_version_info([key])
return self.input_asset_version_info[key]

# "external" refers to records for inputs generated outside of this step
Expand All @@ -93,57 +95,88 @@ def fetch_external_input_asset_version_info(self) -> None:
all_dep_keys.append(key)

self.input_asset_version_info = {}
for key in all_dep_keys:
self._fetch_input_asset_version_info(key)
self._fetch_input_asset_version_info(all_dep_keys)
self.is_external_input_asset_version_info_loaded = True

def _fetch_input_asset_version_info(self, key: AssetKey) -> None:
def _fetch_input_asset_version_info(self, asset_keys: Sequence[AssetKey]) -> None:
from dagster._core.definitions.data_version import (
extract_data_version_from_entry,
)

event = self._get_input_asset_event(key)
if event is None:
self.input_asset_version_info[key] = None
else:
storage_id = event.storage_id
# Input name will be none if this is an internal dep
input_name = self._context.job_def.asset_layer.input_for_asset_key(
self._context.node_handle, key
)
# Exclude AllPartitionMapping for now to avoid huge queries
if input_name and self._context.has_asset_partitions_for_input(input_name):
subset = self._context.asset_partitions_subset_for_input(
input_name, require_valid_partitions=False
asset_records_by_key = self._fetch_asset_records(asset_keys)
for key in asset_keys:
asset_record = asset_records_by_key.get(key)
event = self._get_input_asset_event(key, asset_record)
if event is None:
self.input_asset_version_info[key] = None
else:
storage_id = event.storage_id
# Input name will be none if this is an internal dep
input_name = self._context.job_def.asset_layer.input_for_asset_key(
self._context.node_handle, key
)
input_keys = list(subset.get_partition_keys())

# This check represents a temporary constraint that prevents huge query results for upstream
# partition data versions from timing out runs. If a partitioned dependency (a) uses an
# AllPartitionMapping; and (b) has greater than or equal to
# SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD dependency partitions, then we
# process it as a non-partitioned dependency (note that this was the behavior for
# all partition dependencies prior to 2023-08). This means that stale status
# results cannot be accurately computed for the dependency, and there is thus
# corresponding logic in the CachingStaleStatusResolver to account for this. This
# constraint should be removed when we have thoroughly examined the performance of
# the data version retrieval query and can guarantee decent performance.
if len(input_keys) < SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD:
data_version = self._get_partitions_data_version_from_keys(key, input_keys)
# Exclude AllPartitionMapping for now to avoid huge queries
if input_name and self._context.has_asset_partitions_for_input(input_name):
subset = self._context.asset_partitions_subset_for_input(
input_name, require_valid_partitions=False
)
input_keys = list(subset.get_partition_keys())

# This check represents a temporary constraint that prevents huge query results for upstream
# partition data versions from timing out runs. If a partitioned dependency (a) uses an
# AllPartitionMapping; and (b) has greater than or equal to
# SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD dependency partitions, then we
# process it as a non-partitioned dependency (note that this was the behavior for
# all partition dependencies prior to 2023-08). This means that stale status
# results cannot be accurately computed for the dependency, and there is thus
# corresponding logic in the CachingStaleStatusResolver to account for this. This
# constraint should be removed when we have thoroughly examined the performance of
# the data version retrieval query and can guarantee decent performance.
if len(input_keys) < SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD:
data_version = self._get_partitions_data_version_from_keys(key, input_keys)
else:
data_version = extract_data_version_from_entry(event.event_log_entry)
else:
data_version = extract_data_version_from_entry(event.event_log_entry)
else:
data_version = extract_data_version_from_entry(event.event_log_entry)
self.input_asset_version_info[key] = InputAssetVersionInfo(
storage_id,
check.not_none(event.event_log_entry.dagster_event).event_type,
data_version,
event.run_id,
event.timestamp,
self.input_asset_version_info[key] = InputAssetVersionInfo(
storage_id,
check.not_none(event.event_log_entry.dagster_event).event_type,
data_version,
event.run_id,
event.timestamp,
)

def _fetch_asset_records(self, asset_keys: Sequence[AssetKey]) -> Dict[AssetKey, "AssetRecord"]:
batch_size = int(os.getenv("GET_ASSET_RECORDS_FOR_DATA_VERSION_BATCH_SIZE", "100"))
asset_records_by_key = {}
to_fetch = asset_keys
while len(to_fetch):
for record in self._context.instance.get_asset_records(to_fetch[:batch_size]):
asset_records_by_key[record.asset_entry.asset_key] = record
to_fetch = to_fetch[batch_size:]

return asset_records_by_key

def _get_input_asset_event(
self, key: AssetKey, asset_record: Optional["AssetRecord"]
) -> Optional["EventLogRecord"]:
event = None
if asset_record and asset_record.asset_entry.last_materialization_record:
event = asset_record.asset_entry.last_materialization_record
elif (
asset_record
and self._context.instance.event_log_storage.asset_records_have_last_observation
):
event = asset_record.asset_entry.last_observation_record

if (
not event
and not self._context.instance.event_log_storage.asset_records_have_last_observation
):
event = next(
iter(self._context.instance.fetch_observations(key, limit=1).records), None
)

def _get_input_asset_event(self, key: AssetKey) -> Optional["EventLogRecord"]:
event = self._context.instance.get_latest_data_version_record(key)
if event:
self._check_input_asset_event(key, event)
return event
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import pytest
from dagster import (
AssetIn,
AssetMaterialization,
AssetOut,
DagsterInstance,
MaterializeResult,
RunConfig,
Expand All @@ -16,8 +18,6 @@
)
from dagster._config.field import Field
from dagster._config.pythonic_config import Config
from dagster._core.definitions.asset_in import AssetIn
from dagster._core.definitions.asset_out import AssetOut
from dagster._core.definitions.data_version import (
DATA_VERSION_TAG,
SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD,
Expand All @@ -44,6 +44,7 @@
ASSET_PARTITION_RANGE_END_TAG,
ASSET_PARTITION_RANGE_START_TAG,
)
from dagster._utils import Counter, traced_counter
from dagster._utils.test.data_versions import (
assert_code_version,
assert_data_version,
Expand Down Expand Up @@ -1175,3 +1176,33 @@ def asset1():
assert extract_data_provenance_from_entry(record.event_log_entry).input_storage_ids == {
AssetKey(["asset0"]): 500
}


def test_fan_in():
def create_upstream_asset(i: int):
@asset(name=f"upstream_asset_{i}", code_version="abc")
def upstream_asset():
return i

return upstream_asset

upstream_assets = [create_upstream_asset(i) for i in range(100)]

@asset(
ins={f"input_{i}": AssetIn(key=f"upstream_asset_{i}") for i in range(100)},
code_version="abc",
)
def downstream_asset(**kwargs):
return kwargs.values()

all_assets = [*upstream_assets, downstream_asset]
instance = DagsterInstance.ephemeral()
materialize_assets(all_assets, instance)

counter = Counter()
traced_counter.set(counter)
materialize_assets(all_assets, instance)[downstream_asset.key]
assert traced_counter.get().counts() == {
"DagsterInstance.get_asset_records": 1,
"DagsterInstance.get_run_record_by_id": 1,
}

0 comments on commit e6ddee9

Please sign in to comment.