Ask AI

Source code for dagster_dbt.core.dbt_cli_invocation

import contextlib
import copy
import os
import shutil
import signal
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterator, List, Mapping, NamedTuple, Optional, Sequence, Union, cast

import orjson
from dagster import (
    AssetCheckResult,
    AssetExecutionContext,
    AssetMaterialization,
    AssetObservation,
    OpExecutionContext,
    Output,
    get_dagster_logger,
)
from dagster._annotations import public
from dagster._core.errors import DagsterExecutionInterruptedError
from dbt.adapters.base.impl import BaseAdapter, BaseColumn, BaseRelation
from typing_extensions import Final, Literal

from dagster_dbt.core.dbt_cli_event import DbtCliEventMessage
from dagster_dbt.core.dbt_event_iterator import DbtDagsterEventType, DbtEventIterator
from dagster_dbt.dagster_dbt_translator import DagsterDbtTranslator
from dagster_dbt.errors import DagsterDbtCliRuntimeError

PARTIAL_PARSE_FILE_NAME = "partial_parse.msgpack"
DAGSTER_DBT_TERMINATION_TIMEOUT_SECONDS = int(
    os.getenv("DAGSTER_DBT_TERMINATION_TIMEOUT_SECONDS", "25")
)
DEFAULT_EVENT_POSTPROCESSING_THREADPOOL_SIZE: Final[int] = 4


logger = get_dagster_logger()


def _get_dbt_target_path() -> Path:
    return Path(os.getenv("DBT_TARGET_PATH", "target"))


class RelationKey(NamedTuple):
    """Hashable representation of the information needed to identify a relation in a database."""

    database: str
    schema: str
    identifier: str


class RelationData(NamedTuple):
    """Relation metadata queried from a database."""

    name: str
    columns: List[BaseColumn]


def _get_relation_from_adapter(adapter: BaseAdapter, relation_key: RelationKey) -> BaseRelation:
    return adapter.Relation.create(
        database=relation_key.database,
        schema=relation_key.schema,
        identifier=relation_key.identifier,
    )


