import logging
from abc import abstractmethod
from collections import OrderedDict, defaultdict
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
ContextManager,
Dict,
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
import pendulum
import sqlalchemy as db
import sqlalchemy.exc as db_exc
from sqlalchemy.engine import Connection
from typing_extensions import TypeAlias
import dagster._check as check
import dagster._seven as seven
from dagster._core.assets import AssetDetails
from dagster._core.definitions.asset_check_evaluation import (
AssetCheckEvaluation,
AssetCheckEvaluationPlanned,
)
from dagster._core.definitions.events import AssetKey, AssetMaterialization
from dagster._core.errors import (
DagsterEventLogInvalidForRun,
DagsterInvalidInvocationError,
DagsterInvariantViolationError,
)
from dagster._core.event_api import RunShardedEventsCursor
from dagster._core.events import ASSET_CHECK_EVENTS, ASSET_EVENTS, MARKER_EVENTS, DagsterEventType
from dagster._core.events.log import EventLogEntry
from dagster._core.execution.stats import RunStepKeyStatsSnapshot, build_run_step_stats_from_events
from dagster._core.storage.asset_check_execution_record import (
AssetCheckExecutionRecord,
AssetCheckExecutionRecordStatus,
)
from dagster._core.storage.sql import SqlAlchemyQuery, SqlAlchemyRow
from dagster._core.storage.sqlalchemy_compat import (
db_case,
db_fetch_mappings,
db_select,
db_subquery,
)
from dagster._serdes import (
deserialize_value,
serialize_value,
)
from dagster._serdes.errors import DeserializationError
from dagster._utils import (
PrintFn,
datetime_as_float,
utc_datetime_from_naive,
utc_datetime_from_timestamp,
)
from dagster._utils.concurrency import (
ConcurrencyClaimStatus,
ConcurrencyKeyInfo,
ConcurrencySlotStatus,
)
from ..dagster_run import DagsterRunStatsSnapshot
from .base import (
AssetEntry,
AssetRecord,
EventLogConnection,
EventLogCursor,
EventLogRecord,
EventLogStorage,
EventRecordsFilter,
)
from .migration import ASSET_DATA_MIGRATIONS, ASSET_KEY_INDEX_COLS, EVENT_LOG_DATA_MIGRATIONS
from .schema import (
AssetCheckExecutionsTable,
AssetEventTagsTable,
AssetKeyTable,
ConcurrencySlotsTable,
DynamicPartitionsTable,
PendingStepsTable,
SecondaryIndexMigrationTable,
SqlEventLogStorageTable,
)
if TYPE_CHECKING:
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue
MAX_CONCURRENCY_SLOTS = 1000
MIN_ASSET_ROWS = 25
# We are using third-party library objects for DB connections-- at this time, these libraries are
# untyped. When/if we upgrade to typed variants, the `Any` here can be replaced or the alias as a
# whole can be dropped.
SqlDbConnection: TypeAlias = Any
[docs]class SqlEventLogStorage(EventLogStorage):
"""Base class for SQL backed event log storages.
Distinguishes between run-based connections and index connections in order to support run-level
sharding, while maintaining the ability to do cross-run queries
"""
@abstractmethod
def run_connection(self, run_id: Optional[str]) -> ContextManager[Connection]:
"""Context manager yielding a connection to access the event logs for a specific run.
Args:
run_id (Optional[str]): Enables those storages which shard based on run_id, e.g.,
SqliteEventLogStorage, to connect appropriately.
"""
@abstractmethod
def index_connection(self) -> ContextManager[Connection]:
"""Context manager yielding a connection to access cross-run indexed tables."""
@abstractmethod
def upgrade(self) -> None:
"""This method should perform any schema migrations necessary to bring an
out-of-date instance of the storage up to date.
"""
@abstractmethod
def has_table(self, table_name: str) -> bool:
"""This method checks if a table exists in the database."""
def prepare_insert_event(self, event):
"""Helper method for preparing the event log SQL insertion statement. Abstracted away to
have a single place for the logical table representation of the event, while having a way
for SQL backends to implement different execution implementations for `store_event`. See
the `dagster-postgres` implementation which overrides the generic SQL implementation of
`store_event`.
"""
dagster_event_type = None
asset_key_str = None
partition = None
step_key = event.step_key
if event.is_dagster_event:
dagster_event_type = event.dagster_event.event_type_value
step_key = event.dagster_event.step_key
if event.dagster_event.asset_key:
check.inst_param(event.dagster_event.asset_key, "asset_key", AssetKey)
asset_key_str = event.dagster_event.asset_key.to_string()
if event.dagster_event.partition:
partition = event.dagster_event.partition
# https://stackoverflow.com/a/54386260/324449
return SqlEventLogStorageTable.insert().values(
run_id=event.run_id,
event=serialize_value(event),
dagster_event_type=dagster_event_type,
# Postgres requires a datetime that is in UTC but has no timezone info set
# in order to be stored correctly
timestamp=datetime.utcfromtimestamp(event.timestamp),
step_key=step_key,
asset_key=asset_key_str,
partition=partition,
)
def has_asset_key_col(self, column_name: str) -> bool:
with self.index_connection() as conn:
column_names = [x.get("name") for x in db.inspect(conn).get_columns(AssetKeyTable.name)]
return column_name in column_names
def has_asset_key_index_cols(self) -> bool:
return self.has_asset_key_col("last_materialization_timestamp")
def store_asset_event(self, event: EventLogEntry, event_id: int):
check.inst_param(event, "event", EventLogEntry)
if not (event.dagster_event and event.dagster_event.asset_key):
return
# We switched to storing the entire event record of the last materialization instead of just
# the AssetMaterialization object, so that we have access to metadata like timestamp,
# pipeline, run_id, etc.
#
# This should make certain asset queries way more performant, without having to do extra
# queries against the event log.
#
# This should be accompanied by a schema change in 0.12.0, renaming `last_materialization`
# to `last_materialization_event`, for clarity. For now, we should do some back-compat.
#
# https://github.com/dagster-io/dagster/issues/3945
values = self._get_asset_entry_values(event, event_id, self.has_asset_key_index_cols())
insert_statement = AssetKeyTable.insert().values(
asset_key=event.dagster_event.asset_key.to_string(), **values
)
update_statement = (
AssetKeyTable.update()
.values(**values)
.where(
AssetKeyTable.c.asset_key == event.dagster_event.asset_key.to_string(),
)
)
with self.index_connection() as conn:
try:
conn.execute(insert_statement)
except db_exc.IntegrityError:
conn.execute(update_statement)
def _get_asset_entry_values(
self, event: EventLogEntry, event_id: int, has_asset_key_index_cols: bool
) -> Dict[str, Any]:
# The AssetKeyTable contains a `last_materialization_timestamp` column that is exclusively
# used to determine if an asset exists (last materialization timestamp > wipe timestamp).
# This column is used nowhere else, and as of AssetObservation/AssetMaterializationPlanned
# event creation, we want to extend this functionality to ensure that assets with any event
# (observation, materialization, or materialization planned) yielded with timestamp
# > wipe timestamp display in the Dagster UI.
# As of the following PRs, we update last_materialization_timestamp to store the timestamp
# of the latest asset observation, materialization, or materialization_planned that has occurred.
# https://github.com/dagster-io/dagster/pull/6885
# https://github.com/dagster-io/dagster/pull/7319
entry_values: Dict[str, Any] = {}
dagster_event = check.not_none(event.dagster_event)
if dagster_event.is_step_materialization:
entry_values.update(
{
"last_materialization": serialize_value(
EventLogRecord(
storage_id=event_id,
event_log_entry=event,
)
),
"last_run_id": event.run_id,
}
)
if has_asset_key_index_cols:
entry_values.update(
{
"last_materialization_timestamp": utc_datetime_from_timestamp(
event.timestamp
),
}
)
elif dagster_event.is_asset_materialization_planned:
# The AssetKeyTable also contains a `last_run_id` column that is updated upon asset
# materialization. This column was not being used until the below PR. This new change
# writes to the column upon `ASSET_MATERIALIZATION_PLANNED` events to fetch the last
# run id for a set of assets in one roundtrip call to event log storage.
# https://github.com/dagster-io/dagster/pull/7319
entry_values.update({"last_run_id": event.run_id})
if has_asset_key_index_cols:
entry_values.update(
{
"last_materialization_timestamp": utc_datetime_from_timestamp(
event.timestamp
),
}
)
elif dagster_event.is_asset_observation:
if has_asset_key_index_cols:
entry_values.update(
{
"last_materialization_timestamp": utc_datetime_from_timestamp(
event.timestamp
),
}
)
return entry_values
def supports_add_asset_event_tags(self) -> bool:
return self.has_table(AssetEventTagsTable.name)
def add_asset_event_tags(
self,
event_id: int,
event_timestamp: float,
asset_key: AssetKey,
new_tags: Mapping[str, str],
) -> None:
check.int_param(event_id, "event_id")
check.float_param(event_timestamp, "event_timestamp")
check.inst_param(asset_key, "asset_key", AssetKey)
check.mapping_param(new_tags, "new_tags", key_type=str, value_type=str)
if not self.supports_add_asset_event_tags():
raise DagsterInvalidInvocationError(
"In order to add asset event tags, you must run `dagster instance migrate` to "
"create the AssetEventTags table."
)
current_tags_list = self.get_event_tags_for_asset(asset_key, filter_event_id=event_id)
asset_key_str = asset_key.to_string()
if len(current_tags_list) == 0:
current_tags: Mapping[str, str] = {}
else:
current_tags = current_tags_list[0]
with self.index_connection() as conn:
current_tags_set = set(current_tags.keys())
new_tags_set = set(new_tags.keys())
existing_tags = current_tags_set & new_tags_set
added_tags = new_tags_set.difference(existing_tags)
for tag in existing_tags:
conn.execute(
AssetEventTagsTable.update()
.where(
db.and_(
AssetEventTagsTable.c.event_id == event_id,
AssetEventTagsTable.c.asset_key == asset_key_str,
AssetEventTagsTable.c.key == tag,
)
)
.values(value=new_tags[tag])
)
if added_tags:
conn.execute(
AssetEventTagsTable.insert(),
[
dict(
event_id=event_id,
asset_key=asset_key_str,
key=tag,
value=new_tags[tag],
# Postgres requires a datetime that is in UTC but has no timezone info
# set in order to be stored correctly
event_timestamp=datetime.utcfromtimestamp(event_timestamp),
)
for tag in added_tags
],
)
def store_asset_event_tags(self, event: EventLogEntry, event_id: int) -> None:
check.inst_param(event, "event", EventLogEntry)
check.int_param(event_id, "event_id")
if event.dagster_event and event.dagster_event.asset_key:
if event.dagster_event.is_step_materialization:
tags = event.dagster_event.step_materialization_data.materialization.tags
elif event.dagster_event.is_asset_observation:
tags = event.dagster_event.asset_observation_data.asset_observation.tags
else:
tags = None
if not tags or not self.has_table(AssetEventTagsTable.name):
# If tags table does not exist, silently exit. This is to support OSS
# users who have not yet run the migration to create the table.
# On read, we will throw an error if the table does not exist.
return
check.inst_param(event.dagster_event.asset_key, "asset_key", AssetKey)
asset_key_str = event.dagster_event.asset_key.to_string()
with self.index_connection() as conn:
conn.execute(
AssetEventTagsTable.insert(),
[
dict(
event_id=event_id,
asset_key=asset_key_str,
key=key,
value=value,
# Postgres requires a datetime that is in UTC but has no timezone info
# set in order to be stored correctly
event_timestamp=datetime.utcfromtimestamp(event.timestamp),
)
for key, value in tags.items()
],
)
def store_event(self, event: EventLogEntry) -> None:
"""Store an event corresponding to a pipeline run.
Args:
event (EventLogEntry): The event to store.
"""
check.inst_param(event, "event", EventLogEntry)
insert_event_statement = self.prepare_insert_event(event)
run_id = event.run_id
event_id = None
with self.run_connection(run_id) as conn:
result = conn.execute(insert_event_statement)
event_id = result.inserted_primary_key[0]
if (
event.is_dagster_event
and event.dagster_event_type in ASSET_EVENTS
and event.dagster_event.asset_key # type: ignore
):
self.store_asset_event(event, event_id)
if event_id is None:
raise DagsterInvariantViolationError(
"Cannot store asset event tags for null event id."
)
self.store_asset_event_tags(event, event_id)
if event.is_dagster_event and event.dagster_event_type in ASSET_CHECK_EVENTS:
self.store_asset_check_event(event, event_id)
def get_records_for_run(
self,
run_id,
cursor: Optional[str] = None,
of_type: Optional[Union[DagsterEventType, Set[DagsterEventType]]] = None,
limit: Optional[int] = None,
ascending: bool = True,
) -> EventLogConnection:
"""Get all of the logs corresponding to a run.
Args:
run_id (str): The id of the run for which to fetch logs.
cursor (Optional[int]): Zero-indexed logs will be returned starting from cursor + 1,
i.e., if cursor is -1, all logs will be returned. (default: -1)
of_type (Optional[DagsterEventType]): the dagster event type to filter the logs.
limit (Optional[int]): the maximum number of events to fetch
"""
check.str_param(run_id, "run_id")
check.opt_str_param(cursor, "cursor")
check.invariant(not of_type or isinstance(of_type, (DagsterEventType, frozenset, set)))
dagster_event_types = (
{of_type}
if isinstance(of_type, DagsterEventType)
else check.opt_set_param(of_type, "dagster_event_type", of_type=DagsterEventType)
)
query = (
db_select([SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event])
.where(SqlEventLogStorageTable.c.run_id == run_id)
.order_by(
SqlEventLogStorageTable.c.id.asc()
if ascending
else SqlEventLogStorageTable.c.id.desc()
)
)
if dagster_event_types:
query = query.where(
SqlEventLogStorageTable.c.dagster_event_type.in_(
[dagster_event_type.value for dagster_event_type in dagster_event_types]
)
)
# adjust 0 based index cursor to SQL offset
if cursor is not None:
cursor_obj = EventLogCursor.parse(cursor)
if cursor_obj.is_offset_cursor():
query = query.offset(cursor_obj.offset())
elif cursor_obj.is_id_cursor():
if ascending:
query = query.where(SqlEventLogStorageTable.c.id > cursor_obj.storage_id())
else:
query = query.where(SqlEventLogStorageTable.c.id < cursor_obj.storage_id())
if limit:
query = query.limit(limit)
with self.run_connection(run_id) as conn:
results = conn.execute(query).fetchall()
last_record_id = None
try:
records = []
for (
record_id,
json_str,
) in results:
records.append(
EventLogRecord(
storage_id=record_id,
event_log_entry=deserialize_value(json_str, EventLogEntry),
)
)
last_record_id = record_id
except (seven.JSONDecodeError, DeserializationError) as err:
raise DagsterEventLogInvalidForRun(run_id=run_id) from err
if last_record_id is not None:
next_cursor = EventLogCursor.from_storage_id(last_record_id).to_string()
elif cursor:
# record fetch returned no new logs, return the same cursor
next_cursor = cursor
else:
# rely on the fact that all storage ids will be positive integers
next_cursor = EventLogCursor.from_storage_id(-1).to_string()
return EventLogConnection(
records=records,
cursor=next_cursor,
has_more=bool(limit and len(results) == limit),
)
def get_stats_for_run(self, run_id: str) -> DagsterRunStatsSnapshot:
check.str_param(run_id, "run_id")
query = (
db_select(
[
SqlEventLogStorageTable.c.dagster_event_type,
db.func.count().label("n_events_of_type"),
db.func.max(SqlEventLogStorageTable.c.timestamp).label("last_event_timestamp"),
]
)
.where(
db.and_(
SqlEventLogStorageTable.c.run_id == run_id,
SqlEventLogStorageTable.c.dagster_event_type != None, # noqa: E711
)
)
.group_by("dagster_event_type")
)
with self.run_connection(run_id) as conn:
results = conn.execute(query).fetchall()
try:
counts = {}
times = {}
for result in results:
(dagster_event_type, n_events_of_type, last_event_timestamp) = result
check.invariant(dagster_event_type is not None)
counts[dagster_event_type] = n_events_of_type
times[dagster_event_type] = last_event_timestamp
enqueued_time = times.get(DagsterEventType.PIPELINE_ENQUEUED.value, None)
launch_time = times.get(DagsterEventType.PIPELINE_STARTING.value, None)
start_time = times.get(DagsterEventType.PIPELINE_START.value, None)
end_time = times.get(
DagsterEventType.PIPELINE_SUCCESS.value,
times.get(
DagsterEventType.PIPELINE_FAILURE.value,
times.get(DagsterEventType.PIPELINE_CANCELED.value, None),
),
)
return DagsterRunStatsSnapshot(
run_id=run_id,
steps_succeeded=counts.get(DagsterEventType.STEP_SUCCESS.value, 0),
steps_failed=counts.get(DagsterEventType.STEP_FAILURE.value, 0),
materializations=counts.get(DagsterEventType.ASSET_MATERIALIZATION.value, 0),
expectations=counts.get(DagsterEventType.STEP_EXPECTATION_RESULT.value, 0),
enqueued_time=datetime_as_float(enqueued_time) if enqueued_time else None,
launch_time=datetime_as_float(launch_time) if launch_time else None,
start_time=datetime_as_float(start_time) if start_time else None,
end_time=datetime_as_float(end_time) if end_time else None,
)
except (seven.JSONDecodeError, DeserializationError) as err:
raise DagsterEventLogInvalidForRun(run_id=run_id) from err
def get_step_stats_for_run(
self, run_id: str, step_keys: Optional[Sequence[str]] = None
) -> Sequence[RunStepKeyStatsSnapshot]:
check.str_param(run_id, "run_id")
check.opt_list_param(step_keys, "step_keys", of_type=str)
# Originally, this was two different queries:
# 1) one query which aggregated top-level step stats by grouping by event type / step_key in
# a single query, using pure SQL (e.g. start_time, end_time, status, attempt counts).
# 2) one query which fetched all the raw events for a specific event type and then inspected
# the deserialized event object to aggregate stats derived from sequences of events.
# (e.g. marker events, materializations, expectations resuls, attempts timing, etc.)
#
# For simplicity, we now just do the second type of query and derive the stats in Python
# from the raw events. This has the benefit of being easier to read and also the benefit of
# being able to share code with the in-memory event log storage implementation. We may
# choose to revisit this in the future, especially if we are able to do JSON-column queries
# in SQL as a way of bypassing the serdes layer in all cases.
raw_event_query = (
db_select([SqlEventLogStorageTable.c.event])
.where(SqlEventLogStorageTable.c.run_id == run_id)
.where(SqlEventLogStorageTable.c.step_key != None) # noqa: E711
.where(
SqlEventLogStorageTable.c.dagster_event_type.in_(
[
DagsterEventType.STEP_START.value,
DagsterEventType.STEP_SUCCESS.value,
DagsterEventType.STEP_SKIPPED.value,
DagsterEventType.STEP_FAILURE.value,
DagsterEventType.STEP_RESTARTED.value,
DagsterEventType.ASSET_MATERIALIZATION.value,
DagsterEventType.STEP_EXPECTATION_RESULT.value,
DagsterEventType.STEP_RESTARTED.value,
DagsterEventType.STEP_UP_FOR_RETRY.value,
]
+ [marker_event.value for marker_event in MARKER_EVENTS]
)
)
.order_by(SqlEventLogStorageTable.c.id.asc())
)
if step_keys:
raw_event_query = raw_event_query.where(
SqlEventLogStorageTable.c.step_key.in_(step_keys)
)
with self.run_connection(run_id) as conn:
results = conn.execute(raw_event_query).fetchall()
try:
records = [deserialize_value(json_str, EventLogEntry) for (json_str,) in results]
return build_run_step_stats_from_events(run_id, records)
except (seven.JSONDecodeError, DeserializationError) as err:
raise DagsterEventLogInvalidForRun(run_id=run_id) from err
def _apply_migration(self, migration_name, migration_fn, print_fn, force):
if self.has_secondary_index(migration_name):
if not force:
if print_fn:
print_fn(f"Skipping already applied data migration: {migration_name}")
return
if print_fn:
print_fn(f"Starting data migration: {migration_name}")
migration_fn()(self, print_fn)
self.enable_secondary_index(migration_name)
if print_fn:
print_fn(f"Finished data migration: {migration_name}")
def reindex_events(self, print_fn: Optional[PrintFn] = None, force: bool = False) -> None:
"""Call this method to run any data migrations across the event_log table."""
for migration_name, migration_fn in EVENT_LOG_DATA_MIGRATIONS.items():
self._apply_migration(migration_name, migration_fn, print_fn, force)
def reindex_assets(self, print_fn: Optional[PrintFn] = None, force: bool = False) -> None:
"""Call this method to run any data migrations across the asset_keys table."""
for migration_name, migration_fn in ASSET_DATA_MIGRATIONS.items():
self._apply_migration(migration_name, migration_fn, print_fn, force)
def wipe(self) -> None:
"""Clears the event log storage."""
# Should be overridden by SqliteEventLogStorage and other storages that shard based on
# run_id
# https://stackoverflow.com/a/54386260/324449
with self.run_connection(run_id=None) as conn:
conn.execute(SqlEventLogStorageTable.delete())
conn.execute(AssetKeyTable.delete())
if self.has_table("asset_event_tags"):
conn.execute(AssetEventTagsTable.delete())
if self.has_table("dynamic_partitions"):
conn.execute(DynamicPartitionsTable.delete())
if self.has_table("concurrency_slots"):
conn.execute(ConcurrencySlotsTable.delete())
if self.has_table("pending_steps"):
conn.execute(PendingStepsTable.delete())
if self.has_table("asset_check_executions"):
conn.execute(AssetCheckExecutionsTable.delete())
self._wipe_index()
def _wipe_index(self):
with self.index_connection() as conn:
conn.execute(SqlEventLogStorageTable.delete())
conn.execute(AssetKeyTable.delete())
if self.has_table("asset_event_tags"):
conn.execute(AssetEventTagsTable.delete())
if self.has_table("dynamic_partitions"):
conn.execute(DynamicPartitionsTable.delete())
if self.has_table("concurrency_slots"):
conn.execute(ConcurrencySlotsTable.delete())
if self.has_table("pending_steps"):
conn.execute(PendingStepsTable.delete())
if self.has_table("asset_check_executions"):
conn.execute(AssetCheckExecutionsTable.delete())
def delete_events(self, run_id: str) -> None:
with self.run_connection(run_id) as conn:
self.delete_events_for_run(conn, run_id)
with self.index_connection() as conn:
self.delete_events_for_run(conn, run_id)
def delete_events_for_run(self, conn: Connection, run_id: str) -> None:
check.str_param(run_id, "run_id")
conn.execute(
SqlEventLogStorageTable.delete().where(SqlEventLogStorageTable.c.run_id == run_id)
)
@property
def is_persistent(self) -> bool:
return True
def update_event_log_record(self, record_id: int, event: EventLogEntry) -> None:
"""Utility method for migration scripts to update SQL representation of event records."""
check.int_param(record_id, "record_id")
check.inst_param(event, "event", EventLogEntry)
dagster_event_type = None
asset_key_str = None
if event.is_dagster_event:
dagster_event_type = event.dagster_event.event_type_value # type: ignore
if event.dagster_event.asset_key: # type: ignore
check.inst_param(event.dagster_event.asset_key, "asset_key", AssetKey) # type: ignore
asset_key_str = event.dagster_event.asset_key.to_string() # type: ignore
with self.run_connection(run_id=event.run_id) as conn:
conn.execute(
SqlEventLogStorageTable.update()
.where(SqlEventLogStorageTable.c.id == record_id)
.values(
event=serialize_value(event),
dagster_event_type=dagster_event_type,
timestamp=datetime.utcfromtimestamp(event.timestamp),
step_key=event.step_key,
asset_key=asset_key_str,
)
)
def get_event_log_table_data(self, run_id: str, record_id: int) -> Optional[SqlAlchemyRow]:
"""Utility method to test representation of the record in the SQL table. Returns all of
the columns stored in the event log storage (as opposed to the deserialized `EventLogEntry`).
This allows checking that certain fields are extracted to support performant lookups (e.g.
extracting `step_key` for fast filtering).
"""
with self.run_connection(run_id=run_id) as conn:
query = (
db_select([SqlEventLogStorageTable])
.where(SqlEventLogStorageTable.c.id == record_id)
.order_by(SqlEventLogStorageTable.c.id.asc())
)
return conn.execute(query).fetchone()
def has_secondary_index(self, name: str) -> bool:
"""This method uses a checkpoint migration table to see if summary data has been constructed
in a secondary index table. Can be used to checkpoint event_log data migrations.
"""
query = (
db_select([1])
.where(SecondaryIndexMigrationTable.c.name == name)
.where(SecondaryIndexMigrationTable.c.migration_completed != None) # noqa: E711
.limit(1)
)
with self.index_connection() as conn:
results = conn.execute(query).fetchall()
return len(results) > 0
def enable_secondary_index(self, name: str) -> None:
"""This method marks an event_log data migration as complete, to indicate that a summary
data migration is complete.
"""
query = SecondaryIndexMigrationTable.insert().values(
name=name,
migration_completed=datetime.now(),
)
with self.index_connection() as conn:
try:
conn.execute(query)
except db_exc.IntegrityError:
conn.execute(
SecondaryIndexMigrationTable.update()
.where(SecondaryIndexMigrationTable.c.name == name)
.values(migration_completed=datetime.now())
)
def _apply_filter_to_query(
self,
query: SqlAlchemyQuery,
event_records_filter: EventRecordsFilter,
asset_details: Optional[AssetDetails] = None,
apply_cursor_filters: bool = True,
) -> SqlAlchemyQuery:
query = query.where(
SqlEventLogStorageTable.c.dagster_event_type == event_records_filter.event_type.value
)
if event_records_filter.asset_key:
query = query.where(
SqlEventLogStorageTable.c.asset_key == event_records_filter.asset_key.to_string(),
)
if event_records_filter.asset_partitions:
query = query.where(
SqlEventLogStorageTable.c.partition.in_(event_records_filter.asset_partitions)
)
if asset_details and asset_details.last_wipe_timestamp:
query = query.where(
SqlEventLogStorageTable.c.timestamp
> datetime.utcfromtimestamp(asset_details.last_wipe_timestamp)
)
if apply_cursor_filters:
# allow the run-sharded sqlite implementation to disable this cursor filtering so that
# it can implement its own custom cursor logic, as cursor ids are not unique across run
# shards
if event_records_filter.before_cursor is not None:
before_cursor_id = (
event_records_filter.before_cursor.id
if isinstance(event_records_filter.before_cursor, RunShardedEventsCursor)
else event_records_filter.before_cursor
)
query = query.where(SqlEventLogStorageTable.c.id < before_cursor_id)
if event_records_filter.after_cursor is not None:
after_cursor_id = (
event_records_filter.after_cursor.id
if isinstance(event_records_filter.after_cursor, RunShardedEventsCursor)
else event_records_filter.after_cursor
)
query = query.where(SqlEventLogStorageTable.c.id > after_cursor_id)
if event_records_filter.before_timestamp:
query = query.where(
SqlEventLogStorageTable.c.timestamp
< datetime.utcfromtimestamp(event_records_filter.before_timestamp)
)
if event_records_filter.after_timestamp:
query = query.where(
SqlEventLogStorageTable.c.timestamp
> datetime.utcfromtimestamp(event_records_filter.after_timestamp)
)
if event_records_filter.storage_ids:
query = query.where(SqlEventLogStorageTable.c.id.in_(event_records_filter.storage_ids))
if event_records_filter.tags and self.has_table(AssetEventTagsTable.name):
# If we don't have the tags table, we'll filter the results after the query
check.invariant(
isinstance(event_records_filter.asset_key, AssetKey),
"Asset key must be set in event records filter to filter by tags.",
)
if self.supports_intersect:
intersections = [
db_select([AssetEventTagsTable.c.event_id]).where(
db.and_(
AssetEventTagsTable.c.asset_key
== event_records_filter.asset_key.to_string(), # type: ignore # (bad sig?)
AssetEventTagsTable.c.key == key,
(
AssetEventTagsTable.c.value == value
if isinstance(value, str)
else AssetEventTagsTable.c.value.in_(value)
),
)
)
for key, value in event_records_filter.tags.items()
]
query = query.where(SqlEventLogStorageTable.c.id.in_(db.intersect(*intersections)))
return query
def _apply_tags_table_joins(
self,
table: db.Table,
tags: Mapping[str, Union[str, Sequence[str]]],
asset_key: Optional[AssetKey],
) -> db.Table:
event_id_col = table.c.id if table == SqlEventLogStorageTable else table.c.event_id
i = 0
for key, value in tags.items():
i += 1
tags_table = db_subquery(
db_select([AssetEventTagsTable]), f"asset_event_tags_subquery_{i}"
)
table = table.join(
tags_table,
db.and_(
event_id_col == tags_table.c.event_id,
not asset_key or tags_table.c.asset_key == asset_key.to_string(),
tags_table.c.key == key,
(
tags_table.c.value == value
if isinstance(value, str)
else tags_table.c.value.in_(value)
),
),
)
return table
def get_event_records(
self,
event_records_filter: EventRecordsFilter,
limit: Optional[int] = None,
ascending: bool = False,
) -> Sequence[EventLogRecord]:
"""Returns a list of (record_id, record)."""
check.inst_param(event_records_filter, "event_records_filter", EventRecordsFilter)
check.opt_int_param(limit, "limit")
check.bool_param(ascending, "ascending")
if event_records_filter.asset_key:
asset_details = next(iter(self._get_assets_details([event_records_filter.asset_key])))
else:
asset_details = None
if (
event_records_filter.tags
and not self.supports_intersect
and self.has_table(AssetEventTagsTable.name)
):
table = self._apply_tags_table_joins(
SqlEventLogStorageTable, event_records_filter.tags, event_records_filter.asset_key
)
else:
table = SqlEventLogStorageTable
query = db_select(
[SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event]
).select_from(table)
query = self._apply_filter_to_query(
query=query,
event_records_filter=event_records_filter,
asset_details=asset_details,
)
if limit:
query = query.limit(limit)
if ascending:
query = query.order_by(SqlEventLogStorageTable.c.id.asc())
else:
query = query.order_by(SqlEventLogStorageTable.c.id.desc())
with self.index_connection() as conn:
results = conn.execute(query).fetchall()
event_records = []
for row_id, json_str in results:
try:
event_record = deserialize_value(json_str, NamedTuple)
if not isinstance(event_record, EventLogEntry):
logging.warning(
"Could not resolve event record as EventLogEntry for id `%s`.", row_id
)
continue
if event_records_filter.tags and not self.has_table(AssetEventTagsTable.name):
# If we can't filter tags via the tags table, filter the returned records
if limit is not None:
raise DagsterInvalidInvocationError(
"Cannot filter events on tags with a limit, without the asset event "
"tags table. To fix, run `dagster instance migrate`."
)
event_record_tags = event_record.tags
if not event_record_tags or any(
event_record_tags.get(k) != v for k, v in event_records_filter.tags.items()
):
continue
event_records.append(
EventLogRecord(storage_id=row_id, event_log_entry=event_record)
)
except seven.JSONDecodeError:
logging.warning("Could not parse event record id `%s`.", row_id)
return event_records
def supports_event_consumer_queries(self) -> bool:
return True
@property
def supports_intersect(self) -> bool:
return True
def get_logs_for_all_runs_by_log_id(
self,
after_cursor: int = -1,
dagster_event_type: Optional[Union[DagsterEventType, Set[DagsterEventType]]] = None,
limit: Optional[int] = None,
) -> Mapping[int, EventLogEntry]:
check.int_param(after_cursor, "after_cursor")
check.invariant(
after_cursor >= -1,
f"Don't know what to do with negative cursor {after_cursor}",
)
dagster_event_types = (
{dagster_event_type}
if isinstance(dagster_event_type, DagsterEventType)
else check.opt_set_param(
dagster_event_type, "dagster_event_type", of_type=DagsterEventType
)
)
query = (
db_select([SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event])
.where(SqlEventLogStorageTable.c.id > after_cursor)
.order_by(SqlEventLogStorageTable.c.id.asc())
)
if dagster_event_types:
query = query.where(
SqlEventLogStorageTable.c.dagster_event_type.in_(
[dagster_event_type.value for dagster_event_type in dagster_event_types]
)
)
if limit:
query = query.limit(limit)
with self.index_connection() as conn:
results = conn.execute(query).fetchall()
events = {}
record_id = None
try:
for (
record_id,
json_str,
) in results:
events[record_id] = deserialize_value(json_str, EventLogEntry)
except (seven.JSONDecodeError, DeserializationError):
logging.warning("Could not parse event record id `%s`.", record_id)
return events
def get_maximum_record_id(self) -> Optional[int]:
with self.index_connection() as conn:
result = conn.execute(db_select([db.func.max(SqlEventLogStorageTable.c.id)])).fetchone()
return result[0] # type: ignore
def _construct_asset_record_from_row(
self,
row,
last_materialization_record: Optional[EventLogRecord],
can_cache_asset_status_data: bool,
) -> AssetRecord:
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue
asset_key = AssetKey.from_db_string(row["asset_key"])
if asset_key:
return AssetRecord(
storage_id=row["id"],
asset_entry=AssetEntry(
asset_key=asset_key,
last_materialization_record=last_materialization_record,
last_run_id=row["last_run_id"],
asset_details=AssetDetails.from_db_string(row["asset_details"]),
cached_status=(
AssetStatusCacheValue.from_db_string(row["cached_status_data"])
if can_cache_asset_status_data
else None
),
),
)
else:
check.failed("Row did not contain asset key.")
def _get_latest_materialization_records(
self, raw_asset_rows
) -> Mapping[AssetKey, Optional[EventLogRecord]]:
# Given a list of raw asset rows, returns a mapping of asset key to latest asset materialization
# event log entry. Fetches backcompat EventLogEntry records when the last_materialization
# in the raw asset row is an AssetMaterialization.
to_backcompat_fetch = set()
results: Dict[AssetKey, Optional[EventLogRecord]] = {}
for row in raw_asset_rows:
asset_key = AssetKey.from_db_string(row["asset_key"])
if not asset_key:
continue
event_or_materialization = (
deserialize_value(row["last_materialization"], NamedTuple)
if row["last_materialization"]
else None
)
if isinstance(event_or_materialization, EventLogRecord):
results[asset_key] = event_or_materialization
else:
to_backcompat_fetch.add(asset_key)
latest_event_subquery = db_subquery(
db_select(
[
SqlEventLogStorageTable.c.asset_key,
db.func.max(SqlEventLogStorageTable.c.id).label("id"),
]
)
.where(
db.and_(
SqlEventLogStorageTable.c.asset_key.in_(
[asset_key.to_string() for asset_key in to_backcompat_fetch]
),
SqlEventLogStorageTable.c.dagster_event_type
== DagsterEventType.ASSET_MATERIALIZATION.value,
)
)
.group_by(SqlEventLogStorageTable.c.asset_key),
"latest_event_subquery",
)
backcompat_query = db_select(
[
SqlEventLogStorageTable.c.asset_key,
SqlEventLogStorageTable.c.id,
SqlEventLogStorageTable.c.event,
]
).select_from(
latest_event_subquery.join(
SqlEventLogStorageTable,
db.and_(
SqlEventLogStorageTable.c.asset_key == latest_event_subquery.c.asset_key,
SqlEventLogStorageTable.c.id == latest_event_subquery.c.id,
),
)
)
with self.index_connection() as conn:
event_rows = db_fetch_mappings(conn, backcompat_query)
for row in event_rows:
asset_key = AssetKey.from_db_string(cast(Optional[str], row["asset_key"]))
if asset_key:
results[asset_key] = EventLogRecord(
storage_id=cast(int, row["id"]),
event_log_entry=deserialize_value(cast(str, row["event"]), EventLogEntry),
)
return results
def can_cache_asset_status_data(self) -> bool:
return self.has_asset_key_col("cached_status_data")
def wipe_asset_cached_status(self, asset_key: AssetKey) -> None:
if self.can_cache_asset_status_data():
check.inst_param(asset_key, "asset_key", AssetKey)
with self.index_connection() as conn:
conn.execute(
AssetKeyTable.update()
.values(dict(cached_status_data=None))
.where(
AssetKeyTable.c.asset_key == asset_key.to_string(),
)
)
def get_asset_records(
self, asset_keys: Optional[Sequence[AssetKey]] = None
) -> Sequence[AssetRecord]:
rows = self._fetch_asset_rows(asset_keys=asset_keys)
latest_materialization_records = self._get_latest_materialization_records(rows)
can_cache_asset_status_data = self.can_cache_asset_status_data()
asset_records: List[AssetRecord] = []
for row in rows:
asset_key = AssetKey.from_db_string(row["asset_key"])
if asset_key:
asset_records.append(
self._construct_asset_record_from_row(
row,
latest_materialization_records.get(asset_key),
can_cache_asset_status_data,
)
)
return asset_records
def has_asset_key(self, asset_key: AssetKey) -> bool:
check.inst_param(asset_key, "asset_key", AssetKey)
rows = self._fetch_asset_rows(asset_keys=[asset_key])
return bool(rows)
def all_asset_keys(self):
rows = self._fetch_asset_rows()
asset_keys = [
AssetKey.from_db_string(row["asset_key"])
for row in sorted(rows, key=lambda x: x["asset_key"])
]
return [asset_key for asset_key in asset_keys if asset_key]
def get_asset_keys(
self,
prefix: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
cursor: Optional[str] = None,
) -> Sequence[AssetKey]:
rows = self._fetch_asset_rows(prefix=prefix, limit=limit, cursor=cursor)
asset_keys = [
AssetKey.from_db_string(row["asset_key"])
for row in sorted(rows, key=lambda x: x["asset_key"])
]
return [asset_key for asset_key in asset_keys if asset_key]
def get_latest_materialization_events(
self, asset_keys: Iterable[AssetKey]
) -> Mapping[AssetKey, Optional[EventLogEntry]]:
check.iterable_param(asset_keys, "asset_keys", AssetKey)
rows = self._fetch_asset_rows(asset_keys=asset_keys)
return {
asset_key: event_log_record.event_log_entry if event_log_record is not None else None
for asset_key, event_log_record in self._get_latest_materialization_records(
rows
).items()
}
def _fetch_asset_rows(
self,
asset_keys=None,
prefix: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
cursor: Optional[str] = None,
) -> Sequence[SqlAlchemyRow]:
# fetches rows containing asset_key, last_materialization, and asset_details from the DB,
# applying the filters specified in the arguments.
#
# Differs from _fetch_raw_asset_rows, in that it loops through to make sure enough rows are
# returned to satisfy the limit.
#
# returns a list of rows where each row is a tuple of serialized asset_key, materialization,
# and asset_details
should_query = True
current_cursor = cursor
if self.has_secondary_index(ASSET_KEY_INDEX_COLS):
# if we have migrated, we can limit using SQL
fetch_limit = limit
else:
# if we haven't migrated, overfetch in case the first N results are wiped
fetch_limit = max(limit, MIN_ASSET_ROWS) if limit else None
result = []
while should_query:
rows, has_more, current_cursor = self._fetch_raw_asset_rows(
asset_keys=asset_keys, prefix=prefix, limit=fetch_limit, cursor=current_cursor
)
result.extend(rows)
should_query = bool(has_more) and bool(limit) and len(result) < cast(int, limit)
is_partial_query = asset_keys is not None or bool(prefix) or bool(limit) or bool(cursor)
if not is_partial_query and self._can_mark_assets_as_migrated(rows): # type: ignore
self.enable_secondary_index(ASSET_KEY_INDEX_COLS)
return result[:limit] if limit else result
def _fetch_raw_asset_rows(
self,
asset_keys: Optional[Sequence[AssetKey]] = None,
prefix: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
cursor=None,
) -> Tuple[Iterable[SqlAlchemyRow], bool, Optional[str]]:
# fetches rows containing asset_key, last_materialization, and asset_details from the DB,
# applying the filters specified in the arguments. Does not guarantee that the number of
# rows returned will match the limit specified. This helper function is used to fetch a
# chunk of asset key rows, which may or may not be wiped.
#
# Returns a tuple of (rows, has_more, cursor), where each row is a tuple of serialized
# asset_key, materialization, and asset_details
# TODO update comment
columns = [
AssetKeyTable.c.id,
AssetKeyTable.c.asset_key,
AssetKeyTable.c.last_materialization,
AssetKeyTable.c.last_run_id,
AssetKeyTable.c.asset_details,
]
if self.can_cache_asset_status_data():
columns.extend([AssetKeyTable.c.cached_status_data])
is_partial_query = asset_keys is not None or bool(prefix) or bool(limit) or bool(cursor)
if self.has_asset_key_index_cols() and not is_partial_query:
# if the schema has been migrated, fetch the last_materialization_timestamp to see if
# we can lazily migrate the data table
columns.append(AssetKeyTable.c.last_materialization_timestamp)
columns.append(AssetKeyTable.c.wipe_timestamp)
query = db_select(columns).order_by(AssetKeyTable.c.asset_key.asc())
query = self._apply_asset_filter_to_query(query, asset_keys, prefix, limit, cursor)
if self.has_secondary_index(ASSET_KEY_INDEX_COLS):
query = query.where(
db.or_(
AssetKeyTable.c.wipe_timestamp.is_(None),
AssetKeyTable.c.last_materialization_timestamp > AssetKeyTable.c.wipe_timestamp,
)
)
with self.index_connection() as conn:
rows = db_fetch_mappings(conn, query)
return rows, False, None
with self.index_connection() as conn:
rows = db_fetch_mappings(conn, query)
wiped_timestamps_by_asset_key: Dict[AssetKey, float] = {}
row_by_asset_key: Dict[AssetKey, SqlAlchemyRow] = OrderedDict()
for row in rows:
asset_key = AssetKey.from_db_string(cast(str, row["asset_key"]))
if not asset_key:
continue
asset_details = AssetDetails.from_db_string(row["asset_details"])
if not asset_details or not asset_details.last_wipe_timestamp:
row_by_asset_key[asset_key] = row
continue
materialization_or_event_or_record = (
deserialize_value(cast(str, row["last_materialization"]), NamedTuple)
if row["last_materialization"]
else None
)
if isinstance(materialization_or_event_or_record, (EventLogRecord, EventLogEntry)):
if isinstance(materialization_or_event_or_record, EventLogRecord):
event_timestamp = materialization_or_event_or_record.event_log_entry.timestamp
else:
event_timestamp = materialization_or_event_or_record.timestamp
if asset_details.last_wipe_timestamp > event_timestamp:
# this asset has not been materialized since being wiped, skip
continue
else:
# add the key
row_by_asset_key[asset_key] = row
else:
row_by_asset_key[asset_key] = row
wiped_timestamps_by_asset_key[asset_key] = asset_details.last_wipe_timestamp
if wiped_timestamps_by_asset_key:
materialization_times = self._fetch_backcompat_materialization_times(
wiped_timestamps_by_asset_key.keys() # type: ignore
)
for asset_key, wiped_timestamp in wiped_timestamps_by_asset_key.items():
materialization_time = materialization_times.get(asset_key)
if not materialization_time or utc_datetime_from_naive(
materialization_time
) < utc_datetime_from_timestamp(wiped_timestamp):
# remove rows that have not been materialized since being wiped
row_by_asset_key.pop(asset_key)
has_more = limit and len(rows) == limit
new_cursor = rows[-1]["id"] if rows else None
return row_by_asset_key.values(), has_more, new_cursor # type: ignore
def update_asset_cached_status_data(
self, asset_key: AssetKey, cache_values: "AssetStatusCacheValue"
) -> None:
if self.can_cache_asset_status_data():
with self.index_connection() as conn:
conn.execute(
AssetKeyTable.update()
.where(
AssetKeyTable.c.asset_key == asset_key.to_string(),
)
.values(cached_status_data=serialize_value(cache_values))
)
def _fetch_backcompat_materialization_times(
self, asset_keys: Sequence[AssetKey]
) -> Mapping[AssetKey, datetime]:
# fetches the latest materialization timestamp for the given asset_keys. Uses the (slower)
# raw event log table.
backcompat_query = (
db_select(
[
SqlEventLogStorageTable.c.asset_key,
db.func.max(SqlEventLogStorageTable.c.timestamp).label("timestamp"),
]
)
.where(
SqlEventLogStorageTable.c.asset_key.in_(
[asset_key.to_string() for asset_key in asset_keys]
)
)
.group_by(SqlEventLogStorageTable.c.asset_key)
.order_by(db.func.max(SqlEventLogStorageTable.c.timestamp).asc())
)
with self.index_connection() as conn:
backcompat_rows = db_fetch_mappings(conn, backcompat_query)
return {AssetKey.from_db_string(row["asset_key"]): row["timestamp"] for row in backcompat_rows} # type: ignore
def _can_mark_assets_as_migrated(self, rows):
if not self.has_asset_key_index_cols():
return False
if self.has_secondary_index(ASSET_KEY_INDEX_COLS):
# we have already migrated
return False
for row in rows:
if not _get_from_row(row, "last_materialization_timestamp"):
return False
if _get_from_row(row, "asset_details") and not _get_from_row(row, "wipe_timestamp"):
return False
return True
def _apply_asset_filter_to_query(
self,
query: SqlAlchemyQuery,
asset_keys: Optional[Sequence[AssetKey]] = None,
prefix=None,
limit: Optional[int] = None,
cursor: Optional[str] = None,
) -> SqlAlchemyQuery:
if asset_keys is not None:
query = query.where(
AssetKeyTable.c.asset_key.in_([asset_key.to_string() for asset_key in asset_keys])
)
if prefix:
prefix_str = seven.dumps(prefix)[:-1]
query = query.where(AssetKeyTable.c.asset_key.startswith(prefix_str))
if cursor:
query = query.where(AssetKeyTable.c.asset_key > cursor)
if limit:
query = query.limit(limit)
return query
def _get_assets_details(
self, asset_keys: Sequence[AssetKey]
) -> Sequence[Optional[AssetDetails]]:
check.sequence_param(asset_keys, "asset_key", AssetKey)
rows = None
with self.index_connection() as conn:
rows = db_fetch_mappings(
conn,
db_select([AssetKeyTable.c.asset_key, AssetKeyTable.c.asset_details]).where(
AssetKeyTable.c.asset_key.in_(
[asset_key.to_string() for asset_key in asset_keys]
),
),
)
asset_key_to_details = {
cast(str, row["asset_key"]): (
deserialize_value(cast(str, row["asset_details"]), AssetDetails)
if row["asset_details"]
else None
)
for row in rows
}
# returns a list of the corresponding asset_details to provided asset_keys
return [
asset_key_to_details.get(asset_key.to_string(), None) for asset_key in asset_keys
]
def _add_assets_wipe_filter_to_query(
self,
query: SqlAlchemyQuery,
assets_details: Sequence[Optional[AssetDetails]],
asset_keys: Sequence[AssetKey],
) -> SqlAlchemyQuery:
check.invariant(
len(assets_details) == len(asset_keys),
"asset_details and asset_keys must be the same length",
)
for i in range(len(assets_details)):
asset_key, asset_details = asset_keys[i], assets_details[i]
if asset_details and asset_details.last_wipe_timestamp:
asset_key_in_row = SqlEventLogStorageTable.c.asset_key == asset_key.to_string()
# If asset key is in row, keep the row if the timestamp > wipe timestamp, else remove the row.
# If asset key is not in row, keep the row.
query = query.where(
db.or_(
db.and_(
asset_key_in_row,
SqlEventLogStorageTable.c.timestamp
> datetime.utcfromtimestamp(asset_details.last_wipe_timestamp),
),
db.not_(asset_key_in_row),
)
)
return query
def get_event_tags_for_asset(
self,
asset_key: AssetKey,
filter_tags: Optional[Mapping[str, str]] = None,
filter_event_id: Optional[int] = None,
) -> Sequence[Mapping[str, str]]:
"""Fetches asset event tags for the given asset key.
If filter_tags is provided, searches for events containing all of the filter tags. Then,
returns all tags for those events. This enables searching for multipartitioned asset
partition tags with a fixed dimension value, e.g. all of the tags for events where
"country" == "US".
If filter_event_id is provided, fetches only tags applied to the given event.
Returns a list of dicts, where each dict is a mapping of tag key to tag value for a
single event.
"""
asset_key = check.inst_param(asset_key, "asset_key", AssetKey)
filter_tags = check.opt_mapping_param(
filter_tags, "filter_tags", key_type=str, value_type=str
)
filter_event_id = check.opt_int_param(filter_event_id, "filter_event_id")
if not self.has_table(AssetEventTagsTable.name):
raise DagsterInvalidInvocationError(
"In order to search for asset event tags, you must run "
"`dagster instance migrate` to create the AssetEventTags table."
)
asset_details = self._get_assets_details([asset_key])[0]
if not filter_tags:
tags_query = db_select(
[
AssetEventTagsTable.c.key,
AssetEventTagsTable.c.value,
AssetEventTagsTable.c.event_id,
]
).where(AssetEventTagsTable.c.asset_key == asset_key.to_string())
if asset_details and asset_details.last_wipe_timestamp:
tags_query = tags_query.where(
AssetEventTagsTable.c.event_timestamp
> datetime.utcfromtimestamp(asset_details.last_wipe_timestamp)
)
elif self.supports_intersect:
def get_tag_filter_query(tag_key, tag_value):
filter_query = db_select([AssetEventTagsTable.c.event_id]).where(
db.and_(
AssetEventTagsTable.c.asset_key == asset_key.to_string(),
AssetEventTagsTable.c.key == tag_key,
AssetEventTagsTable.c.value == tag_value,
)
)
if asset_details and asset_details.last_wipe_timestamp:
filter_query = filter_query.where(
AssetEventTagsTable.c.event_timestamp
> datetime.utcfromtimestamp(asset_details.last_wipe_timestamp)
)
return filter_query
intersections = [
get_tag_filter_query(tag_key, tag_value)
for tag_key, tag_value in filter_tags.items()
]
tags_query = db_select(
[
AssetEventTagsTable.c.key,
AssetEventTagsTable.c.value,
AssetEventTagsTable.c.event_id,
]
).where(
db.and_(
AssetEventTagsTable.c.event_id.in_(db.intersect(*intersections)),
)
)
else:
table = self._apply_tags_table_joins(AssetEventTagsTable, filter_tags, asset_key)
tags_query = db_select(
[
AssetEventTagsTable.c.key,
AssetEventTagsTable.c.value,
AssetEventTagsTable.c.event_id,
]
).select_from(table)
if asset_details and asset_details.last_wipe_timestamp:
tags_query = tags_query.where(
AssetEventTagsTable.c.event_timestamp
> datetime.utcfromtimestamp(asset_details.last_wipe_timestamp)
)
if filter_event_id is not None:
tags_query = tags_query.where(AssetEventTagsTable.c.event_id == filter_event_id)
with self.index_connection() as conn:
results = conn.execute(tags_query).fetchall()
tags_by_event_id: Dict[int, Dict[str, str]] = defaultdict(dict)
for row in results:
key, value, event_id = row
tags_by_event_id[event_id][key] = value
return list(tags_by_event_id.values())
def _asset_materialization_from_json_column(
self, json_str: str
) -> Optional[AssetMaterialization]:
if not json_str:
return None
# We switched to storing the entire event record of the last materialization instead of just
# the AssetMaterialization object, so that we have access to metadata like timestamp,
# pipeline, run_id, etc.
#
# This should make certain asset queries way more performant, without having to do extra
# queries against the event log.
#
# This should be accompanied by a schema change in 0.12.0, renaming `last_materialization`
# to `last_materialization_event`, for clarity. For now, we should do some back-compat.
#
# https://github.com/dagster-io/dagster/issues/3945
event_or_materialization = deserialize_value(json_str, NamedTuple)
if isinstance(event_or_materialization, AssetMaterialization):
return event_or_materialization
if (
not isinstance(event_or_materialization, EventLogEntry)
or not event_or_materialization.is_dagster_event
or not event_or_materialization.dagster_event.asset_key # type: ignore
):
return None
return event_or_materialization.dagster_event.step_materialization_data.materialization # type: ignore
def _get_asset_key_values_on_wipe(self) -> Mapping[str, Any]:
wipe_timestamp = pendulum.now("UTC").timestamp()
values = {
"asset_details": serialize_value(AssetDetails(last_wipe_timestamp=wipe_timestamp)),
"last_run_id": None,
}
if self.has_asset_key_index_cols():
values.update(
dict(
wipe_timestamp=utc_datetime_from_timestamp(wipe_timestamp),
)
)
if self.can_cache_asset_status_data():
values.update(dict(cached_status_data=None))
return values
def wipe_asset(self, asset_key: AssetKey) -> None:
check.inst_param(asset_key, "asset_key", AssetKey)
wiped_values = self._get_asset_key_values_on_wipe()
with self.index_connection() as conn:
conn.execute(
AssetKeyTable.update()
.values(**wiped_values)
.where(
AssetKeyTable.c.asset_key == asset_key.to_string(),
)
)
def get_materialized_partitions(
self, asset_key: AssetKey, after_cursor: Optional[int] = None
) -> Set[str]:
query = (
db_select(
[
SqlEventLogStorageTable.c.partition,
db.func.max(SqlEventLogStorageTable.c.id),
]
)
.where(
db.and_(
SqlEventLogStorageTable.c.asset_key == asset_key.to_string(),
SqlEventLogStorageTable.c.partition != None, # noqa: E711
SqlEventLogStorageTable.c.dagster_event_type
== DagsterEventType.ASSET_MATERIALIZATION.value,
)
)
.group_by(SqlEventLogStorageTable.c.partition)
)
assets_details = self._get_assets_details([asset_key])
query = self._add_assets_wipe_filter_to_query(query, assets_details, [asset_key])
if after_cursor:
query = query.where(SqlEventLogStorageTable.c.id > after_cursor)
with self.index_connection() as conn:
results = conn.execute(query).fetchall()
return set([cast(str, row[0]) for row in results])
def get_materialization_count_by_partition(
self, asset_keys: Sequence[AssetKey], after_cursor: Optional[int] = None
) -> Mapping[AssetKey, Mapping[str, int]]:
check.sequence_param(asset_keys, "asset_keys", AssetKey)
query = (
db_select(
[
SqlEventLogStorageTable.c.asset_key,
SqlEventLogStorageTable.c.partition,
db.func.count(SqlEventLogStorageTable.c.id),
]
)
.where(
db.and_(
SqlEventLogStorageTable.c.asset_key.in_(
[asset_key.to_string() for asset_key in asset_keys]
),
SqlEventLogStorageTable.c.partition != None, # noqa: E711
SqlEventLogStorageTable.c.dagster_event_type
== DagsterEventType.ASSET_MATERIALIZATION.value,
)
)
.group_by(SqlEventLogStorageTable.c.asset_key, SqlEventLogStorageTable.c.partition)
)
assets_details = self._get_assets_details(asset_keys)
query = self._add_assets_wipe_filter_to_query(query, assets_details, asset_keys)
if after_cursor:
query = query.where(SqlEventLogStorageTable.c.id > after_cursor)
with self.index_connection() as conn:
results = conn.execute(query).fetchall()
materialization_count_by_partition: Dict[AssetKey, Dict[str, int]] = {
asset_key: {} for asset_key in asset_keys
}
for row in results:
asset_key = AssetKey.from_db_string(cast(Optional[str], row[0]))
if asset_key:
materialization_count_by_partition[asset_key][cast(str, row[1])] = cast(int, row[2])
return materialization_count_by_partition
def _latest_event_ids_by_partition_subquery(
self,
asset_key: AssetKey,
event_types: Sequence[DagsterEventType],
asset_partitions: Optional[Sequence[str]] = None,
before_cursor: Optional[int] = None,
after_cursor: Optional[int] = None,
):
"""Subquery for locating the latest event ids by partition for a given asset key and set
of event types.
"""
query = db_select(
[
SqlEventLogStorageTable.c.dagster_event_type,
SqlEventLogStorageTable.c.partition,
db.func.max(SqlEventLogStorageTable.c.id).label("id"),
]
).where(
db.and_(
SqlEventLogStorageTable.c.asset_key == asset_key.to_string(),
SqlEventLogStorageTable.c.partition != None, # noqa: E711
SqlEventLogStorageTable.c.dagster_event_type.in_(
[event_type.value for event_type in event_types]
),
)
)
if asset_partitions is not None:
query = query.where(SqlEventLogStorageTable.c.partition.in_(asset_partitions))
if before_cursor is not None:
query = query.where(SqlEventLogStorageTable.c.id < before_cursor)
if after_cursor is not None:
query = query.where(SqlEventLogStorageTable.c.id > after_cursor)
latest_event_ids_subquery = query.group_by(
SqlEventLogStorageTable.c.dagster_event_type, SqlEventLogStorageTable.c.partition
)
assets_details = self._get_assets_details([asset_key])
return db_subquery(
self._add_assets_wipe_filter_to_query(
latest_event_ids_subquery, assets_details, [asset_key]
),
"latest_event_ids_by_partition_subquery",
)
def get_latest_storage_id_by_partition(
self, asset_key: AssetKey, event_type: DagsterEventType
) -> Mapping[str, int]:
"""Fetch the latest materialzation storage id for each partition for a given asset key.
Returns a mapping of partition to storage id.
"""
check.inst_param(asset_key, "asset_key", AssetKey)
latest_event_ids_by_partition_subquery = self._latest_event_ids_by_partition_subquery(
asset_key, [event_type]
)
latest_event_ids_by_partition = db_select(
[
latest_event_ids_by_partition_subquery.c.partition,
latest_event_ids_by_partition_subquery.c.id,
]
)
with self.index_connection() as conn:
rows = conn.execute(latest_event_ids_by_partition).fetchall()
latest_materialization_storage_id_by_partition: Dict[str, int] = {}
for row in rows:
latest_materialization_storage_id_by_partition[cast(str, row[0])] = cast(int, row[1])
return latest_materialization_storage_id_by_partition
def get_latest_tags_by_partition(
self,
asset_key: AssetKey,
event_type: DagsterEventType,
tag_keys: Sequence[str],
asset_partitions: Optional[Sequence[str]] = None,
before_cursor: Optional[int] = None,
after_cursor: Optional[int] = None,
) -> Mapping[str, Mapping[str, str]]:
check.inst_param(asset_key, "asset_key", AssetKey)
check.inst_param(event_type, "event_type", DagsterEventType)
check.sequence_param(tag_keys, "tag_keys", of_type=str)
check.opt_nullable_sequence_param(asset_partitions, "asset_partitions", of_type=str)
check.opt_int_param(before_cursor, "before_cursor")
check.opt_int_param(after_cursor, "after_cursor")
latest_event_ids_subquery = self._latest_event_ids_by_partition_subquery(
asset_key=asset_key,
event_types=[event_type],
asset_partitions=asset_partitions,
before_cursor=before_cursor,
after_cursor=after_cursor,
)
latest_tags_by_partition_query = (
db_select(
[
latest_event_ids_subquery.c.partition,
AssetEventTagsTable.c.key,
AssetEventTagsTable.c.value,
]
)
.select_from(
latest_event_ids_subquery.join(
AssetEventTagsTable,
AssetEventTagsTable.c.event_id == latest_event_ids_subquery.c.id,
)
)
.where(AssetEventTagsTable.c.key.in_(tag_keys))
)
latest_tags_by_partition: Dict[str, Dict[str, str]] = defaultdict(dict)
with self.index_connection() as conn:
rows = conn.execute(latest_tags_by_partition_query).fetchall()
for row in rows:
latest_tags_by_partition[cast(str, row[0])][cast(str, row[1])] = cast(str, row[2])
# convert defaultdict to dict
return dict(latest_tags_by_partition)
def get_latest_asset_partition_materialization_attempts_without_materializations(
self, asset_key: AssetKey
) -> Mapping[str, Tuple[str, int]]:
"""Fetch the latest materialzation and materialization planned events for each partition of the given asset.
Return the partitions that have a materialization planned event but no matching (same run) materialization event.
These materializations could be in progress, or they could have failed. A separate query checking the run status
is required to know.
Returns a mapping of partition to [run id, event id].
"""
check.inst_param(asset_key, "asset_key", AssetKey)
latest_event_ids_subquery = self._latest_event_ids_by_partition_subquery(
asset_key,
[
DagsterEventType.ASSET_MATERIALIZATION,
DagsterEventType.ASSET_MATERIALIZATION_PLANNED,
],
)
latest_events_subquery = db_subquery(
db_select(
[
SqlEventLogStorageTable.c.dagster_event_type,
SqlEventLogStorageTable.c.partition,
SqlEventLogStorageTable.c.run_id,
SqlEventLogStorageTable.c.id,
]
).select_from(
latest_event_ids_subquery.join(
SqlEventLogStorageTable,
SqlEventLogStorageTable.c.id == latest_event_ids_subquery.c.id,
),
),
"latest_events_subquery",
)
materialization_planned_events = db_select(
[
latest_events_subquery.c.dagster_event_type,
latest_events_subquery.c.partition,
latest_events_subquery.c.run_id,
latest_events_subquery.c.id,
]
).where(
latest_events_subquery.c.dagster_event_type
== DagsterEventType.ASSET_MATERIALIZATION_PLANNED.value
)
materialization_events = db_select(
[
latest_events_subquery.c.dagster_event_type,
latest_events_subquery.c.partition,
latest_events_subquery.c.run_id,
]
).where(
latest_events_subquery.c.dagster_event_type
== DagsterEventType.ASSET_MATERIALIZATION.value
)
with self.index_connection() as conn:
materialization_planned_rows = db_fetch_mappings(conn, materialization_planned_events)
materialization_rows = db_fetch_mappings(conn, materialization_events)
materialization_planned_rows_by_partition = {
cast(str, row["partition"]): (cast(str, row["run_id"]), cast(int, row["id"]))
for row in materialization_planned_rows
}
for row in materialization_rows:
if (
row["partition"] in materialization_planned_rows_by_partition
and materialization_planned_rows_by_partition[cast(str, row["partition"])][0]
== row["run_id"]
):
materialization_planned_rows_by_partition.pop(cast(str, row["partition"]))
return materialization_planned_rows_by_partition
def _check_partitions_table(self) -> None:
# Guards against cases where the user is not running the latest migration for
# partitions storage. Should be updated when the partitions storage schema changes.
if not self.has_table("dynamic_partitions"):
raise DagsterInvalidInvocationError(
"Using dynamic partitions definitions requires the dynamic partitions table, which"
" currently does not exist. Add this table by running `dagster"
" instance migrate`."
)
def _fetch_partition_keys_for_partition_def(self, partitions_def_name: str) -> Sequence[str]:
columns = [
DynamicPartitionsTable.c.partitions_def_name,
DynamicPartitionsTable.c.partition,
]
query = (
db_select(columns)
.where(DynamicPartitionsTable.c.partitions_def_name == partitions_def_name)
.order_by(DynamicPartitionsTable.c.id)
)
with self.index_connection() as conn:
rows = conn.execute(query).fetchall()
return [cast(str, row[1]) for row in rows]
def get_dynamic_partitions(self, partitions_def_name: str) -> Sequence[str]:
"""Get the list of partition keys for a partition definition."""
self._check_partitions_table()
return self._fetch_partition_keys_for_partition_def(partitions_def_name)
def has_dynamic_partition(self, partitions_def_name: str, partition_key: str) -> bool:
self._check_partitions_table()
return partition_key in self._fetch_partition_keys_for_partition_def(partitions_def_name)
def add_dynamic_partitions(
self, partitions_def_name: str, partition_keys: Sequence[str]
) -> None:
self._check_partitions_table()
with self.index_connection() as conn:
existing_rows = conn.execute(
db_select([DynamicPartitionsTable.c.partition]).where(
db.and_(
DynamicPartitionsTable.c.partition.in_(partition_keys),
DynamicPartitionsTable.c.partitions_def_name == partitions_def_name,
)
)
).fetchall()
existing_keys = set([row[0] for row in existing_rows])
new_keys = [
partition_key
for partition_key in partition_keys
if partition_key not in existing_keys
]
if new_keys:
conn.execute(
DynamicPartitionsTable.insert(),
[
dict(partitions_def_name=partitions_def_name, partition=partition_key)
for partition_key in new_keys
],
)
def delete_dynamic_partition(self, partitions_def_name: str, partition_key: str) -> None:
self._check_partitions_table()
with self.index_connection() as conn:
conn.execute(
DynamicPartitionsTable.delete().where(
db.and_(
DynamicPartitionsTable.c.partitions_def_name == partitions_def_name,
DynamicPartitionsTable.c.partition == partition_key,
)
)
)
@property
def supports_global_concurrency_limits(self) -> bool:
return self.has_table(ConcurrencySlotsTable.name)
def set_concurrency_slots(self, concurrency_key: str, num: int) -> None:
"""Allocate a set of concurrency slots.
Args:
concurrency_key (str): The key to allocate the slots for.
num (int): The number of slots to allocate.
"""
if num > MAX_CONCURRENCY_SLOTS:
raise DagsterInvalidInvocationError(
f"Cannot have more than {MAX_CONCURRENCY_SLOTS} slots per concurrency key."
)
if num < 0:
raise DagsterInvalidInvocationError("Cannot have a negative number of slots.")
keys_to_assign = None
with self.index_connection() as conn:
count_row = conn.execute(
db_select([db.func.count()])
.select_from(ConcurrencySlotsTable)
.where(
db.and_(
ConcurrencySlotsTable.c.concurrency_key == concurrency_key,
ConcurrencySlotsTable.c.deleted == False, # noqa: E712
)
)
).fetchone()
existing = cast(int, count_row[0]) if count_row else 0
if existing > num:
# need to delete some slots, favoring ones where the slot is unallocated
rows = conn.execute(
db_select([ConcurrencySlotsTable.c.id])
.select_from(ConcurrencySlotsTable)
.where(
db.and_(
ConcurrencySlotsTable.c.concurrency_key == concurrency_key,
ConcurrencySlotsTable.c.deleted == False, # noqa: E712
)
)
.order_by(
db_case([(ConcurrencySlotsTable.c.run_id.is_(None), 1)], else_=0).desc(),
ConcurrencySlotsTable.c.id.desc(),
)
.limit(existing - num)
).fetchall()
if rows:
# mark rows as deleted
conn.execute(
ConcurrencySlotsTable.update()
.values(deleted=True)
.where(ConcurrencySlotsTable.c.id.in_([row[0] for row in rows]))
)
# actually delete rows that are marked as deleted and are not claimed... the rest
# will be deleted when the slots are released by the free_concurrency_slots
conn.execute(
ConcurrencySlotsTable.delete().where(
db.and_(
ConcurrencySlotsTable.c.deleted == True, # noqa: E712
ConcurrencySlotsTable.c.run_id == None, # noqa: E711
)
)
)
elif num > existing:
# need to add some slots
rows = [
{
"concurrency_key": concurrency_key,
"run_id": None,
"step_key": None,
"deleted": False,
}
for _ in range(existing, num)
]
conn.execute(ConcurrencySlotsTable.insert().values(rows))
keys_to_assign = [concurrency_key for _ in range(existing, num)]
if keys_to_assign:
# we've added some slots... if there are any pending steps, we can assign them now or
# they will be unutilized until free_concurrency_slots is called
self.assign_pending_steps(keys_to_assign)
def has_unassigned_slots(self, concurrency_key: str) -> bool:
with self.index_connection() as conn:
pending_row = conn.execute(
db_select([db.func.count()])
.select_from(PendingStepsTable)
.where(
db.and_(
PendingStepsTable.c.concurrency_key == concurrency_key,
PendingStepsTable.c.assigned_timestamp != None, # noqa: E711
)
)
).fetchone()
slots = conn.execute(
db_select([db.func.count()])
.select_from(ConcurrencySlotsTable)
.where(
db.and_(
ConcurrencySlotsTable.c.concurrency_key == concurrency_key,
ConcurrencySlotsTable.c.deleted == False, # noqa: E712
)
)
).fetchone()
pending_count = cast(int, pending_row[0]) if pending_row else 0
slots_count = cast(int, slots[0]) if slots else 0
return slots_count > pending_count
def check_concurrency_claim(
self, concurrency_key: str, run_id: str, step_key: str
) -> ConcurrencyClaimStatus:
with self.index_connection() as conn:
pending_row = conn.execute(
db_select(
[
PendingStepsTable.c.assigned_timestamp,
PendingStepsTable.c.priority,
PendingStepsTable.c.create_timestamp,
]
).where(
db.and_(
PendingStepsTable.c.run_id == run_id,
PendingStepsTable.c.step_key == step_key,
PendingStepsTable.c.concurrency_key == concurrency_key,
)
)
).fetchone()
if not pending_row:
# no pending step pending_row exists, the slot is blocked and the enqueued timestamp is None
return ConcurrencyClaimStatus(
concurrency_key=concurrency_key,
slot_status=ConcurrencySlotStatus.BLOCKED,
priority=None,
assigned_timestamp=None,
enqueued_timestamp=None,
)
priority = cast(int, pending_row[1]) if pending_row[1] else None
assigned_timestamp = cast(datetime, pending_row[0]) if pending_row[0] else None
create_timestamp = cast(datetime, pending_row[2]) if pending_row[2] else None
if assigned_timestamp is None:
return ConcurrencyClaimStatus(
concurrency_key=concurrency_key,
slot_status=ConcurrencySlotStatus.BLOCKED,
priority=priority,
assigned_timestamp=None,
enqueued_timestamp=create_timestamp,
)
# pending step is assigned, check to see if it's been claimed
slot_row = conn.execute(
db_select([db.func.count()]).where(
db.and_(
ConcurrencySlotsTable.c.concurrency_key == concurrency_key,
ConcurrencySlotsTable.c.run_id == run_id,
ConcurrencySlotsTable.c.step_key == step_key,
)
)
).fetchone()
return ConcurrencyClaimStatus(
concurrency_key=concurrency_key,
slot_status=(
ConcurrencySlotStatus.CLAIMED
if slot_row and slot_row[0]
else ConcurrencySlotStatus.BLOCKED
),
priority=priority,
assigned_timestamp=assigned_timestamp,
enqueued_timestamp=create_timestamp,
)
def can_claim_from_pending(self, concurrency_key: str, run_id: str, step_key: str):
with self.index_connection() as conn:
row = conn.execute(
db_select([PendingStepsTable.c.assigned_timestamp]).where(
db.and_(
PendingStepsTable.c.run_id == run_id,
PendingStepsTable.c.step_key == step_key,
PendingStepsTable.c.concurrency_key == concurrency_key,
)
)
).fetchone()
return row and row[0] is not None
def has_pending_step(self, concurrency_key: str, run_id: str, step_key: str):
with self.index_connection() as conn:
row = conn.execute(
db_select([db.func.count()])
.select_from(PendingStepsTable)
.where(
db.and_(
PendingStepsTable.c.concurrency_key == concurrency_key,
PendingStepsTable.c.run_id == run_id,
PendingStepsTable.c.step_key == step_key,
)
)
).fetchone()
return row and cast(int, row[0]) > 0
def assign_pending_steps(self, concurrency_keys: Sequence[str]):
if not concurrency_keys:
return
with self.index_connection() as conn:
for key in concurrency_keys:
row = conn.execute(
db_select([PendingStepsTable.c.id])
.where(
db.and_(
PendingStepsTable.c.concurrency_key == key,
PendingStepsTable.c.assigned_timestamp == None, # noqa: E711
)
)
.order_by(
PendingStepsTable.c.priority.desc(),
PendingStepsTable.c.create_timestamp.asc(),
)
.limit(1)
).fetchone()
if row:
conn.execute(
PendingStepsTable.update()
.where(PendingStepsTable.c.id == row[0])
.values(assigned_timestamp=db.func.now())
)
def add_pending_step(
self,
concurrency_key: str,
run_id: str,
step_key: str,
priority: Optional[int] = None,
should_assign: bool = False,
):
with self.index_connection() as conn:
try:
conn.execute(
PendingStepsTable.insert().values(
[
dict(
run_id=run_id,
step_key=step_key,
concurrency_key=concurrency_key,
priority=priority or 0,
assigned_timestamp=db.func.now() if should_assign else None,
)
]
)
)
except db_exc.IntegrityError:
# do nothing
pass
def _remove_pending_steps(self, run_id: str, step_key: Optional[str] = None):
query = PendingStepsTable.delete().where(PendingStepsTable.c.run_id == run_id)
if step_key:
query = query.where(PendingStepsTable.c.step_key == step_key)
with self.index_connection() as conn:
conn.execute(query)
def claim_concurrency_slot(
self, concurrency_key: str, run_id: str, step_key: str, priority: Optional[int] = None
) -> ConcurrencyClaimStatus:
"""Claim concurrency slot for step.
Args:
concurrency_keys (str): The concurrency key to claim.
run_id (str): The run id to claim for.
step_key (str): The step key to claim for.
"""
# first, register the step by adding to pending queue
if not self.has_pending_step(
concurrency_key=concurrency_key, run_id=run_id, step_key=step_key
):
has_unassigned_slots = self.has_unassigned_slots(concurrency_key)
self.add_pending_step(
concurrency_key=concurrency_key,
run_id=run_id,
step_key=step_key,
priority=priority,
should_assign=has_unassigned_slots,
)
# if the step is not assigned (i.e. has not been popped from queue), block the claim
claim_status = self.check_concurrency_claim(
concurrency_key=concurrency_key, run_id=run_id, step_key=step_key
)
if claim_status.is_claimed or not claim_status.is_assigned:
return claim_status
# attempt to claim a concurrency slot... this should generally work because we only assign
# based on the number of unclaimed slots, but this should act as a safeguard, using the slot
# rows as a semaphore
slot_status = self._claim_concurrency_slot(
concurrency_key=concurrency_key, run_id=run_id, step_key=step_key
)
return claim_status.with_slot_status(slot_status)
def _claim_concurrency_slot(
self, concurrency_key: str, run_id: str, step_key: str
) -> ConcurrencySlotStatus:
"""Claim a concurrency slot for the step. Helper method that is called for steps that are
popped off the priority queue.
Args:
concurrency_key (str): The concurrency key to claim.
run_id (str): The run id to claim a slot for.
step_key (str): The step key to claim a slot for.
"""
with self.index_connection() as conn:
result = conn.execute(
db_select([ConcurrencySlotsTable.c.id])
.select_from(ConcurrencySlotsTable)
.where(
db.and_(
ConcurrencySlotsTable.c.concurrency_key == concurrency_key,
ConcurrencySlotsTable.c.step_key == None, # noqa: E711
ConcurrencySlotsTable.c.deleted == False, # noqa: E712
)
)
.with_for_update(skip_locked=True)
.limit(1)
).fetchone()
if not result or not result[0]:
return ConcurrencySlotStatus.BLOCKED
if not conn.execute(
ConcurrencySlotsTable.update()
.values(run_id=run_id, step_key=step_key)
.where(ConcurrencySlotsTable.c.id == result[0])
).rowcount:
return ConcurrencySlotStatus.BLOCKED
return ConcurrencySlotStatus.CLAIMED
def get_concurrency_keys(self) -> Set[str]:
"""Get the set of concurrency limited keys."""
with self.index_connection() as conn:
rows = conn.execute(
db_select([ConcurrencySlotsTable.c.concurrency_key])
.select_from(ConcurrencySlotsTable)
.where(ConcurrencySlotsTable.c.deleted == False) # noqa: E712
.distinct()
).fetchall()
return {cast(str, row[0]) for row in rows}
def get_concurrency_info(self, concurrency_key: str) -> ConcurrencyKeyInfo:
"""Get the list of concurrency slots for a given concurrency key.
Args:
concurrency_key (str): The concurrency key to get the slots for.
Returns:
List[Tuple[str, int]]: A list of tuples of run_id and the number of slots it is
occupying for the given concurrency key.
"""
with self.index_connection() as conn:
slot_query = (
db_select(
[
ConcurrencySlotsTable.c.run_id,
ConcurrencySlotsTable.c.deleted,
db.func.count().label("count"),
]
)
.select_from(ConcurrencySlotsTable)
.where(ConcurrencySlotsTable.c.concurrency_key == concurrency_key)
.group_by(ConcurrencySlotsTable.c.run_id, ConcurrencySlotsTable.c.deleted)
)
slot_rows = db_fetch_mappings(conn, slot_query)
pending_query = (
db_select(
[
PendingStepsTable.c.run_id,
db_case(
[(PendingStepsTable.c.assigned_timestamp.is_(None), False)],
else_=True,
).label("is_assigned"),
db.func.count().label("count"),
]
)
.select_from(PendingStepsTable)
.where(PendingStepsTable.c.concurrency_key == concurrency_key)
.group_by(PendingStepsTable.c.run_id, "is_assigned")
)
pending_rows = db_fetch_mappings(conn, pending_query)
return ConcurrencyKeyInfo(
concurrency_key=concurrency_key,
slot_count=sum(
[
cast(int, slot_row["count"])
for slot_row in slot_rows
if not slot_row["deleted"]
]
),
active_slot_count=sum(
[cast(int, slot_row["count"]) for slot_row in slot_rows if slot_row["run_id"]]
),
active_run_ids={
cast(str, slot_row["run_id"]) for slot_row in slot_rows if slot_row["run_id"]
},
pending_step_count=sum(
[cast(int, row["count"]) for row in pending_rows if not row["is_assigned"]]
),
pending_run_ids={
cast(str, row["run_id"]) for row in pending_rows if not row["is_assigned"]
},
assigned_step_count=sum(
[cast(int, row["count"]) for row in pending_rows if row["is_assigned"]]
),
assigned_run_ids={
cast(str, row["run_id"]) for row in pending_rows if row["is_assigned"]
},
)
def get_concurrency_run_ids(self) -> Set[str]:
with self.index_connection() as conn:
rows = conn.execute(db_select([PendingStepsTable.c.run_id]).distinct()).fetchall()
return set([cast(str, row[0]) for row in rows])
def free_concurrency_slots_for_run(self, run_id: str) -> None:
freed_concurrency_keys = self._free_concurrency_slots(run_id=run_id)
self._remove_pending_steps(run_id=run_id)
if freed_concurrency_keys:
# assign any pending steps that can now claim a slot
self.assign_pending_steps(freed_concurrency_keys)
def free_concurrency_slot_for_step(self, run_id: str, step_key: str) -> None:
freed_concurrency_keys = self._free_concurrency_slots(run_id=run_id, step_key=step_key)
self._remove_pending_steps(run_id=run_id, step_key=step_key)
if freed_concurrency_keys:
# assign any pending steps that can now claim a slot
self.assign_pending_steps(freed_concurrency_keys)
def _free_concurrency_slots(self, run_id: str, step_key: Optional[str] = None) -> Sequence[str]:
"""Frees concurrency slots for a given run/step.
Args:
run_id (str): The run id to free the slots for.
step_key (Optional[str]): The step key to free the slots for. If not provided, all the
slots for all the steps of the run will be freed.
"""
with self.index_connection() as conn:
# first delete any rows that apply and are marked as deleted. This happens when the
# configured number of slots has been reduced, and some of the pruned slots included
# ones that were already allocated to the run/step
delete_query = ConcurrencySlotsTable.delete().where(
db.and_(
ConcurrencySlotsTable.c.run_id == run_id,
ConcurrencySlotsTable.c.deleted == True, # noqa: E712
)
)
if step_key:
delete_query = delete_query.where(ConcurrencySlotsTable.c.step_key == step_key)
conn.execute(delete_query)
# next, fetch the slots to free up, while grabbing the concurrency keys so that we can
# allocate any pending steps from the queue for the freed slots, if necessary
select_query = (
db_select([ConcurrencySlotsTable.c.id, ConcurrencySlotsTable.c.concurrency_key])
.select_from(ConcurrencySlotsTable)
.where(ConcurrencySlotsTable.c.run_id == run_id)
.with_for_update(skip_locked=True)
)
if step_key:
select_query = select_query.where(ConcurrencySlotsTable.c.step_key == step_key)
rows = conn.execute(select_query).fetchall()
if not rows:
return []
# now, actually free the slots
conn.execute(
ConcurrencySlotsTable.update()
.values(run_id=None, step_key=None)
.where(
db.and_(
ConcurrencySlotsTable.c.id.in_([row[0] for row in rows]),
)
)
)
# return the concurrency keys for the freed slots
return [cast(str, row[1]) for row in rows]
def store_asset_check_event(self, event: EventLogEntry, event_id: Optional[int]) -> None:
check.inst_param(event, "event", EventLogEntry)
check.opt_int_param(event_id, "event_id")
check.invariant(
self.supports_asset_checks,
"Asset checks require a database schema migration. Run `dagster instance migrate`.",
)
if event.dagster_event_type == DagsterEventType.ASSET_CHECK_EVALUATION_PLANNED:
self._store_asset_check_evaluation_planned(event, event_id)
if event.dagster_event_type == DagsterEventType.ASSET_CHECK_EVALUATION:
if event.run_id == "" or event.run_id is None:
self._store_runless_asset_check_evaluation(event, event_id)
else:
self._update_asset_check_evaluation(event, event_id)
def _store_asset_check_evaluation_planned(
self, event: EventLogEntry, event_id: Optional[int]
) -> None:
planned = cast(
AssetCheckEvaluationPlanned, check.not_none(event.dagster_event).event_specific_data
)
with self.index_connection() as conn:
conn.execute(
AssetCheckExecutionsTable.insert().values(
asset_key=planned.asset_key.to_string(),
check_name=planned.check_name,
run_id=event.run_id,
execution_status=AssetCheckExecutionRecordStatus.PLANNED.value,
)
)
def _store_runless_asset_check_evaluation(
self, event: EventLogEntry, event_id: Optional[int]
) -> None:
evaluation = cast(
AssetCheckEvaluation, check.not_none(event.dagster_event).event_specific_data
)
with self.index_connection() as conn:
conn.execute(
AssetCheckExecutionsTable.insert().values(
asset_key=evaluation.asset_key.to_string(),
check_name=evaluation.check_name,
run_id=event.run_id,
execution_status=(
AssetCheckExecutionRecordStatus.SUCCEEDED.value
if evaluation.success
else AssetCheckExecutionRecordStatus.FAILED.value
),
evaluation_event=serialize_value(event),
evaluation_event_timestamp=datetime.utcfromtimestamp(event.timestamp),
evaluation_event_storage_id=event_id,
materialization_event_storage_id=(
evaluation.target_materialization_data.storage_id
if evaluation.target_materialization_data
else None
),
)
)
def _update_asset_check_evaluation(self, event: EventLogEntry, event_id: Optional[int]) -> None:
evaluation = cast(
AssetCheckEvaluation, check.not_none(event.dagster_event).event_specific_data
)
with self.index_connection() as conn:
rows_updated = conn.execute(
AssetCheckExecutionsTable.update()
.where(
# (asset_key, check_name, run_id) uniquely identifies the row created for the planned event
db.and_(
AssetCheckExecutionsTable.c.asset_key == evaluation.asset_key.to_string(),
AssetCheckExecutionsTable.c.check_name == evaluation.check_name,
AssetCheckExecutionsTable.c.run_id == event.run_id,
)
)
.values(
execution_status=(
AssetCheckExecutionRecordStatus.SUCCEEDED.value
if evaluation.success
else AssetCheckExecutionRecordStatus.FAILED.value
),
evaluation_event=serialize_value(event),
evaluation_event_timestamp=datetime.utcfromtimestamp(event.timestamp),
evaluation_event_storage_id=event_id,
materialization_event_storage_id=(
evaluation.target_materialization_data.storage_id
if evaluation.target_materialization_data
else None
),
)
).rowcount
if rows_updated != 1:
raise DagsterInvariantViolationError(
"Expected to update one row for asset check evaluation, but updated"
f" {rows_updated}."
)
def get_asset_check_executions(
self,
asset_key: AssetKey,
check_name: str,
limit: int,
cursor: Optional[int] = None,
materialization_event_storage_id: Optional[int] = None,
include_planned: bool = True,
) -> Sequence[AssetCheckExecutionRecord]:
query = (
db_select(
[
AssetCheckExecutionsTable.c.id,
AssetCheckExecutionsTable.c.run_id,
AssetCheckExecutionsTable.c.execution_status,
AssetCheckExecutionsTable.c.evaluation_event,
AssetCheckExecutionsTable.c.create_timestamp,
]
)
.where(
db.and_(
AssetCheckExecutionsTable.c.asset_key == asset_key.to_string(),
AssetCheckExecutionsTable.c.check_name == check_name,
)
)
.order_by(AssetCheckExecutionsTable.c.id.desc())
).limit(limit)
if cursor:
query = query.where(AssetCheckExecutionsTable.c.id < cursor)
if not include_planned:
query = query.where(
AssetCheckExecutionsTable.c.execution_status
!= AssetCheckExecutionRecordStatus.PLANNED.value
)
if materialization_event_storage_id:
if include_planned:
# rows in PLANNED status are not associated with a materialization event yet
query = query.where(
db.or_(
AssetCheckExecutionsTable.c.materialization_event_storage_id
== materialization_event_storage_id,
AssetCheckExecutionsTable.c.execution_status
== AssetCheckExecutionRecordStatus.PLANNED.value,
)
)
else:
query = query.where(
AssetCheckExecutionsTable.c.materialization_event_storage_id
== materialization_event_storage_id
)
with self.index_connection() as conn:
rows = conn.execute(query).fetchall()
return [
AssetCheckExecutionRecord(
id=cast(int, row[0]),
run_id=cast(str, row[1]),
status=AssetCheckExecutionRecordStatus(row[2]),
evaluation_event=(
deserialize_value(cast(str, row[3]), EventLogEntry) if row[3] else None
),
create_timestamp=datetime_as_float(cast(datetime, row[4])),
)
for row in rows
]
@property
def supports_asset_checks(self):
return self.has_table(AssetCheckExecutionsTable.name)
def _get_from_row(row: SqlAlchemyRow, column: str) -> object:
"""Utility function for extracting a column from a sqlalchemy row proxy, since '_asdict' is not
supported in sqlalchemy 1.3.
"""
if column not in row.keys():
return None
return row[column]