Ask AI

Source code for dagster._core.storage.schedules.sql_schedule_storage

from abc import abstractmethod
from collections import defaultdict
from datetime import datetime
from typing import (
    Any,
    Callable,
    ContextManager,
    Mapping,
    NamedTuple,
    Optional,
    Sequence,
    Set,
    Type,
    TypeVar,
)

import pendulum
import sqlalchemy as db
import sqlalchemy.exc as db_exc
from sqlalchemy.engine import Connection

import dagster._check as check
from dagster._core.definitions.asset_condition.asset_condition import (
    AssetConditionEvaluationWithRunIds,
)
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.run_request import InstigatorType
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.scheduler.instigation import (
    AutoMaterializeAssetEvaluationRecord,
    InstigatorState,
    InstigatorStatus,
    InstigatorTick,
    TickData,
    TickStatus,
)
from dagster._core.storage.sql import SqlAlchemyQuery, SqlAlchemyRow
from dagster._core.storage.sqlalchemy_compat import db_fetch_mappings, db_select, db_subquery
from dagster._serdes import serialize_value
from dagster._serdes.serdes import deserialize_value
from dagster._utils import PrintFn, utc_datetime_from_timestamp

from .base import ScheduleStorage
from .migration import (
    OPTIONAL_SCHEDULE_DATA_MIGRATIONS,
    REQUIRED_SCHEDULE_DATA_MIGRATIONS,
    SCHEDULE_JOBS_SELECTOR_ID,
    SCHEDULE_TICKS_SELECTOR_ID,
)
from .schema import (
    AssetDaemonAssetEvaluationsTable,
    InstigatorsTable,
    JobTable,
    JobTickTable,
    SecondaryIndexMigrationTable,
)

T_NamedTuple = TypeVar("T_NamedTuple", bound=NamedTuple)


