Skip to content

Commit

Permalink
Experimental flag to attach row count metadata as part of `dagster-db…
Browse files Browse the repository at this point in the history
…t` execution (dagster-io#21542)

## Summary

Adds a new `fetch_table_metadata` experimental flag to
`DbtCliResource.cli` which allows us to fetch `dagster/total_row_count`
(introduced in dagster-io#21524) to dbt-built tables:

```python
@dbt_assets(manifest=dbt_manifest)
def jaffle_shop_dbt_assets(
    context: AssetExecutionContext,
    dbt: DbtCliResource,
):
    yield from dbt.cli(
        ["build"],
        context=context,
        fetch_table_metadata=True,
    ).stream()
```

<img width="534" alt="Screenshot 2024-05-03 at 11 03 19 AM"
src="https://github.com/dagster-io/dagster/assets/10215173/c3e64633-5fc3-44e4-99e3-601f0c7a0856">

Under the hood, this PR uses dbt's `dbt.adapters.base.impl.BaseAdapter`
abstraction to let Dagster connect to the user's warehouse using the
dbt-provided credentials. Right now, we just run a simple `select
count(*)` on the tables specified in each `AssetMaterialization` and
`Output`, but this lays some groundwork we could use for fetching other
data as well.

There are a few caveats:
- When using duckdb, we wait for the dbt run to conclude, since duckdb
does not allow simultaneous connections when a write connection is open
(e.g. when dbt is running)
- We don't query row counts on views, since they may include non-trivial
sql which could be expensive to query

## Test Plan

Tested locally w/ duckdb, bigquery, and snowflake. Introduced basic
pytest test to test against duckdb.
  • Loading branch information
benpankow committed May 9, 2024
1 parent 00ee5a2 commit 536685d
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 23 deletions.
223 changes: 200 additions & 23 deletions python_modules/libraries/dagster-dbt/dagster_dbt/core/resources_v2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import contextlib
import copy
import dataclasses
import os
import shutil
import signal
import subprocess
import sys
import uuid
from argparse import Namespace
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import suppress
from dataclasses import InitVar, dataclass, field
from pathlib import Path
Expand Down Expand Up @@ -39,14 +41,17 @@
Output,
TableColumnDep,
TableColumnLineage,
_check as check,
get_dagster_logger,
)
from dagster._annotations import public
from dagster._annotations import experimental, public
from dagster._core.definitions.metadata import (
TableMetadataSet,
TextMetadataValue,
)
from dagster._core.errors import DagsterExecutionInterruptedError, DagsterInvalidPropertyError
from dagster._model.pydantic_compat_layer import compat_model_validator
from dagster._utils import pushd
from dagster._utils.warnings import disable_dagster_warnings
from dbt.adapters.base.impl import BaseAdapter
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters
Expand Down Expand Up @@ -93,6 +98,7 @@
ASSET_RESOURCE_TYPES,
get_dbt_resource_props_by_dbt_unique_id_from_manifest,
)
from .utils import get_future_completion_state_or_err

logger = get_dagster_logger()

Expand All @@ -106,6 +112,8 @@
DBT_INDIRECT_SELECTION_ENV: Final[str] = "DBT_INDIRECT_SELECTION"
DBT_EMPTY_INDIRECT_SELECTION: Final[str] = "empty"

STREAM_EVENTS_THREADPOOL_SIZE: Final[int] = 4


def _get_dbt_target_path() -> Path:
return Path(os.getenv("DBT_TARGET_PATH", "target"))
Expand Down Expand Up @@ -505,6 +513,9 @@ def _build_column_lineage_metadata(
)


DbtDagsterEventType = Union[Output, AssetMaterialization, AssetCheckResult, AssetObservation]


