Ask AI

Source code for dagster._core.storage.runs.sql_run_storage

import logging
import uuid
import zlib
from abc import abstractmethod
from collections import defaultdict
from datetime import datetime
from enum import Enum
from typing import (
    Any,
    Callable,
    ContextManager,
    Dict,
    Iterable,
    Mapping,
    NamedTuple,
    Optional,
    Sequence,
    Set,
    Tuple,
    Union,
    cast,
)

import pendulum
import sqlalchemy as db
import sqlalchemy.exc as db_exc
from sqlalchemy.engine import Connection

import dagster._check as check
from dagster._core.errors import (
    DagsterInvariantViolationError,
    DagsterRunAlreadyExists,
    DagsterRunNotFoundError,
    DagsterSnapshotDoesNotExist,
)
from dagster._core.events import (
    EVENT_TYPE_TO_PIPELINE_RUN_STATUS,
    DagsterEvent,
    DagsterEventType,
    RunFailureReason,
)
from dagster._core.execution.backfill import BulkActionStatus, PartitionBackfill
from dagster._core.remote_representation.origin import RemoteJobOrigin
from dagster._core.snap import (
    ExecutionPlanSnapshot,
    JobSnapshot,
    create_execution_plan_snapshot_id,
    create_job_snapshot_id,
)
from dagster._core.storage.sql import SqlAlchemyQuery
from dagster._core.storage.sqlalchemy_compat import (
    db_fetch_mappings,
    db_scalar_subquery,
    db_select,
    db_subquery,
)
from dagster._core.storage.tags import (
    PARTITION_NAME_TAG,
    PARTITION_SET_TAG,
    REPOSITORY_LABEL_TAG,
    ROOT_RUN_ID_TAG,
    RUN_FAILURE_REASON_TAG,
)
from dagster._daemon.types import DaemonHeartbeat
from dagster._serdes import (
    deserialize_value,
    serialize_value,
)
from dagster._seven import JSONDecodeError
from dagster._utils import PrintFn, utc_datetime_from_timestamp
from dagster._utils.merger import merge_dicts

from ..dagster_run import (
    DagsterRun,
    DagsterRunStatus,
    JobBucket,
    RunPartitionData,
    RunRecord,
    RunsFilter,
    TagBucket,
)
from .base import RunStorage
from .migration import (
    OPTIONAL_DATA_MIGRATIONS,
    REQUIRED_DATA_MIGRATIONS,
    RUN_PARTITIONS,
    MigrationFn,
)
from .schema import (
    BulkActionsTable,
    DaemonHeartbeatsTable,
    InstanceInfo,
    KeyValueStoreTable,
    RunsTable,
    RunTagsTable,
    SecondaryIndexMigrationTable,
    SnapshotsTable,
)


class SnapshotType(Enum):
    PIPELINE = "PIPELINE"
    EXECUTION_PLAN = "EXECUTION_PLAN"