[docs]class SqlScheduleStorage(ScheduleStorage): """Base class for SQL backed schedule storage.""" @abstractmethod def connect(self) -> ContextManager[Connection]: """Context manager yielding a sqlalchemy.engine.Connection.""" def execute(self, query: SqlAlchemyQuery) -> Sequence[SqlAlchemyRow]: with self.connect() as conn: result_proxy = conn.execute(query) res = result_proxy.fetchall() result_proxy.close() return res def _deserialize_rows( self, rows: Sequence[SqlAlchemyRow], as_type: Type[T_NamedTuple] ) -> Sequence[T_NamedTuple]: return list(map(lambda r: deserialize_value(r[0], as_type), rows)) def all_instigator_state( self, repository_origin_id: Optional[str] = None, repository_selector_id: Optional[str] = None, instigator_type: Optional[InstigatorType] = None, instigator_statuses: Optional[Set[InstigatorStatus]] = None, ) -> Sequence[InstigatorState]: check.opt_inst_param(instigator_type, "instigator_type", InstigatorType) if self.has_instigators_table() and self.has_built_index(SCHEDULE_JOBS_SELECTOR_ID): query = db_select([InstigatorsTable.c.instigator_body]).select_from(InstigatorsTable) if repository_selector_id: query = query.where( InstigatorsTable.c.repository_selector_id == repository_selector_id ) if instigator_type: query = query.where(InstigatorsTable.c.instigator_type == instigator_type.value) if instigator_statuses: query = query.where( InstigatorsTable.c.status.in_([status.value for status in instigator_statuses]) ) else: query = db_select([JobTable.c.job_body]).select_from(JobTable) if repository_origin_id: query = query.where(JobTable.c.repository_origin_id == repository_origin_id) if instigator_type: query = query.where(JobTable.c.job_type == instigator_type.value) if instigator_statuses: query = query.where( JobTable.c.status.in_([status.value for status in instigator_statuses]) ) rows = self.execute(query) return self._deserialize_rows(rows, InstigatorState) def get_instigator_state(self, origin_id: str, selector_id: str) -> Optional[InstigatorState]: check.str_param(origin_id, "origin_id") check.str_param(selector_id, "selector_id") if self.has_instigators_table() and self.has_built_index(SCHEDULE_JOBS_SELECTOR_ID): query = ( db_select([InstigatorsTable.c.instigator_body]) .select_from(InstigatorsTable) .where(InstigatorsTable.c.selector_id == selector_id) ) else: query = ( db_select([JobTable.c.job_body]) .select_from(JobTable) .where(JobTable.c.job_origin_id == origin_id) ) rows = self.execute(query) return self._deserialize_rows(rows[:1], InstigatorState)[0] if len(rows) else None def _has_instigator_state_by_selector(self, selector_id: str) -> bool: check.str_param(selector_id, "selector_id") query = ( db_select([JobTable.c.job_body]) .select_from(JobTable) .where(JobTable.c.selector_id == selector_id) ) rows = self.execute(query) return self._deserialize_rows(rows[:1])[0] if len(rows) else None # type: ignore def _add_or_update_instigators_table(self, conn: Connection, state: InstigatorState) -> None: selector_id = state.selector_id try: conn.execute( InstigatorsTable.insert().values( selector_id=selector_id, repository_selector_id=state.repository_selector_id, status=state.status.value, instigator_type=state.instigator_type.value, instigator_body=serialize_value(state), ) ) except db_exc.IntegrityError: conn.execute( InstigatorsTable.update() .where(InstigatorsTable.c.selector_id == selector_id) .values( status=state.status.value, instigator_type=state.instigator_type.value, instigator_body=serialize_value(state), update_timestamp=pendulum.now("UTC"), ) ) def add_instigator_state(self, state: InstigatorState) -> InstigatorState: check.inst_param(state, "state", InstigatorState) with self.connect() as conn: try: conn.execute( JobTable.insert().values( job_origin_id=state.instigator_origin_id, repository_origin_id=state.repository_origin_id, status=state.status.value, job_type=state.instigator_type.value, job_body=serialize_value(state), ) ) except db_exc.IntegrityError as exc: raise DagsterInvariantViolationError( f"InstigatorState {state.instigator_origin_id} is already present in storage" ) from exc # try writing to the instigators table if self._has_instigators_table(conn): self._add_or_update_instigators_table(conn, state) return state def update_instigator_state(self, state: InstigatorState) -> InstigatorState: check.inst_param(state, "state", InstigatorState) if not self.get_instigator_state(state.instigator_origin_id, state.selector_id): raise DagsterInvariantViolationError( f"InstigatorState {state.instigator_origin_id} is not present in storage" ) values = { "status": state.status.value, "job_body": serialize_value(state), "update_timestamp": pendulum.now("UTC"), } if self.has_instigators_table(): values["selector_id"] = state.selector_id with self.connect() as conn: conn.execute( JobTable.update() .where(JobTable.c.job_origin_id == state.instigator_origin_id) .values(**values) ) if self._has_instigators_table(conn): self._add_or_update_instigators_table(conn, state) return state def delete_instigator_state(self, origin_id: str, selector_id: str) -> None: check.str_param(origin_id, "origin_id") check.str_param(selector_id, "selector_id") if not self.get_instigator_state(origin_id, selector_id): raise DagsterInvariantViolationError( f"InstigatorState {origin_id} is not present in storage" ) with self.connect() as conn: conn.execute(JobTable.delete().where(JobTable.c.job_origin_id == origin_id)) if self._has_instigators_table(conn): if not self._jobs_has_selector_state(conn, selector_id): conn.execute( InstigatorsTable.delete().where( InstigatorsTable.c.selector_id == selector_id ) ) def _jobs_has_selector_state(self, conn: Connection, selector_id: str) -> bool: query = ( db_select([db.func.count()]) .select_from(JobTable) .where(JobTable.c.selector_id == selector_id) ) result = conn.execute(query) row = result.fetchone() result.close() return row[0] > 0 # type: ignore # (possible none) def _add_filter_limit( self, query: SqlAlchemyQuery, before: Optional[float] = None, after: Optional[float] = None, limit: Optional[int] = None, statuses=None, ) -> SqlAlchemyQuery: check.opt_float_param(before, "before") check.opt_float_param(after, "after") check.opt_int_param(limit, "limit") check.opt_list_param(statuses, "statuses", of_type=TickStatus) if before: query = query.where(JobTickTable.c.timestamp < utc_datetime_from_timestamp(before)) if after: query = query.where(JobTickTable.c.timestamp > utc_datetime_from_timestamp(after)) if limit: query = query.limit(limit) if statuses: query = query.where(JobTickTable.c.status.in_([status.value for status in statuses])) return query @property def supports_batch_queries(self) -> bool: return self.has_instigators_table() and self.has_built_index(SCHEDULE_TICKS_SELECTOR_ID) def has_instigators_table(self) -> bool: with self.connect() as conn: return self._has_instigators_table(conn) def _has_instigators_table(self, conn: Connection) -> bool: table_names = db.inspect(conn).get_table_names() return "instigators" in table_names def _has_asset_daemon_asset_evaluations_table(self, conn: Connection) -> bool: table_names = db.inspect(conn).get_table_names() return "asset_daemon_asset_evaluations" in table_names def get_batch_ticks( self, selector_ids: Sequence[str], limit: Optional[int] = None, statuses: Optional[Sequence[TickStatus]] = None, ) -> Mapping[str, Sequence[InstigatorTick]]: check.sequence_param(selector_ids, "selector_ids", of_type=str) check.opt_int_param(limit, "limit") check.opt_sequence_param(statuses, "statuses", of_type=TickStatus) bucket_rank_column = ( db.func.rank() .over( order_by=db.desc(JobTickTable.c.timestamp), partition_by=JobTickTable.c.selector_id, ) .label("rank") ) subquery = db_subquery( db_select( [ JobTickTable.c.id, JobTickTable.c.selector_id, JobTickTable.c.tick_body, bucket_rank_column, ] ) .select_from(JobTickTable) .where(JobTickTable.c.selector_id.in_(selector_ids)) ) if statuses: subquery = subquery.where( JobTickTable.c.status.in_([status.value for status in statuses]) ) query = ( db_select([subquery.c.id, subquery.c.selector_id, subquery.c.tick_body]) .order_by(subquery.c.rank.asc()) .where(subquery.c.rank <= limit) ) rows = self.execute(query) results = defaultdict(list) for row in rows: tick_id = row[0] selector_id = row[1] tick_data = deserialize_value(row[2], TickData) results[selector_id].append(InstigatorTick(tick_id, tick_data)) return results def get_tick(self, tick_id: int) -> InstigatorTick: check.int_param(tick_id, "tick_id") query = ( db_select([JobTickTable.c.id, JobTickTable.c.tick_body]) .select_from(JobTickTable) .where(JobTickTable.c.id == tick_id) ) rows = self.execute(query) if not rows: raise DagsterInvariantViolationError( f"InstigatorTick {tick_id} is not present in storage" ) tick_id, tick_data = rows[0] return InstigatorTick(tick_id, deserialize_value(tick_data, TickData)) def get_ticks( self, origin_id: str, selector_id: str, before: Optional[float] = None, after: Optional[float] = None, limit: Optional[int] = None, statuses: Optional[Sequence[TickStatus]] = None, ) -> Sequence[InstigatorTick]: check.str_param(origin_id, "origin_id") check.opt_float_param(before, "before") check.opt_float_param(after, "after") check.opt_int_param(limit, "limit") check.opt_list_param(statuses, "statuses", of_type=TickStatus) base_query = ( db_select([JobTickTable.c.id, JobTickTable.c.tick_body]) .select_from(JobTickTable) .order_by(JobTickTable.c.timestamp.desc()) ) if self.has_instigators_table(): query = base_query.where( db.or_( JobTickTable.c.selector_id == selector_id, db.and_( JobTickTable.c.selector_id.is_(None), JobTickTable.c.job_origin_id == origin_id, ), ) ) else: query = base_query.where(JobTickTable.c.job_origin_id == origin_id) query = self._add_filter_limit( query, before=before, after=after, limit=limit, statuses=statuses ) rows = self.execute(query) return list(map(lambda r: InstigatorTick(r[0], deserialize_value(r[1], TickData)), rows)) def create_tick(self, tick_data: TickData) -> InstigatorTick: check.inst_param(tick_data, "tick_data", TickData) values = { "job_origin_id": tick_data.instigator_origin_id, "status": tick_data.status.value, "type": tick_data.instigator_type.value, "timestamp": utc_datetime_from_timestamp(tick_data.timestamp), "tick_body": serialize_value(tick_data), } if self.has_instigators_table() and tick_data.selector_id: values["selector_id"] = tick_data.selector_id with self.connect() as conn: try: tick_insert = JobTickTable.insert().values(**values) result = conn.execute(tick_insert) tick_id = result.inserted_primary_key[0] return InstigatorTick(tick_id, tick_data) except db_exc.IntegrityError as exc: raise DagsterInvariantViolationError( f"Unable to insert InstigatorTick for job {tick_data.instigator_name} in" " storage" ) from exc def update_tick(self, tick: InstigatorTick) -> InstigatorTick: check.inst_param(tick, "tick", InstigatorTick) values = { "status": tick.status.value, "type": tick.instigator_type.value, "timestamp": utc_datetime_from_timestamp(tick.timestamp), "tick_body": serialize_value(tick.tick_data), } if self.has_instigators_table() and tick.selector_id: values["selector_id"] = tick.selector_id with self.connect() as conn: conn.execute( JobTickTable.update().where(JobTickTable.c.id == tick.tick_id).values(**values) ) return tick def purge_ticks( self, origin_id: str, selector_id: str, before: float, tick_statuses: Optional[Sequence[TickStatus]] = None, ) -> None: check.str_param(origin_id, "origin_id") check.float_param(before, "before") check.opt_list_param(tick_statuses, "tick_statuses", of_type=TickStatus) utc_before = utc_datetime_from_timestamp(before) query = JobTickTable.delete().where(JobTickTable.c.timestamp < utc_before) if tick_statuses: query = query.where( JobTickTable.c.status.in_([tick_status.value for tick_status in tick_statuses]) ) if self.has_instigators_table(): query = query.where( db.or_( JobTickTable.c.selector_id == selector_id, db.and_( JobTickTable.c.selector_id.is_(None), JobTickTable.c.job_origin_id == origin_id, ), ) ) else: query = query.where(JobTickTable.c.job_origin_id == origin_id) with self.connect() as conn: conn.execute(query) @property def supports_auto_materialize_asset_evaluations(self) -> bool: with self.connect() as conn: return self._has_asset_daemon_asset_evaluations_table(conn) def add_auto_materialize_asset_evaluations( self, evaluation_id: int, asset_evaluations: Sequence[AssetConditionEvaluationWithRunIds], ): if not asset_evaluations: return with self.connect() as conn: for evaluation in asset_evaluations: insert_stmt = AssetDaemonAssetEvaluationsTable.insert().values( [ { "evaluation_id": evaluation_id, "asset_key": evaluation.asset_key.to_string(), "asset_evaluation_body": serialize_value(evaluation), "num_requested": evaluation.num_requested, } ] ) try: conn.execute(insert_stmt) except db_exc.IntegrityError: conn.execute( AssetDaemonAssetEvaluationsTable.update() .where( db.and_( AssetDaemonAssetEvaluationsTable.c.evaluation_id == evaluation_id, AssetDaemonAssetEvaluationsTable.c.asset_key == evaluation.asset_key.to_string(), ) ) .values( asset_evaluation_body=serialize_value(evaluation), num_requested=evaluation.num_requested, ) ) def get_auto_materialize_asset_evaluations( self, asset_key: AssetKey, limit: int, cursor: Optional[int] = None ) -> Sequence[AutoMaterializeAssetEvaluationRecord]: with self.connect() as conn: query = ( db_select( [ AssetDaemonAssetEvaluationsTable.c.id, AssetDaemonAssetEvaluationsTable.c.asset_evaluation_body, AssetDaemonAssetEvaluationsTable.c.evaluation_id, AssetDaemonAssetEvaluationsTable.c.create_timestamp, AssetDaemonAssetEvaluationsTable.c.asset_key, ] ) .where(AssetDaemonAssetEvaluationsTable.c.asset_key == asset_key.to_string()) .order_by(AssetDaemonAssetEvaluationsTable.c.evaluation_id.desc()) ).limit(limit) if cursor: query = query.where(AssetDaemonAssetEvaluationsTable.c.evaluation_id < cursor) rows = db_fetch_mappings(conn, query) return [AutoMaterializeAssetEvaluationRecord.from_db_row(row) for row in rows] def get_auto_materialize_evaluations_for_evaluation_id( self, evaluation_id: int ) -> Sequence[AutoMaterializeAssetEvaluationRecord]: with self.connect() as conn: query = db_select( [ AssetDaemonAssetEvaluationsTable.c.id, AssetDaemonAssetEvaluationsTable.c.asset_evaluation_body, AssetDaemonAssetEvaluationsTable.c.evaluation_id, AssetDaemonAssetEvaluationsTable.c.create_timestamp, AssetDaemonAssetEvaluationsTable.c.asset_key, ] ).where(AssetDaemonAssetEvaluationsTable.c.evaluation_id == evaluation_id) rows = db_fetch_mappings(conn, query) return [AutoMaterializeAssetEvaluationRecord.from_db_row(row) for row in rows] def purge_asset_evaluations(self, before: float): check.float_param(before, "before") utc_before = utc_datetime_from_timestamp(before) query = AssetDaemonAssetEvaluationsTable.delete().where( AssetDaemonAssetEvaluationsTable.c.create_timestamp < utc_before ) with self.connect() as conn: conn.execute(query) def wipe(self) -> None: """Clears the schedule storage.""" with self.connect() as conn: # https://stackoverflow.com/a/54386260/324449 conn.execute(JobTable.delete()) conn.execute(JobTickTable.delete()) if self._has_instigators_table(conn): conn.execute(InstigatorsTable.delete()) if self._has_asset_daemon_asset_evaluations_table(conn): conn.execute(AssetDaemonAssetEvaluationsTable.delete()) # MIGRATIONS def has_secondary_index_table(self) -> bool: with self.connect() as conn: return "secondary_indexes" in db.inspect(conn).get_table_names() def has_built_index(self, migration_name: str) -> bool: if not self.has_secondary_index_table(): return False query = ( db_select([1]) .where(SecondaryIndexMigrationTable.c.name == migration_name) .where(SecondaryIndexMigrationTable.c.migration_completed != None) # noqa: E711 .limit(1) ) with self.connect() as conn: results = conn.execute(query).fetchall() return len(results) > 0 def mark_index_built(self, migration_name: str) -> None: query = SecondaryIndexMigrationTable.insert().values( name=migration_name, migration_completed=datetime.now(), ) with self.connect() as conn: try: conn.execute(query) except db_exc.IntegrityError: conn.execute( SecondaryIndexMigrationTable.update() .where(SecondaryIndexMigrationTable.c.name == migration_name) .values(migration_completed=datetime.now()) ) def _execute_data_migrations( self, migrations: Mapping[str, Callable[..., Any]], print_fn: Optional[Callable] = None, force_rebuild_all: bool = False, ) -> None: for migration_name, migration_fn in migrations.items(): if self.has_built_index(migration_name): if not force_rebuild_all: if print_fn: print_fn(f"Skipping already applied migration: {migration_name}") continue if print_fn: print_fn(f"Starting data migration: {migration_name}") migration_fn()(self, print_fn) self.mark_index_built(migration_name) if print_fn: print_fn(f"Finished data migration: {migration_name}") def migrate(self, print_fn: Optional[PrintFn] = None, force_rebuild_all: bool = False) -> None: self._execute_data_migrations( REQUIRED_SCHEDULE_DATA_MIGRATIONS, print_fn, force_rebuild_all ) def optimize(self, print_fn: Optional[PrintFn] = None, force_rebuild_all: bool = False) -> None: self._execute_data_migrations( OPTIONAL_SCHEDULE_DATA_MIGRATIONS, print_fn, force_rebuild_all )