@dataclass
class DbtCliInvocation:
"""The representation of an invoked dbt command.
Expand All @@ -528,6 +539,7 @@ class DbtCliInvocation:
init=False, default=DAGSTER_DBT_TERMINATION_TIMEOUT_SECONDS
)
adapter: Optional[BaseAdapter] = field(default=None)
should_fetch_row_count: bool = field(default=False)
_stdout: List[str] = field(init=False, default_factory=list)
_error_messages: List[str] = field(init=False, default_factory=list)

Expand Down Expand Up @@ -585,6 +597,7 @@ def run(
raise_on_error=raise_on_error,
context=context,
adapter=adapter,
should_fetch_row_count=False,
)
logger.info(f"Running dbt command: `{dbt_cli_invocation.dbt_command}`.")

Expand Down Expand Up @@ -633,6 +646,109 @@ def is_successful(self) -> bool:

return self.process.wait() == 0

def _stream_asset_events(
self,
) -> Iterator[DbtDagsterEventType]:
"""Stream the dbt CLI events and convert them to Dagster events."""
for event in self.stream_raw_events():
yield from event.to_default_asset_events(
manifest=self.manifest,
dagster_dbt_translator=self.dagster_dbt_translator,
context=self.context,
target_path=self.target_path,
)

def _get_dbt_resource_props_from_event(self, event: DbtDagsterEventType) -> Dict[str, Any]:
unique_id = cast(TextMetadataValue, event.metadata["unique_id"]).text
return check.not_none(self.manifest["nodes"].get(unique_id))

def _attach_post_materialization_metadata(
self,
event: DbtDagsterEventType,
) -> DbtDagsterEventType:
"""Threaded task which runs any postprocessing steps on the given event before it's
emitted to user code.
This is used to, for example, query the row count of a table after it has been
materialized by dbt.
"""
adapter = check.not_none(self.adapter)

dbt_resource_props = self._get_dbt_resource_props_from_event(event)
is_view = dbt_resource_props["config"]["materialized"] == "view"

# Avoid counting rows for views, since they may include complex SQL queries
# that are costly to execute. We can revisit this in the future if there is
# a demand for it.
if is_view:
return event

# If the adapter is DuckDB, we need to wait for the dbt CLI process to complete
# so that the DuckDB lock is released. This is because DuckDB does not allow for
# opening multiple connections to the same database when a write connection, such
# as the one dbt uses, is open.
try:
from dbt.adapters.duckdb import DuckDBAdapter

if isinstance(adapter, DuckDBAdapter):
self._dbt_run_thread.result()
except ImportError:
pass

unique_id = dbt_resource_props["unique_id"]
logger.debug("Fetching row count for %s", unique_id)
table_str = f"{dbt_resource_props['database']}.{dbt_resource_props['schema']}.{dbt_resource_props['name']}"

with adapter.connection_named(f"row_count_{unique_id}"):
query_result = adapter.execute(
f"""
SELECT
count(*) as row_count
FROM
{table_str}
""",
fetch=True,
)
row_count = query_result[1][0]["row_count"]
additional_metadata = {**TableMetadataSet(row_count=row_count)}

if isinstance(event, Output):
return event.with_metadata(metadata={**event.metadata, **additional_metadata})
else:
return event._replace(metadata={**event.metadata, **additional_metadata})

def _stream_dbt_events_and_enqueue_postprocessing(
self,
output_events_and_futures: List[Union[Future, DbtDagsterEventType]],
executor: ThreadPoolExecutor,
) -> None:
"""Task which streams dbt events and either directly places them in
the output_events list to be emitted to user code, or enqueues post-processing tasks
where needed.
"""
for event in self._stream_asset_events():
# For any materialization or output event, we run postprocessing steps
# to attach additional metadata to the event.
if self.should_fetch_row_count and isinstance(event, (AssetMaterialization, Output)):
output_events_and_futures.append(
executor.submit(
self._attach_post_materialization_metadata,
event,
)
)
else:
output_events_and_futures.append(event)

@experimental
def enable_fetch_row_count(
self,
) -> "DbtCliInvocation":
"""Experimental functionality which will fetch row counts for materialized dbt
models in a dbt run once they are built. Note that row counts will not be fetched
for views, since this requires running the view's SQL query which may be costly.
"""
return dataclasses.replace(self, should_fetch_row_count=True)

@public
def stream(
self,
Expand Down Expand Up @@ -669,14 +785,60 @@ def stream(
def my_dbt_assets(context, dbt: DbtCliResource):
yield from dbt.cli(["run"], context=context).stream()
"""
for event in self.stream_raw_events():
yield from event.to_default_asset_events(
manifest=self.manifest,
dagster_dbt_translator=self.dagster_dbt_translator,
context=self.context,
target_path=self.target_path,
has_any_parallel_tasks = self.should_fetch_row_count

if not has_any_parallel_tasks:
# If we're not enqueuing any parallel tasks, we can just stream the events in
# the main thread.
yield from self._stream_asset_events()
return

if self.should_fetch_row_count:
logger.info(
"Row counts will be fetched for non-view models once they are materialized."
)

# We keep a list of emitted Dagster events and pending futures which augment
# emitted events with additional metadata. This ensures we can yield events in the order
# they are emitted by dbt.
output_events_and_futures: List[Union[Future, DbtDagsterEventType]] = []

# Point at project directory to ensure dbt adapters run correctly
with pushd(str(self.project_dir)), ThreadPoolExecutor(
max_workers=STREAM_EVENTS_THREADPOOL_SIZE
) as executor:
self._dbt_run_thread = executor.submit(
self._stream_dbt_events_and_enqueue_postprocessing,
output_events_and_futures,
executor,
)

# Step through the list of output events and futures, yielding them in order
# once they are ready to be emitted
event_to_emit_idx = 0
while True:
all_work_complete = get_future_completion_state_or_err(
[self._dbt_run_thread, *output_events_and_futures]
)
if all_work_complete and event_to_emit_idx >= len(output_events_and_futures):
break

if event_to_emit_idx < len(output_events_and_futures):
event_to_emit = output_events_and_futures[event_to_emit_idx]

if isinstance(event_to_emit, Future):
# If the next event to emit is a Future (waiting on postprocessing),
# we need to wait for it to complete before yielding the event.
try:
event = event_to_emit.result(timeout=0.1)
yield event
event_to_emit_idx += 1
except:
pass
else:
yield event_to_emit
event_to_emit_idx += 1

@public
def stream_raw_events(self) -> Iterator[DbtCliEventMessage]:
"""Stream the events from the dbt CLI process.
Expand Down Expand Up @@ -1116,6 +1278,18 @@ def _initialize_adapter(self, args: Sequence[str]) -> BaseAdapter:
project = load_project(self.project_dir, False, profile, {})
config = RuntimeConfig.from_parts(project, profile, flags)

# If the dbt adapter is DuckDB, set the access mode to READ_ONLY, since DuckDB only allows
# simultaneous connections for read-only access.
try:
from dbt.adapters.duckdb.credentials import DuckDBCredentials

if isinstance(config.credentials, DuckDBCredentials):
if not config.credentials.config_options:
config.credentials.config_options = {}
config.credentials.config_options["access_mode"] = "READ_ONLY"
except ImportError:
pass

cleanup_event_logger()
register_adapter(config)
adapter = cast(BaseAdapter, get_adapter(config))
Expand Down Expand Up @@ -1369,23 +1543,26 @@ def my_dbt_op(dbt: DbtCliResource):
if not target_path.is_absolute():
target_path = project_dir.joinpath(target_path)

try:
adapter = self._initialize_adapter(args)
except:
# defer exceptions until they can be raised in the runtime context of the invocation
adapter = None
adapter: Optional[BaseAdapter] = None
with pushd(self.project_dir):
try:
adapter = self._initialize_adapter(args)

return DbtCliInvocation.run(
args=args,
env=env,
manifest=manifest,
dagster_dbt_translator=dagster_dbt_translator,
project_dir=project_dir,
target_path=target_path,
raise_on_error=raise_on_error,
context=context,
adapter=adapter,
)
except:
# defer exceptions until they can be raised in the runtime context of the invocation
pass

return DbtCliInvocation.run(
args=args,
env=env,
manifest=manifest,
dagster_dbt_translator=dagster_dbt_translator,
project_dir=project_dir,
target_path=target_path,
raise_on_error=raise_on_error,
context=context,
adapter=adapter,
)


def _get_subset_selection_for_context(
Expand Down
18 changes: 18 additions & 0 deletions python_modules/libraries/dagster-dbt/dagster_dbt/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import subprocess
from concurrent.futures import Future
from typing import Any, Iterator, List, Mapping, NamedTuple, Optional, Sequence, Union

import dagster._check as check
Expand Down Expand Up @@ -282,3 +283,20 @@ def parse_manifest(path: str, target_path: str = DEFAULT_DBT_TARGET_PATH) -> Map
return json.load(file)
except FileNotFoundError:
raise DagsterDbtCliOutputsNotFoundError(path=manifest_path)


def get_future_completion_state_or_err(futures: List[Union[Future, Any]]) -> bool:
"""Given a list of futures (and potentially other objects), return True if all futures are completed.
If any future has an exception, raise the exception.
"""
for future in futures:
if not isinstance(future, Future):
continue

if not future.done():
return False

exception = future.exception()
if exception:
raise exception
return True
Loading

0 comments on commit 536685d

Please sign in to comment.