[docs]class SqlRunStorage(RunStorage): """Base class for SQL based run storages.""" @abstractmethod def connect(self) -> ContextManager[Connection]: """Context manager yielding a sqlalchemy.engine.Connection.""" @abstractmethod def upgrade(self) -> None: """This method should perform any schema or data migrations necessary to bring an out-of-date instance of the storage up to date. """ def fetchall(self, query: SqlAlchemyQuery) -> Sequence[Any]: with self.connect() as conn: return db_fetch_mappings(conn, query) def fetchone(self, query: SqlAlchemyQuery) -> Optional[Any]: with self.connect() as conn: if db.__version__.startswith("2."): return conn.execute(query).mappings().first() else: return conn.execute(query).fetchone() def add_run(self, dagster_run: DagsterRun) -> DagsterRun: check.inst_param(dagster_run, "dagster_run", DagsterRun) if dagster_run.job_snapshot_id and not self.has_job_snapshot(dagster_run.job_snapshot_id): raise DagsterSnapshotDoesNotExist( f"Snapshot {dagster_run.job_snapshot_id} does not exist in run storage" ) has_tags = dagster_run.tags and len(dagster_run.tags) > 0 partition = dagster_run.tags.get(PARTITION_NAME_TAG) if has_tags else None partition_set = dagster_run.tags.get(PARTITION_SET_TAG) if has_tags else None runs_insert = RunsTable.insert().values( run_id=dagster_run.run_id, pipeline_name=dagster_run.job_name, status=dagster_run.status.value, run_body=serialize_value(dagster_run), snapshot_id=dagster_run.job_snapshot_id, partition=partition, partition_set=partition_set, ) with self.connect() as conn: try: conn.execute(runs_insert) except db_exc.IntegrityError as exc: raise DagsterRunAlreadyExists from exc tags_to_insert = dagster_run.tags_for_storage() if tags_to_insert: conn.execute( RunTagsTable.insert(), [ dict(run_id=dagster_run.run_id, key=k, value=v) for k, v in tags_to_insert.items() ], ) return dagster_run def handle_run_event(self, run_id: str, event: DagsterEvent) -> None: from dagster._core.events import JobFailureData check.str_param(run_id, "run_id") check.inst_param(event, "event", DagsterEvent) if event.event_type not in EVENT_TYPE_TO_PIPELINE_RUN_STATUS: return run = self._get_run_by_id(run_id) if not run: # TODO log? return new_job_status = EVENT_TYPE_TO_PIPELINE_RUN_STATUS[event.event_type] run_stats_cols_in_index = self.has_run_stats_index_cols() kwargs = {} # consider changing the `handle_run_event` signature to get timestamp off of the # EventLogEntry instead of the DagsterEvent, for consistency now = pendulum.now("UTC") if run_stats_cols_in_index and event.event_type == DagsterEventType.PIPELINE_START: kwargs["start_time"] = now.timestamp() if run_stats_cols_in_index and event.event_type in { DagsterEventType.PIPELINE_CANCELED, DagsterEventType.PIPELINE_FAILURE, DagsterEventType.PIPELINE_SUCCESS, }: kwargs["end_time"] = now.timestamp() with self.connect() as conn: conn.execute( RunsTable.update() .where(RunsTable.c.run_id == run_id) .values( run_body=serialize_value(run.with_status(new_job_status)), status=new_job_status.value, update_timestamp=now, **kwargs, ) ) if event.event_type == DagsterEventType.PIPELINE_FAILURE and isinstance( event.event_specific_data, JobFailureData ): failure_reason = event.event_specific_data.failure_reason if failure_reason and failure_reason != RunFailureReason.UNKNOWN: self.add_run_tags(run_id, {RUN_FAILURE_REASON_TAG: failure_reason.value}) def _row_to_run(self, row: Dict) -> DagsterRun: run = deserialize_value(row["run_body"], DagsterRun) status = DagsterRunStatus(row["status"]) # NOTE: the status column is more trustworthy than the status in the run body, since concurrent # writes (e.g. handle_run_event and add_tags) can cause the status in the body to be out of # overriden with an old value. return run.with_status(status) def _rows_to_runs(self, rows: Iterable[Dict]) -> Sequence[DagsterRun]: return list(map(self._row_to_run, rows)) def _add_cursor_limit_to_query( self, query: SqlAlchemyQuery, cursor: Optional[str], limit: Optional[int], order_by: Optional[str], ascending: Optional[bool], ) -> SqlAlchemyQuery: """Helper function to deal with cursor/limit pagination args.""" if cursor: cursor_query = db_select([RunsTable.c.id]).where(RunsTable.c.run_id == cursor) if ascending: query = query.where(RunsTable.c.id > db_scalar_subquery(cursor_query)) else: query = query.where(RunsTable.c.id < db_scalar_subquery(cursor_query)) if limit: query = query.limit(limit) sorting_column = getattr(RunsTable.c, order_by) if order_by else RunsTable.c.id direction = db.asc if ascending else db.desc query = query.order_by(direction(sorting_column)) return query def _add_filters_to_query(self, query: SqlAlchemyQuery, filters: RunsFilter) -> SqlAlchemyQuery: check.inst_param(filters, "filters", RunsFilter) if filters.run_ids: query = query.where(RunsTable.c.run_id.in_(filters.run_ids)) if filters.job_name: query = query.where(RunsTable.c.pipeline_name == filters.job_name) if filters.statuses: query = query.where( RunsTable.c.status.in_([status.value for status in filters.statuses]) ) if filters.snapshot_id: query = query.where(RunsTable.c.snapshot_id == filters.snapshot_id) if filters.updated_after: query = query.where(RunsTable.c.update_timestamp > filters.updated_after) if filters.updated_before: query = query.where(RunsTable.c.update_timestamp < filters.updated_before) if filters.created_after: query = query.where(RunsTable.c.create_timestamp > filters.created_after) if filters.created_before: query = query.where(RunsTable.c.create_timestamp < filters.created_before) if filters.tags: query = self._apply_tags_table_filters(query, filters.tags) return query def _runs_query( self, filters: Optional[RunsFilter] = None, cursor: Optional[str] = None, limit: Optional[int] = None, columns: Optional[Sequence[str]] = None, order_by: Optional[str] = None, ascending: bool = False, bucket_by: Optional[Union[JobBucket, TagBucket]] = None, ) -> SqlAlchemyQuery: filters = check.opt_inst_param(filters, "filters", RunsFilter, default=RunsFilter()) check.opt_str_param(cursor, "cursor") check.opt_int_param(limit, "limit") check.opt_sequence_param(columns, "columns") check.opt_str_param(order_by, "order_by") check.opt_bool_param(ascending, "ascending") if columns is None: columns = ["run_body", "status"] table = RunsTable base_query = db_select([getattr(RunsTable.c, column) for column in columns]).select_from( table ) base_query = self._add_filters_to_query(base_query, filters) return self._add_cursor_limit_to_query(base_query, cursor, limit, order_by, ascending) def _apply_tags_table_filters( self, query: SqlAlchemyQuery, tags: Mapping[str, Union[str, Sequence[str]]] ) -> SqlAlchemyQuery: """Efficient query pattern for filtering by multiple tags.""" expected_count = len(tags) if expected_count == 1: key, value = next(iter(tags.items())) # since run tags should be much larger than runs, select where exists # should be more efficient than joining subquery = db.exists().where( (RunsTable.c.run_id == RunTagsTable.c.run_id) & (RunTagsTable.c.key == key) & ( (RunTagsTable.c.value == value) if isinstance(value, str) else RunTagsTable.c.value.in_(value) ) ) query = query.where(subquery) elif expected_count > 1: # efficient query for filtering by multiple tags. first find all run_ids that match # all tags, then select from runs table where run_id in that set subquery = db_select([RunTagsTable.c.run_id]) expressions = [] for key, value in tags.items(): expression = RunTagsTable.c.key == key if isinstance(value, str): expression &= RunTagsTable.c.value == value else: expression &= RunTagsTable.c.value.in_(value) expressions.append(expression) subquery = subquery.where(db.or_(*expressions)) subquery = subquery.group_by(RunTagsTable.c.run_id) subquery = subquery.having( db.func.count(db.distinct(RunTagsTable.c.key)) == expected_count ) query = query.where(RunsTable.c.run_id.in_(subquery)) return query def get_runs( self, filters: Optional[RunsFilter] = None, cursor: Optional[str] = None, limit: Optional[int] = None, bucket_by: Optional[Union[JobBucket, TagBucket]] = None, ascending: bool = False, ) -> Sequence[DagsterRun]: query = self._runs_query(filters, cursor, limit, bucket_by=bucket_by, ascending=ascending) rows = self.fetchall(query) return self._rows_to_runs(rows) def get_run_ids( self, filters: Optional[RunsFilter] = None, cursor: Optional[str] = None, limit: Optional[int] = None, ) -> Sequence[str]: query = self._runs_query(filters=filters, cursor=cursor, limit=limit, columns=["run_id"]) rows = self.fetchall(query) return [row["run_id"] for row in rows] def get_runs_count(self, filters: Optional[RunsFilter] = None) -> int: subquery = db_subquery(self._runs_query(filters=filters)) query = db_select([db.func.count().label("count")]).select_from(subquery) row = self.fetchone(query) count = row["count"] if row else 0 return count def _get_run_by_id(self, run_id: str) -> Optional[DagsterRun]: check.str_param(run_id, "run_id") query = db_select([RunsTable.c.run_body, RunsTable.c.status]).where( RunsTable.c.run_id == run_id ) rows = self.fetchall(query) return self._row_to_run(rows[0]) if rows else None def get_run_records( self, filters: Optional[RunsFilter] = None, limit: Optional[int] = None, order_by: Optional[str] = None, ascending: bool = False, cursor: Optional[str] = None, bucket_by: Optional[Union[JobBucket, TagBucket]] = None, ) -> Sequence[RunRecord]: filters = check.opt_inst_param(filters, "filters", RunsFilter, default=RunsFilter()) check.opt_int_param(limit, "limit") columns = ["id", "run_body", "status", "create_timestamp", "update_timestamp"] if self.has_run_stats_index_cols(): columns += ["start_time", "end_time"] # only fetch columns we use to build RunRecord query = self._runs_query( filters=filters, limit=limit, columns=columns, order_by=order_by, ascending=ascending, cursor=cursor, bucket_by=bucket_by, ) rows = self.fetchall(query) return [ RunRecord( storage_id=check.int_param(row["id"], "id"), dagster_run=self._row_to_run(row), create_timestamp=check.inst(row["create_timestamp"], datetime), update_timestamp=check.inst(row["update_timestamp"], datetime), start_time=( check.opt_inst(row["start_time"], float) if "start_time" in row else None ), end_time=check.opt_inst(row["end_time"], float) if "end_time" in row else None, ) for row in rows ] def get_run_tags( self, tag_keys: Sequence[str], value_prefix: Optional[str] = None, limit: Optional[int] = None, ) -> Sequence[Tuple[str, Set[str]]]: result = defaultdict(set) query = ( db_select([RunTagsTable.c.key, RunTagsTable.c.value]) .distinct() .order_by(RunTagsTable.c.key, RunTagsTable.c.value) .where(RunTagsTable.c.key.in_(tag_keys)) ) if value_prefix: query = query.where(RunTagsTable.c.value.startswith(value_prefix)) if limit: query = query.limit(limit) rows = self.fetchall(query) for r in rows: result[r["key"]].add(r["value"]) return sorted(list([(k, v) for k, v in result.items()]), key=lambda x: x[0]) def get_run_tag_keys(self) -> Sequence[str]: query = db_select([RunTagsTable.c.key]).distinct().order_by(RunTagsTable.c.key) rows = self.fetchall(query) return sorted([r["key"] for r in rows]) def add_run_tags(self, run_id: str, new_tags: Mapping[str, str]) -> None: check.str_param(run_id, "run_id") check.mapping_param(new_tags, "new_tags", key_type=str, value_type=str) run = self._get_run_by_id(run_id) if not run: raise DagsterRunNotFoundError( f"Run {run_id} was not found in instance.", invalid_run_id=run_id ) current_tags = run.tags if run.tags else {} all_tags = merge_dicts(current_tags, new_tags) partition = all_tags.get(PARTITION_NAME_TAG) partition_set = all_tags.get(PARTITION_SET_TAG) with self.connect() as conn: conn.execute( RunsTable.update() .where(RunsTable.c.run_id == run_id) .values( run_body=serialize_value(run.with_tags(merge_dicts(current_tags, new_tags))), partition=partition, partition_set=partition_set, update_timestamp=pendulum.now("UTC"), ) ) current_tags_set = set(current_tags.keys()) new_tags_set = set(new_tags.keys()) existing_tags = current_tags_set & new_tags_set added_tags = new_tags_set.difference(existing_tags) for tag in existing_tags: conn.execute( RunTagsTable.update() .where(db.and_(RunTagsTable.c.run_id == run_id, RunTagsTable.c.key == tag)) .values(value=new_tags[tag]) ) if added_tags: conn.execute( RunTagsTable.insert(), [dict(run_id=run_id, key=tag, value=new_tags[tag]) for tag in added_tags], ) def get_run_group(self, run_id: str) -> Tuple[str, Sequence[DagsterRun]]: check.str_param(run_id, "run_id") dagster_run = self._get_run_by_id(run_id) if not dagster_run: raise DagsterRunNotFoundError( f"Run {run_id} was not found in instance.", invalid_run_id=run_id ) # find root_run root_run_id = dagster_run.root_run_id if dagster_run.root_run_id else dagster_run.run_id root_run = self._get_run_by_id(root_run_id) if not root_run: raise DagsterRunNotFoundError( f"Run id {root_run_id} set as root run id for run {run_id} was not found in" " instance.", invalid_run_id=root_run_id, ) # root_run_id to run_id 1:1 mapping # https://github.com/dagster-io/dagster/issues/2495 # Note: we currently use tags to persist the run group info root_to_run = db_subquery( db_select( [RunTagsTable.c.value.label("root_run_id"), RunTagsTable.c.run_id.label("run_id")] ).where( db.and_(RunTagsTable.c.key == ROOT_RUN_ID_TAG, RunTagsTable.c.value == root_run_id) ), "root_to_run", ) # get run group run_group_query = db_select([RunsTable.c.run_body, RunsTable.c.status]).select_from( root_to_run.join( RunsTable, root_to_run.c.run_id == RunsTable.c.run_id, isouter=True, ) ) res = self.fetchall(run_group_query) run_group = self._rows_to_runs(res) return (root_run_id, [root_run, *run_group]) def has_run(self, run_id: str) -> bool: check.str_param(run_id, "run_id") return bool(self._get_run_by_id(run_id)) def delete_run(self, run_id: str) -> None: check.str_param(run_id, "run_id") query = db.delete(RunsTable).where(RunsTable.c.run_id == run_id) with self.connect() as conn: conn.execute(query) def has_job_snapshot(self, job_snapshot_id: str) -> bool: check.str_param(job_snapshot_id, "job_snapshot_id") return self._has_snapshot_id(job_snapshot_id) def add_job_snapshot(self, job_snapshot: JobSnapshot, snapshot_id: Optional[str] = None) -> str: check.inst_param(job_snapshot, "job_snapshot", JobSnapshot) check.opt_str_param(snapshot_id, "snapshot_id") if not snapshot_id: snapshot_id = create_job_snapshot_id(job_snapshot) return self._add_snapshot( snapshot_id=snapshot_id, snapshot_obj=job_snapshot, snapshot_type=SnapshotType.PIPELINE, ) def get_job_snapshot(self, job_snapshot_id: str) -> JobSnapshot: check.str_param(job_snapshot_id, "job_snapshot_id") return self._get_snapshot(job_snapshot_id) # type: ignore # (allowed to return None?) def has_execution_plan_snapshot(self, execution_plan_snapshot_id: str) -> bool: check.str_param(execution_plan_snapshot_id, "execution_plan_snapshot_id") return bool(self.get_execution_plan_snapshot(execution_plan_snapshot_id)) def add_execution_plan_snapshot( self, execution_plan_snapshot: ExecutionPlanSnapshot, snapshot_id: Optional[str] = None ) -> str: check.inst_param(execution_plan_snapshot, "execution_plan_snapshot", ExecutionPlanSnapshot) check.opt_str_param(snapshot_id, "snapshot_id") if not snapshot_id: snapshot_id = create_execution_plan_snapshot_id(execution_plan_snapshot) return self._add_snapshot( snapshot_id=snapshot_id, snapshot_obj=execution_plan_snapshot, snapshot_type=SnapshotType.EXECUTION_PLAN, ) def get_execution_plan_snapshot(self, execution_plan_snapshot_id: str) -> ExecutionPlanSnapshot: check.str_param(execution_plan_snapshot_id, "execution_plan_snapshot_id") return self._get_snapshot(execution_plan_snapshot_id) # type: ignore # (allowed to return None?) def _add_snapshot(self, snapshot_id: str, snapshot_obj, snapshot_type: SnapshotType) -> str: check.str_param(snapshot_id, "snapshot_id") check.not_none_param(snapshot_obj, "snapshot_obj") check.inst_param(snapshot_type, "snapshot_type", SnapshotType) with self.connect() as conn: snapshot_insert = SnapshotsTable.insert().values( snapshot_id=snapshot_id, snapshot_body=zlib.compress(serialize_value(snapshot_obj).encode("utf-8")), snapshot_type=snapshot_type.value, ) try: conn.execute(snapshot_insert) except db_exc.IntegrityError: # on_conflict_do_nothing equivalent pass return snapshot_id def get_run_storage_id(self) -> str: query = db_select([InstanceInfo.c.run_storage_id]) row = self.fetchone(query) if not row: run_storage_id = str(uuid.uuid4()) with self.connect() as conn: conn.execute(InstanceInfo.insert().values(run_storage_id=run_storage_id)) return run_storage_id else: return row["run_storage_id"] def _has_snapshot_id(self, snapshot_id: str) -> bool: query = db_select([SnapshotsTable.c.snapshot_id]).where( SnapshotsTable.c.snapshot_id == snapshot_id ) row = self.fetchone(query) return bool(row) def _get_snapshot(self, snapshot_id: str) -> Optional[JobSnapshot]: query = db_select([SnapshotsTable.c.snapshot_body]).where( SnapshotsTable.c.snapshot_id == snapshot_id ) row = self.fetchone(query) return ( defensively_unpack_execution_plan_snapshot_query(logging, [row["snapshot_body"]]) # type: ignore if row else None ) def get_run_partition_data(self, runs_filter: RunsFilter) -> Sequence[RunPartitionData]: if self.has_built_index(RUN_PARTITIONS) and self.has_run_stats_index_cols(): query = self._runs_query( filters=runs_filter, columns=["run_id", "status", "start_time", "end_time", "partition"], ) rows = self.fetchall(query) # dedup by partition _partition_data_by_partition = {} for row in rows: if not row["partition"] or row["partition"] in _partition_data_by_partition: continue _partition_data_by_partition[row["partition"]] = RunPartitionData( run_id=row["run_id"], partition=row["partition"], status=DagsterRunStatus[row["status"]], start_time=row["start_time"], end_time=row["end_time"], ) return list(_partition_data_by_partition.values()) else: query = self._runs_query(filters=runs_filter) rows = self.fetchall(query) _partition_data_by_partition = {} for row in rows: run = self._row_to_run(row) partition = run.tags.get(PARTITION_NAME_TAG) if not partition or partition in _partition_data_by_partition: continue _partition_data_by_partition[partition] = RunPartitionData( run_id=run.run_id, partition=partition, status=run.status, start_time=None, end_time=None, ) return list(_partition_data_by_partition.values()) def _get_partition_runs( self, partition_set_name: str, partition_name: str ) -> Sequence[DagsterRun]: # utility method to help test reads off of the partition column if not self.has_built_index(RUN_PARTITIONS): # query by tags return self.get_runs( filters=RunsFilter( tags={ PARTITION_SET_TAG: partition_set_name, PARTITION_NAME_TAG: partition_name, } ) ) else: query = ( self._runs_query() .where(RunsTable.c.partition == partition_name) .where(RunsTable.c.partition_set == partition_set_name) ) rows = self.fetchall(query) return self._rows_to_runs(rows) # Tracking data migrations over secondary indexes def _execute_data_migrations( self, migrations: Mapping[str, Callable[[], MigrationFn]], print_fn: Optional[PrintFn] = None, force_rebuild_all: bool = False, ) -> None: for migration_name, migration_fn in migrations.items(): if self.has_built_index(migration_name): if not force_rebuild_all: if print_fn: print_fn(f"Skipping already applied data migration: {migration_name}") continue if print_fn: print_fn(f"Starting data migration: {migration_name}") migration_fn()(self, print_fn) self.mark_index_built(migration_name) if print_fn: print_fn(f"Finished data migration: {migration_name}") def migrate(self, print_fn: Optional[PrintFn] = None, force_rebuild_all: bool = False) -> None: self._execute_data_migrations(REQUIRED_DATA_MIGRATIONS, print_fn, force_rebuild_all) def optimize(self, print_fn: Optional[PrintFn] = None, force_rebuild_all: bool = False) -> None: self._execute_data_migrations(OPTIONAL_DATA_MIGRATIONS, print_fn, force_rebuild_all) def has_built_index(self, migration_name: str) -> bool: query = ( db_select([1]) .where(SecondaryIndexMigrationTable.c.name == migration_name) .where(SecondaryIndexMigrationTable.c.migration_completed != None) # noqa: E711 .limit(1) ) results = self.fetchall(query) return len(results) > 0 def mark_index_built(self, migration_name: str) -> None: query = SecondaryIndexMigrationTable.insert().values( name=migration_name, migration_completed=datetime.now(), ) with self.connect() as conn: try: conn.execute(query) except db_exc.IntegrityError: conn.execute( SecondaryIndexMigrationTable.update() .where(SecondaryIndexMigrationTable.c.name == migration_name) .values(migration_completed=datetime.now()) ) # Checking for migrations def has_run_stats_index_cols(self) -> bool: with self.connect() as conn: column_names = [x.get("name") for x in db.inspect(conn).get_columns(RunsTable.name)] return "start_time" in column_names and "end_time" in column_names def has_bulk_actions_selector_cols(self) -> bool: with self.connect() as conn: column_names = [ x.get("name") for x in db.inspect(conn).get_columns(BulkActionsTable.name) ] return "selector_id" in column_names # Daemon heartbeats def add_daemon_heartbeat(self, daemon_heartbeat: DaemonHeartbeat) -> None: with self.connect() as conn: # insert, or update if already present try: conn.execute( DaemonHeartbeatsTable.insert().values( timestamp=utc_datetime_from_timestamp(daemon_heartbeat.timestamp), daemon_type=daemon_heartbeat.daemon_type, daemon_id=daemon_heartbeat.daemon_id, body=serialize_value(daemon_heartbeat), ) ) except db_exc.IntegrityError: conn.execute( DaemonHeartbeatsTable.update() .where(DaemonHeartbeatsTable.c.daemon_type == daemon_heartbeat.daemon_type) .values( timestamp=utc_datetime_from_timestamp(daemon_heartbeat.timestamp), daemon_id=daemon_heartbeat.daemon_id, body=serialize_value(daemon_heartbeat), ) ) def get_daemon_heartbeats(self) -> Mapping[str, DaemonHeartbeat]: rows = self.fetchall(db_select([DaemonHeartbeatsTable.c.body])) heartbeats = [] for row in rows: heartbeats.append(deserialize_value(row["body"], DaemonHeartbeat)) return {heartbeat.daemon_type: heartbeat for heartbeat in heartbeats} def wipe(self) -> None: """Clears the run storage.""" with self.connect() as conn: # https://stackoverflow.com/a/54386260/324449 conn.execute(RunsTable.delete()) conn.execute(RunTagsTable.delete()) conn.execute(SnapshotsTable.delete()) conn.execute(DaemonHeartbeatsTable.delete()) conn.execute(BulkActionsTable.delete()) def wipe_daemon_heartbeats(self) -> None: with self.connect() as conn: # https://stackoverflow.com/a/54386260/324449 conn.execute(DaemonHeartbeatsTable.delete()) def get_backfills( self, status: Optional[BulkActionStatus] = None, cursor: Optional[str] = None, limit: Optional[int] = None, ) -> Sequence[PartitionBackfill]: check.opt_inst_param(status, "status", BulkActionStatus) query = db_select([BulkActionsTable.c.body]) if status: query = query.where(BulkActionsTable.c.status == status.value) if cursor: cursor_query = db_select([BulkActionsTable.c.id]).where( BulkActionsTable.c.key == cursor ) query = query.where(BulkActionsTable.c.id < cursor_query) if limit: query = query.limit(limit) query = query.order_by(BulkActionsTable.c.id.desc()) rows = self.fetchall(query) return [deserialize_value(row["body"], PartitionBackfill) for row in rows] def get_backfill(self, backfill_id: str) -> Optional[PartitionBackfill]: check.str_param(backfill_id, "backfill_id") query = db_select([BulkActionsTable.c.body]).where(BulkActionsTable.c.key == backfill_id) row = self.fetchone(query) return deserialize_value(row["body"], PartitionBackfill) if row else None def add_backfill(self, partition_backfill: PartitionBackfill) -> None: check.inst_param(partition_backfill, "partition_backfill", PartitionBackfill) values: Dict[str, Any] = dict( key=partition_backfill.backfill_id, status=partition_backfill.status.value, timestamp=utc_datetime_from_timestamp(partition_backfill.backfill_timestamp), body=serialize_value(cast(NamedTuple, partition_backfill)), ) if self.has_bulk_actions_selector_cols(): values["selector_id"] = partition_backfill.selector_id values["action_type"] = partition_backfill.bulk_action_type.value with self.connect() as conn: conn.execute(BulkActionsTable.insert().values(**values)) def update_backfill(self, partition_backfill: PartitionBackfill) -> None: check.inst_param(partition_backfill, "partition_backfill", PartitionBackfill) backfill_id = partition_backfill.backfill_id if not self.get_backfill(backfill_id): raise DagsterInvariantViolationError( f"Backfill {backfill_id} is not present in storage" ) with self.connect() as conn: conn.execute( BulkActionsTable.update() .where(BulkActionsTable.c.key == backfill_id) .values( status=partition_backfill.status.value, body=serialize_value(partition_backfill), ) ) def get_cursor_values(self, keys: Set[str]) -> Mapping[str, str]: check.set_param(keys, "keys", of_type=str) rows = self.fetchall( db_select([KeyValueStoreTable.c.key, KeyValueStoreTable.c.value]).where( KeyValueStoreTable.c.key.in_(keys) ), ) return {row["key"]: row["value"] for row in rows} def set_cursor_values(self, pairs: Mapping[str, str]) -> None: check.mapping_param(pairs, "pairs", key_type=str, value_type=str) db_values = [{"key": k, "value": v} for k, v in pairs.items()] with self.connect() as conn: try: conn.execute(KeyValueStoreTable.insert().values(db_values)) except db_exc.IntegrityError: conn.execute( KeyValueStoreTable.update() .where(KeyValueStoreTable.c.key.in_(pairs.keys())) .values(value=db.sql.case(pairs, value=KeyValueStoreTable.c.key)) ) # Migrating run history def replace_job_origin(self, run: DagsterRun, job_origin: RemoteJobOrigin) -> None: new_label = job_origin.repository_origin.get_label() with self.connect() as conn: conn.execute( RunsTable.update() .where(RunsTable.c.run_id == run.run_id) .values( run_body=serialize_value( run.with_job_origin(job_origin).with_tags( {**run.tags, REPOSITORY_LABEL_TAG: new_label} ) ), ) ) conn.execute( RunTagsTable.update() .where(RunTagsTable.c.run_id == run.run_id) .where(RunTagsTable.c.key == REPOSITORY_LABEL_TAG) .values(value=new_label) )
GET_PIPELINE_SNAPSHOT_QUERY_ID = "get-pipeline-snapshot" def defensively_unpack_execution_plan_snapshot_query( logger: logging.Logger, row: Sequence[Any] ) -> Optional[Union[ExecutionPlanSnapshot, JobSnapshot]]: # minimal checking here because sqlalchemy returns a different type based on what version of # SqlAlchemy you are using def _warn(msg: str) -> None: logger.warning(f"get-pipeline-snapshot: {msg}") if not isinstance(row[0], bytes): _warn("First entry in row is not a binary type.") return None try: uncompressed_bytes = zlib.decompress(row[0]) except zlib.error: _warn("Could not decompress bytes stored in snapshot table.") return None try: decoded_str = uncompressed_bytes.decode("utf-8") except UnicodeDecodeError: _warn("Could not unicode decode decompressed bytes stored in snapshot table.") return None try: return deserialize_value(decoded_str, (ExecutionPlanSnapshot, JobSnapshot)) except JSONDecodeError: _warn("Could not parse json in snapshot table.") return None