[docs] @dataclass class DbtCliInvocation: """The representation of an invoked dbt command. Args: process (subprocess.Popen): The process running the dbt command. manifest (Mapping[str, Any]): The dbt manifest blob. project_dir (Path): The path to the dbt project. target_path (Path): The path to the dbt target folder. raise_on_error (bool): Whether to raise an exception if the dbt command fails. """ process: subprocess.Popen manifest: Mapping[str, Any] dagster_dbt_translator: DagsterDbtTranslator project_dir: Path target_path: Path raise_on_error: bool context: Optional[Union[OpExecutionContext, AssetExecutionContext]] = field( default=None, repr=False ) termination_timeout_seconds: float = field( init=False, default=DAGSTER_DBT_TERMINATION_TIMEOUT_SECONDS ) adapter: Optional[BaseAdapter] = field(default=None) postprocessing_threadpool_num_threads: int = field( init=False, default=DEFAULT_EVENT_POSTPROCESSING_THREADPOOL_SIZE ) _stdout: List[Union[str, Dict[str, Any]]] = field(init=False, default_factory=list) _error_messages: List[str] = field(init=False, default_factory=list) # Caches fetching relation column metadata to avoid redundant queries to the database. _relation_column_metadata_cache: Dict[RelationKey, RelationData] = field( init=False, default_factory=dict ) def _get_columns_from_dbt_resource_props( self, adapter: BaseAdapter, dbt_resource_props: Dict[str, Any] ) -> RelationData: """Given a dbt resource properties dictionary, fetches the resource's column metadata from the database, or returns the cached metadata if it has already been fetched. """ relation_key = RelationKey( database=dbt_resource_props["database"], schema=dbt_resource_props["schema"], identifier=( dbt_resource_props["identifier"] if dbt_resource_props["unique_id"].startswith("source") else dbt_resource_props["alias"] ), ) if relation_key in self._relation_column_metadata_cache: return self._relation_column_metadata_cache[relation_key] relation = _get_relation_from_adapter(adapter=adapter, relation_key=relation_key) cols: List = adapter.get_columns_in_relation(relation=relation) return self._relation_column_metadata_cache.setdefault( relation_key, RelationData(name=str(relation), columns=cols) ) @classmethod def run( cls, args: Sequence[str], env: Dict[str, str], manifest: Mapping[str, Any], dagster_dbt_translator: DagsterDbtTranslator, project_dir: Path, target_path: Path, raise_on_error: bool, context: Optional[Union[OpExecutionContext, AssetExecutionContext]], adapter: Optional[BaseAdapter], ) -> "DbtCliInvocation": # Attempt to take advantage of partial parsing. If there is a `partial_parse.msgpack` in # in the target folder, then copy it to the dynamic target path. # # This effectively allows us to skip the parsing of the manifest, which can be expensive. # See https://docs.getdbt.com/reference/programmatic-invocations#reusing-objects for more # details. current_target_path = _get_dbt_target_path() partial_parse_file_path = ( current_target_path.joinpath(PARTIAL_PARSE_FILE_NAME) if current_target_path.is_absolute() else project_dir.joinpath(current_target_path, PARTIAL_PARSE_FILE_NAME) ) partial_parse_destination_target_path = target_path.joinpath(PARTIAL_PARSE_FILE_NAME) if partial_parse_file_path.exists() and not partial_parse_destination_target_path.exists(): logger.info( f"Copying `{partial_parse_file_path}` to `{partial_parse_destination_target_path}`" " to take advantage of partial parsing." ) partial_parse_destination_target_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(partial_parse_file_path, partial_parse_destination_target_path) # Create a subprocess that runs the dbt CLI command. process = subprocess.Popen( args=args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, cwd=project_dir, ) dbt_cli_invocation = cls( process=process, manifest=manifest, dagster_dbt_translator=dagster_dbt_translator, project_dir=project_dir, target_path=target_path, raise_on_error=raise_on_error, context=context, adapter=adapter, ) logger.info(f"Running dbt command: `{dbt_cli_invocation.dbt_command}`.") return dbt_cli_invocation
[docs] @public def wait(self) -> "DbtCliInvocation": """Wait for the dbt CLI process to complete. Returns: DbtCliInvocation: The current representation of the dbt CLI invocation. Examples: .. code-block:: python from dagster_dbt import DbtCliResource dbt = DbtCliResource(project_dir="/path/to/dbt/project") dbt_cli_invocation = dbt.cli(["run"]).wait() """ list(self.stream_raw_events()) return self
[docs] @public def is_successful(self) -> bool: """Return whether the dbt CLI process completed successfully. Returns: bool: True, if the dbt CLI process returns with a zero exit code, and False otherwise. Examples: .. code-block:: python from dagster_dbt import DbtCliResource dbt = DbtCliResource(project_dir="/path/to/dbt/project") dbt_cli_invocation = dbt.cli(["run"], raise_on_error=False) if dbt_cli_invocation.is_successful(): ... """ self._stdout = list(self._stream_stdout()) return self.process.wait() == 0 and not self._error_messages
[docs] @public def get_error(self) -> Optional[Exception]: """Return an exception if the dbt CLI process failed. Returns: Optional[Exception]: An exception if the dbt CLI process failed, and None otherwise. Examples: .. code-block:: python from dagster_dbt import DbtCliResource dbt = DbtCliResource(project_dir="/path/to/dbt/project") dbt_cli_invocation = dbt.cli(["run"], raise_on_error=False) error = dbt_cli_invocation.get_error() if error: logger.error(error) """ if self.is_successful(): return None log_path = self.target_path.joinpath("dbt.log") extra_description = "" if log_path.exists(): extra_description = f", or view the dbt debug log: {log_path}" return DagsterDbtCliRuntimeError( description=( f"The dbt CLI process with command\n\n" f"`{self.dbt_command}`\n\n" f"failed with exit code `{self.process.returncode}`." " Check the stdout in the Dagster compute logs for the full information about" f" the error{extra_description}.{self._format_error_messages()}" ), )
def _stream_asset_events( self, ) -> Iterator[DbtDagsterEventType]: """Stream the dbt CLI events and convert them to Dagster events.""" for event in self.stream_raw_events(): yield from event.to_default_asset_events( manifest=self.manifest, dagster_dbt_translator=self.dagster_dbt_translator, context=self.context, target_path=self.target_path, )
[docs] @public def stream( self, ) -> ( "DbtEventIterator[Union[Output, AssetMaterialization, AssetObservation, AssetCheckResult]]" ): """Stream the events from the dbt CLI process and convert them to Dagster events. Returns: Iterator[Union[Output, AssetMaterialization, AssetObservation, AssetCheckResult]]: A set of corresponding Dagster events. In a Dagster asset definition, the following are yielded: - Output for refables (e.g. models, seeds, snapshots.) - AssetCheckResult for dbt test results that are enabled as asset checks. - AssetObservation for dbt test results that are not enabled as asset checks. In a Dagster op definition, the following are yielded: - AssetMaterialization for dbt test results that are not enabled as asset checks. - AssetObservation for dbt test results. Examples: .. code-block:: python from pathlib import Path from dagster_dbt import DbtCliResource, dbt_assets @dbt_assets(manifest=Path("target", "manifest.json")) def my_dbt_assets(context, dbt: DbtCliResource): yield from dbt.cli(["run"], context=context).stream() """ return DbtEventIterator( self._stream_asset_events(), self, )
[docs] @public def stream_raw_events(self) -> Iterator[DbtCliEventMessage]: """Stream the events from the dbt CLI process. Returns: Iterator[DbtCliEventMessage]: An iterator of events from the dbt CLI process. """ event_history_metadata_by_unique_id: Dict[str, Dict[str, Any]] = {} for raw_event in self._stdout or self._stream_stdout(): if isinstance(raw_event, str): # If we can't parse the event, then just emit it as a raw log. sys.stdout.write(raw_event + "\n") sys.stdout.flush() continue unique_id: Optional[str] = raw_event["data"].get("node_info", {}).get("unique_id") is_result_event = DbtCliEventMessage.is_result_event(raw_event) event_history_metadata: Dict[str, Any] = {} if unique_id and is_result_event: event_history_metadata = copy.deepcopy( event_history_metadata_by_unique_id.get(unique_id, {}) ) event = DbtCliEventMessage( raw_event=raw_event, event_history_metadata=event_history_metadata ) # Attempt to parse the column level metadata from the event message. # If it exists, save it as historical metadata to attach to the NodeFinished event. if event.raw_event["info"]["name"] == "JinjaLogInfo": with contextlib.suppress(orjson.JSONDecodeError): column_level_metadata = orjson.loads(event.raw_event["info"]["msg"]) event_history_metadata_by_unique_id[cast(str, unique_id)] = ( column_level_metadata ) # Don't show this message in stdout continue # Re-emit the logs from dbt CLI process into stdout. sys.stdout.write(str(event) + "\n") sys.stdout.flush() yield event # Ensure that the dbt CLI process has completed. self._raise_on_error()
[docs] @public def get_artifact( self, artifact: Union[ Literal["manifest.json"], Literal["catalog.json"], Literal["run_results.json"], Literal["sources.json"], ], ) -> Dict[str, Any]: """Retrieve a dbt artifact from the target path. See https://docs.getdbt.com/reference/artifacts/dbt-artifacts for more information. Args: artifact (Union[Literal["manifest.json"], Literal["catalog.json"], Literal["run_results.json"], Literal["sources.json"]]): The name of the artifact to retrieve. Returns: Dict[str, Any]: The artifact as a dictionary. Examples: .. code-block:: python from dagster_dbt import DbtCliResource dbt = DbtCliResource(project_dir="/path/to/dbt/project") dbt_cli_invocation = dbt.cli(["run"]).wait() # Retrieve the run_results.json artifact. run_results = dbt_cli_invocation.get_artifact("run_results.json") """ artifact_path = self.target_path.joinpath(artifact) return orjson.loads(artifact_path.read_bytes())
@property def dbt_command(self) -> str: """The dbt CLI command that was invoked.""" return " ".join(cast(Sequence[str], self.process.args)) def _stream_stdout(self) -> Iterator[Union[str, Dict[str, Any]]]: """Stream the stdout from the dbt CLI process.""" try: if not self.process.stdout or self.process.stdout.closed: return with self.process.stdout: for raw_line in self.process.stdout or []: raw_event_str = raw_line.decode().strip() try: raw_event = orjson.loads(raw_event_str) # Parse the error message from the event, if it exists. is_error_message = raw_event["info"]["level"] == "error" if is_error_message: self._error_messages.append(raw_event["info"]["msg"]) yield raw_event except: yield raw_event_str except DagsterExecutionInterruptedError: logger.info(f"Forwarding interrupt signal to dbt command: `{self.dbt_command}`.") self.process.send_signal(signal.SIGINT) self.process.wait(timeout=self.termination_timeout_seconds) logger.info(f"dbt process terminated with exit code `{self.process.returncode}`.") raise def _format_error_messages(self) -> str: """Format the error messages from the dbt CLI process.""" if not self._error_messages: return "" return "\n\n".join( [ "", "Errors parsed from dbt logs:", *self._error_messages, ] ) def _raise_on_error(self) -> None: """Ensure that the dbt CLI process has completed. If the process has not successfully completed, then optionally raise an error. """ logger.info(f"Finished dbt command: `{self.dbt_command}`.") error = self.get_error() if error and self.raise_on_error: raise error