import logging
import os
from abc import abstractmethod
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
Any,
ContextManager,
Dict,
Iterable,
Iterator,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
import pendulum
import sqlalchemy as db
import sqlalchemy.exc as db_exc
from sqlalchemy.engine import Connection
from typing_extensions import TypeAlias
import dagster._check as check
import dagster._seven as seven
from dagster._core.assets import AssetDetails
from dagster._core.definitions.asset_check_evaluation import (
AssetCheckEvaluation,
AssetCheckEvaluationPlanned,
)
from dagster._core.definitions.asset_check_spec import AssetCheckKey
from dagster._core.definitions.events import AssetKey, AssetMaterialization
from dagster._core.errors import (
DagsterEventLogInvalidForRun,
DagsterInvalidInvocationError,
DagsterInvariantViolationError,
)
from dagster._core.event_api import (
EventRecordsResult,
RunShardedEventsCursor,
RunStatusChangeRecordsFilter,
)
from dagster._core.events import (
ASSET_CHECK_EVENTS,
ASSET_EVENTS,
EVENT_TYPE_TO_PIPELINE_RUN_STATUS,
MARKER_EVENTS,
DagsterEventType,
)
from dagster._core.events.log import EventLogEntry
from dagster._core.execution.stats import RunStepKeyStatsSnapshot, build_run_step_stats_from_events
from dagster._core.storage.asset_check_execution_record import (
AssetCheckExecutionRecord,
AssetCheckExecutionRecordStatus,
)
from dagster._core.storage.sql import SqlAlchemyQuery, SqlAlchemyRow
from dagster._core.storage.sqlalchemy_compat import (
db_case,
db_fetch_mappings,
db_select,
db_subquery,
)
from dagster._serdes import (
deserialize_value,
serialize_value,
)
from dagster._serdes.errors import DeserializationError
from dagster._utils import (
PrintFn,
datetime_as_float,
utc_datetime_from_naive,
utc_datetime_from_timestamp,
)
from dagster._utils.concurrency import (
ClaimedSlotInfo,
ConcurrencyClaimStatus,
ConcurrencyKeyInfo,
ConcurrencySlotStatus,
PendingStepInfo,
get_max_concurrency_limit_value,
)
from ..dagster_run import DagsterRunStatsSnapshot
from .base import (
AssetEntry,
AssetRecord,
AssetRecordsFilter,
EventLogConnection,
EventLogCursor,
EventLogRecord,
EventLogStorage,
EventRecordsFilter,
PlannedMaterializationInfo,
)
from .migration import ASSET_DATA_MIGRATIONS, ASSET_KEY_INDEX_COLS, EVENT_LOG_DATA_MIGRATIONS
from .schema import (
AssetCheckExecutionsTable,
AssetEventTagsTable,
AssetKeyTable,
ConcurrencyLimitsTable,
ConcurrencySlotsTable,
DynamicPartitionsTable,
PendingStepsTable,
SecondaryIndexMigrationTable,
SqlEventLogStorageTable,
)
if TYPE_CHECKING:
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue
MIN_ASSET_ROWS = 25
DEFAULT_MAX_LIMIT_EVENT_RECORDS = 10000
def get_max_event_records_limit() -> int:
max_value = os.getenv("MAX_LIMIT_GET_EVENT_RECORDS")
if not max_value:
return DEFAULT_MAX_LIMIT_EVENT_RECORDS
try:
return int(max_value)
except ValueError:
return DEFAULT_MAX_LIMIT_EVENT_RECORDS
def enforce_max_records_limit(limit: int):
max_limit = get_max_event_records_limit()
if limit > max_limit:
raise DagsterInvariantViolationError(
f"Cannot fetch more than {max_limit} event records at a time. Requested {limit}."
)
# We are using third-party library objects for DB connections-- at this time, these libraries are
# untyped. When/if we upgrade to typed variants, the `Any` here can be replaced or the alias as a
# whole can be dropped.
SqlDbConnection: TypeAlias = Any
[docs]class SqlEventLogStorage(EventLogStorage):
"""Base class for SQL backed event log storages.
Distinguishes between run-based connections and index connections in order to support run-level
sharding, while maintaining the ability to do cross-run queries
"""
@abstractmethod
def run_connection(self, run_id: Optional[str]) -> ContextManager[Connection]:
"""Context manager yielding a connection to access the event logs for a specific run.
Args:
run_id (Optional[str]): Enables those storages which shard based on run_id, e.g.,
SqliteEventLogStorage, to connect appropriately.
"""
@abstractmethod
def index_connection(self) -> ContextManager[Connection]:
"""Context manager yielding a connection to access cross-run indexed tables."""
@contextmanager
def index_transaction(self) -> Iterator[Connection]:
"""Context manager yielding a connection to the index shard that has begun a transaction."""
with self.index_connection() as conn:
if conn.in_transaction():
yield conn
else:
with conn.begin():
yield conn
@abstractmethod
def upgrade(self) -> None:
"""This method should perform any schema migrations necessary to bring an
out-of-date instance of the storage up to date.
"""
@abstractmethod
def has_table(self, table_name: str) -> bool:
"""This method checks if a table exists in the database."""
def prepare_insert_event(self, event: EventLogEntry) -> Any:
"""Helper method for preparing the event log SQL insertion statement. Abstracted away to
have a single place for the logical table representation of the event, while having a way
for SQL backends to implement different execution implementations for `store_event`. See
the `dagster-postgres` implementation which overrides the generic SQL implementation of
`store_event`.
"""
# https://stackoverflow.com/a/54386260/324449
return SqlEventLogStorageTable.insert().values(**self._event_to_row(event))
def prepare_insert_event_batch(self, events: Sequence[EventLogEntry]) -> Any:
# https://stackoverflow.com/a/54386260/324449
return SqlEventLogStorageTable.insert().values(
[self._event_to_row(event) for event in events]
)
def _event_to_row(self, event: EventLogEntry) -> Dict[str, Any]:
dagster_event_type = None
asset_key_str = None
partition = None
step_key = event.step_key
if event.is_dagster_event:
dagster_event = event.get_dagster_event()
dagster_event_type = dagster_event.event_type_value
step_key = dagster_event.step_key
if dagster_event.asset_key:
check.inst_param(dagster_event.asset_key, "asset_key", AssetKey)
asset_key_str = dagster_event.asset_key.to_string()
if dagster_event.partition:
partition = dagster_event.partition
return {
"run_id": event.run_id,
"event": serialize_value(event),
"dagster_event_type": dagster_event_type,
"timestamp": self._event_insert_timestamp(event),
"step_key": step_key,
"asset_key": asset_key_str,
"partition": partition,
}
def has_asset_key_col(self, column_name: str) -> bool:
with self.index_connection() as conn:
column_names = [x.get("name") for x in db.inspect(conn).get_columns(AssetKeyTable.name)]
return column_name in column_names
def has_asset_key_index_cols(self) -> bool:
return self.has_asset_key_col("last_materialization_timestamp")
def store_asset_event(self, event: EventLogEntry, event_id: int):
check.inst_param(event, "event", EventLogEntry)
if not (event.dagster_event and event.dagster_event.asset_key):
return
# We switched to storing the entire event record of the last materialization instead of just
# the AssetMaterialization object, so that we have access to metadata like timestamp,
# pipeline, run_id, etc.
#
# This should make certain asset queries way more performant, without having to do extra
# queries against the event log.
#
# This should be accompanied by a schema change in 0.12.0, renaming `last_materialization`
# to `last_materialization_event`, for clarity. For now, we should do some back-compat.
#
# https://github.com/dagster-io/dagster/issues/3945
values = self._get_asset_entry_values(event, event_id, self.has_asset_key_index_cols())
insert_statement = AssetKeyTable.insert().values(
asset_key=event.dagster_event.asset_key.to_string(), **values
)
update_statement = (
AssetKeyTable.update()
.values(**values)
.where(
AssetKeyTable.c.asset_key == event.dagster_event.asset_key.to_string(),
)
)
with self.index_connection() as conn:
try:
conn.execute(insert_statement)
except db_exc.IntegrityError:
conn.execute(update_statement)
def _get_asset_entry_values(
self, event: EventLogEntry, event_id: int, has_asset_key_index_cols: bool
) -> Dict[str, Any]:
# The AssetKeyTable contains a `last_materialization_timestamp` column that is exclusively
# used to determine if an asset exists (last materialization timestamp > wipe timestamp).
# This column is used nowhere else, and as of AssetObservation/AssetMaterializationPlanned
# event creation, we want to extend this functionality to ensure that assets with any event
# (observation, materialization, or materialization planned) yielded with timestamp
# > wipe timestamp display in the Dagster UI.
# As of the following PRs, we update last_materialization_timestamp to store the timestamp
# of the latest asset observation, materialization, or materialization_planned that has occurred.
# https://github.com/dagster-io/dagster/pull/6885
# https://github.com/dagster-io/dagster/pull/7319
entry_values: Dict[str, Any] = {}
dagster_event = check.not_none(event.dagster_event)
if dagster_event.is_step_materialization:
entry_values.update(
{
"last_materialization": serialize_value(
EventLogRecord(
storage_id=event_id,
event_log_entry=event,
)
),
"last_run_id": event.run_id,
}
)
if has_asset_key_index_cols:
entry_values.update(
{
"last_materialization_timestamp": utc_datetime_from_timestamp(
event.timestamp
),
}
)
elif dagster_event.is_asset_materialization_planned:
# The AssetKeyTable also contains a `last_run_id` column that is updated upon asset
# materialization. This column was not being used until the below PR. This new change
# writes to the column upon `ASSET_MATERIALIZATION_PLANNED` events to fetch the last
# run id for a set of assets in one roundtrip call to event log storage.
# https://github.com/dagster-io/dagster/pull/7319
entry_values.update({"last_run_id": event.run_id})
if has_asset_key_index_cols:
entry_values.update(
{
"last_materialization_timestamp": utc_datetime_from_timestamp(
event.timestamp
),
}
)
elif dagster_event.is_asset_observation:
if has_asset_key_index_cols:
entry_values.update(
{
"last_materialization_timestamp": utc_datetime_from_timestamp(
event.timestamp
),
}
)
return entry_values
def supports_add_asset_event_tags(self) -> bool:
return self.has_table(AssetEventTagsTable.name)
def add_asset_event_tags(
self,
event_id: int,
event_timestamp: float,
asset_key: AssetKey,
new_tags: Mapping[str, str],
) -> None:
check.int_param(event_id, "event_id")
check.float_param(event_timestamp, "event_timestamp")
check.inst_param(asset_key, "asset_key", AssetKey)
check.mapping_param(new_tags, "new_tags", key_type=str, value_type=str)
if not self.supports_add_asset_event_tags():
raise DagsterInvalidInvocationError(
"In order to add asset event tags, you must run `dagster instance migrate` to "
"create the AssetEventTags table."
)
current_tags_list = self.get_event_tags_for_asset(asset_key, filter_event_id=event_id)
asset_key_str = asset_key.to_string()
if len(current_tags_list) == 0:
current_tags: Mapping[str, str] = {}
else:
current_tags = current_tags_list[0]
with self.index_connection() as conn:
current_tags_set = set(current_tags.keys())
new_tags_set = set(new_tags.keys())
existing_tags = current_tags_set & new_tags_set
added_tags = new_tags_set.difference(existing_tags)
for tag in existing_tags:
conn.execute(
AssetEventTagsTable.update()
.where(
db.and_(
AssetEventTagsTable.c.event_id == event_id,
AssetEventTagsTable.c.asset_key == asset_key_str,
AssetEventTagsTable.c.key == tag,
)
)
.values(value=new_tags[tag])
)
if added_tags:
conn.execute(
AssetEventTagsTable.insert(),
[
dict(
event_id=event_id,
asset_key=asset_key_str,
key=tag,
value=new_tags[tag],
# Postgres requires a datetime that is in UTC but has no timezone info
# set in order to be stored correctly
event_timestamp=datetime.utcfromtimestamp(event_timestamp),
)
for tag in added_tags
],
)
def store_asset_event_tags(
self, events: Sequence[EventLogEntry], event_ids: Sequence[int]
) -> None:
check.sequence_param(events, "events", EventLogEntry)
check.sequence_param(event_ids, "event_ids", int)
all_values = [
dict(
event_id=event_id,
asset_key=check.not_none(event.get_dagster_event().asset_key).to_string(),
key=key,
value=value,
event_timestamp=self._event_insert_timestamp(event),
)
for event_id, event in zip(event_ids, events)
for key, value in self._tags_for_asset_event(event).items()
]
# Only execute if tags table exists. This is to support OSS users who have not yet run the
# migration to create the table. On read, we will throw an error if the table does not
# exist.
if len(all_values) > 0 and self.has_table(AssetEventTagsTable.name):
with self.index_connection() as conn:
conn.execute(AssetEventTagsTable.insert(), all_values)
def _tags_for_asset_event(self, event: EventLogEntry) -> Mapping[str, str]:
if event.dagster_event and event.dagster_event.asset_key:
if event.dagster_event.is_step_materialization:
return (
event.get_dagster_event().step_materialization_data.materialization.tags or {}
)
elif event.dagster_event.is_asset_observation:
return event.get_dagster_event().asset_observation_data.asset_observation.tags
return {}
def store_event(self, event: EventLogEntry) -> None:
"""Store an event corresponding to a pipeline run.
Args:
event (EventLogEntry): The event to store.
"""
check.inst_param(event, "event", EventLogEntry)
insert_event_statement = self.prepare_insert_event(event)
run_id = event.run_id
event_id = None
with self.run_connection(run_id) as conn:
result = conn.execute(insert_event_statement)
event_id = result.inserted_primary_key[0]
if (
event.is_dagster_event
and event.dagster_event_type in ASSET_EVENTS
and event.dagster_event.asset_key # type: ignore
):
self.store_asset_event(event, event_id)
if event_id is None:
raise DagsterInvariantViolationError(
"Cannot store asset event tags for null event id."
)
self.store_asset_event_tags([event], [event_id])
if event.is_dagster_event and event.dagster_event_type in ASSET_CHECK_EVENTS:
self.store_asset_check_event(event, event_id)
def get_records_for_run(
self,
run_id,
cursor: Optional[str] = None,
of_type: Optional[Union[DagsterEventType, Set[DagsterEventType]]] = None,
limit: Optional[int] = None,
ascending: bool = True,
) -> EventLogConnection:
"""Get all of the logs corresponding to a run.
Args:
run_id (str): The id of the run for which to fetch logs.
cursor (Optional[int]): Zero-indexed logs will be returned starting from cursor + 1,
i.e., if cursor is -1, all logs will be returned. (default: -1)
of_type (Optional[DagsterEventType]): the dagster event type to filter the logs.
limit (Optional[int]): the maximum number of events to fetch
"""
check.str_param(run_id, "run_id")
check.opt_str_param(cursor, "cursor")
check.invariant(not of_type or isinstance(of_type, (DagsterEventType, frozenset, set)))
dagster_event_types = (
{of_type}
if isinstance(of_type, DagsterEventType)
else check.opt_set_param(of_type, "dagster_event_type", of_type=DagsterEventType)
)
query = (
db_select([SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event])
.where(SqlEventLogStorageTable.c.run_id == run_id)
.order_by(
SqlEventLogStorageTable.c.id.asc()
if ascending
else SqlEventLogStorageTable.c.id.desc()
)
)
if dagster_event_types:
query = query.where(
SqlEventLogStorageTable.c.dagster_event_type.in_(
[dagster_event_type.value for dagster_event_type in dagster_event_types]
)
)
# adjust 0 based index cursor to SQL offset
if cursor is not None:
cursor_obj = EventLogCursor.parse(cursor)
if cursor_obj.is_offset_cursor():
query = query.offset(cursor_obj.offset())
elif cursor_obj.is_id_cursor():
if ascending:
query = query.where(SqlEventLogStorageTable.c.id > cursor_obj.storage_id())
else:
query = query.where(SqlEventLogStorageTable.c.id < cursor_obj.storage_id())
if limit:
query = query.limit(limit)
with self.run_connection(run_id) as conn:
results = conn.execute(query).fetchall()
last_record_id = None
try:
records = []
for (
record_id,
json_str,
) in results:
records.append(
EventLogRecord(
storage_id=record_id,
event_log_entry=deserialize_value(json_str, EventLogEntry),
)
)
last_record_id = record_id
except (seven.JSONDecodeError, DeserializationError) as err:
raise DagsterEventLogInvalidForRun(run_id=run_id) from err
if last_record_id is not None:
next_cursor = EventLogCursor.from_storage_id(last_record_id).to_string()
elif cursor:
# record fetch returned no new logs, return the same cursor
next_cursor = cursor
else:
# rely on the fact that all storage ids will be positive integers
next_cursor = EventLogCursor.from_storage_id(-1).to_string()
return EventLogConnection(
records=records,
cursor=next_cursor,
has_more=bool(limit and len(results) == limit),
)
def get_stats_for_run(self, run_id: str) -> DagsterRunStatsSnapshot:
check.str_param(run_id, "run_id")
query = (
db_select(
[
SqlEventLogStorageTable.c.dagster_event_type,
db.func.count().label("n_events_of_type"),
db.func.max(SqlEventLogStorageTable.c.timestamp).label("last_event_timestamp"),
]
)
.where(
db.and_(
SqlEventLogStorageTable.c.run_id == run_id,
SqlEventLogStorageTable.c.dagster_event_type != None, # noqa: E711
)
)
.group_by("dagster_event_type")
)
with self.run_connection(run_id) as conn:
results = conn.execute(query).fetchall()
try:
counts = {}
times = {}
for result in results:
(dagster_event_type, n_events_of_type, last_event_timestamp) = result
check.invariant(dagster_event_type is not None)
counts[dagster_event_type] = n_events_of_type
times[dagster_event_type] = last_event_timestamp
enqueued_time = times.get(DagsterEventType.PIPELINE_ENQUEUED.value, None)
launch_time = times.get(DagsterEventType.PIPELINE_STARTING.value, None)
start_time = times.get(DagsterEventType.PIPELINE_START.value, None)
end_time = times.get(
DagsterEventType.PIPELINE_SUCCESS.value,
times.get(
DagsterEventType.PIPELINE_FAILURE.value,
times.get(DagsterEventType.PIPELINE_CANCELED.value, None),
),
)
return DagsterRunStatsSnapshot(
run_id=run_id,
steps_succeeded=counts.get(DagsterEventType.STEP_SUCCESS.value, 0),
steps_failed=counts.get(DagsterEventType.STEP_FAILURE.value, 0),
materializations=counts.get(DagsterEventType.ASSET_MATERIALIZATION.value, 0),
expectations=counts.get(DagsterEventType.STEP_EXPECTATION_RESULT.value, 0),
enqueued_time=datetime_as_float(enqueued_time) if enqueued_time else None,
launch_time=datetime_as_float(launch_time) if launch_time else None,
start_time=datetime_as_float(start_time) if start_time else None,
end_time=datetime_as_float(end_time) if end_time else None,
)
except (seven.JSONDecodeError, DeserializationError) as err:
raise DagsterEventLogInvalidForRun(run_id=run_id) from err
def get_step_stats_for_run(
self, run_id: str, step_keys: Optional[Sequence[str]] = None
) -> Sequence[RunStepKeyStatsSnapshot]:
check.str_param(run_id, "run_id")
check.opt_list_param(step_keys, "step_keys", of_type=str)
# Originally, this was two different queries:
# 1) one query which aggregated top-level step stats by grouping by event type / step_key in
# a single query, using pure SQL (e.g. start_time, end_time, status, attempt counts).
# 2) one query which fetched all the raw events for a specific event type and then inspected
# the deserialized event object to aggregate stats derived from sequences of events.
# (e.g. marker events, materializations, expectations resuls, attempts timing, etc.)
#
# For simplicity, we now just do the second type of query and derive the stats in Python
# from the raw events. This has the benefit of being easier to read and also the benefit of
# being able to share code with the in-memory event log storage implementation. We may
# choose to revisit this in the future, especially if we are able to do JSON-column queries
# in SQL as a way of bypassing the serdes layer in all cases.
raw_event_query = (
db_select([SqlEventLogStorageTable.c.event])
.where(SqlEventLogStorageTable.c.run_id == run_id)
.where(SqlEventLogStorageTable.c.step_key != None) # noqa: E711
.where(
SqlEventLogStorageTable.c.dagster_event_type.in_(
[
DagsterEventType.STEP_START.value,
DagsterEventType.STEP_SUCCESS.value,
DagsterEventType.STEP_SKIPPED.value,
DagsterEventType.STEP_FAILURE.value,
DagsterEventType.STEP_RESTARTED.value,
DagsterEventType.ASSET_MATERIALIZATION.value,
DagsterEventType.STEP_EXPECTATION_RESULT.value,
DagsterEventType.STEP_RESTARTED.value,
DagsterEventType.STEP_UP_FOR_RETRY.value,
]
+ [marker_event.value for marker_event in MARKER_EVENTS]
)
)
.order_by(SqlEventLogStorageTable.c.id.asc())
)
if step_keys:
raw_event_query = raw_event_query.where(
SqlEventLogStorageTable.c.step_key.in_(step_keys)
)
with self.run_connection(run_id) as conn:
results = conn.execute(raw_event_query).fetchall()
try:
records = [deserialize_value(json_str, EventLogEntry) for (json_str,) in results]
return build_run_step_stats_from_events(run_id, records)
except (seven.JSONDecodeError, DeserializationError) as err:
raise DagsterEventLogInvalidForRun(run_id=run_id) from err
def _apply_migration(self, migration_name, migration_fn, print_fn, force):
if self.has_secondary_index(migration_name):
if not force:
if print_fn:
print_fn(f"Skipping already applied data migration: {migration_name}")
return
if print_fn:
print_fn(f"Starting data migration: {migration_name}")
migration_fn()(self, print_fn)
self.enable_secondary_index(migration_name)
if print_fn:
print_fn(f"Finished data migration: {migration_name}")
def reindex_events(self, print_fn: Optional[PrintFn] = None, force: bool = False) -> None:
"""Call this method to run any data migrations across the event_log table."""
for migration_name, migration_fn in EVENT_LOG_DATA_MIGRATIONS.items():
self._apply_migration(migration_name, migration_fn, print_fn, force)
def reindex_assets(self, print_fn: Optional[PrintFn] = None, force: bool = False) -> None:
"""Call this method to run any data migrations across the asset_keys table."""
for migration_name, migration_fn in ASSET_DATA_MIGRATIONS.items():
self._apply_migration(migration_name, migration_fn, print_fn, force)
def wipe(self) -> None:
"""Clears the event log storage."""
# Should be overridden by SqliteEventLogStorage and other storages that shard based on
# run_id
# https://stackoverflow.com/a/54386260/324449
with self.run_connection(run_id=None) as conn:
conn.execute(SqlEventLogStorageTable.delete())
conn.execute(AssetKeyTable.delete())
if self.has_table("asset_event_tags"):
conn.execute(AssetEventTagsTable.delete())
if self.has_table("dynamic_partitions"):
conn.execute(DynamicPartitionsTable.delete())
if self.has_table("concurrency_limits"):
conn.execute(ConcurrencyLimitsTable.delete())
if self.has_table("concurrency_slots"):
conn.execute(ConcurrencySlotsTable.delete())
if self.has_table("pending_steps"):
conn.execute(PendingStepsTable.delete())
if self.has_table("asset_check_executions"):
conn.execute(AssetCheckExecutionsTable.delete())
self._wipe_index()
def _wipe_index(self):
with self.index_connection() as conn:
conn.execute(SqlEventLogStorageTable.delete())
conn.execute(AssetKeyTable.delete())
if self.has_table("asset_event_tags"):
conn.execute(AssetEventTagsTable.delete())
if self.has_table("dynamic_partitions"):
conn.execute(DynamicPartitionsTable.delete())
if self.has_table("concurrency_slots"):
conn.execute(ConcurrencySlotsTable.delete())
if self.has_table("pending_steps"):
conn.execute(PendingStepsTable.delete())
if self.has_table("asset_check_executions"):
conn.execute(AssetCheckExecutionsTable.delete())
def delete_events(self, run_id: str) -> None:
with self.run_connection(run_id) as conn:
self.delete_events_for_run(conn, run_id)
with self.index_connection() as conn:
self.delete_events_for_run(conn, run_id)
if self.supports_global_concurrency_limits:
self.free_concurrency_slots_for_run(run_id)
def delete_events_for_run(self, conn: Connection, run_id: str) -> None:
check.str_param(run_id, "run_id")
records = conn.execute(
db_select([SqlEventLogStorageTable.c.id]).where(
db.and_(
SqlEventLogStorageTable.c.run_id == run_id,
db.or_(
SqlEventLogStorageTable.c.dagster_event_type
== DagsterEventType.ASSET_MATERIALIZATION.value,
SqlEventLogStorageTable.c.dagster_event_type
== DagsterEventType.ASSET_OBSERVATION.value,
),
)
)
).fetchall()
asset_event_ids = [record[0] for record in records]
conn.execute(
SqlEventLogStorageTable.delete().where(SqlEventLogStorageTable.c.run_id == run_id)
)
if asset_event_ids:
conn.execute(
AssetEventTagsTable.delete().where(
AssetEventTagsTable.c.event_id.in_(asset_event_ids)
)
)
@property
def is_persistent(self) -> bool:
return True
def update_event_log_record(self, record_id: int, event: EventLogEntry) -> None:
"""Utility method for migration scripts to update SQL representation of event records."""
check.int_param(record_id, "record_id")
check.inst_param(event, "event", EventLogEntry)
dagster_event_type = None
asset_key_str = None
if event.is_dagster_event:
dagster_event_type = event.dagster_event.event_type_value # type: ignore
if event.dagster_event.asset_key: # type: ignore
check.inst_param(event.dagster_event.asset_key, "asset_key", AssetKey) # type: ignore
asset_key_str = event.dagster_event.asset_key.to_string() # type: ignore
with self.run_connection(run_id=event.run_id) as conn:
conn.execute(
SqlEventLogStorageTable.update()
.where(SqlEventLogStorageTable.c.id == record_id)
.values(
event=serialize_value(event),
dagster_event_type=dagster_event_type,
timestamp=self._event_insert_timestamp(event),
step_key=event.step_key,
asset_key=asset_key_str,
)
)
def get_event_log_table_data(self, run_id: str, record_id: int) -> Optional[SqlAlchemyRow]:
"""Utility method to test representation of the record in the SQL table. Returns all of
the columns stored in the event log storage (as opposed to the deserialized `EventLogEntry`).
This allows checking that certain fields are extracted to support performant lookups (e.g.
extracting `step_key` for fast filtering).
"""
with self.run_connection(run_id=run_id) as conn:
query = (
db_select([SqlEventLogStorageTable])
.where(SqlEventLogStorageTable.c.id == record_id)
.order_by(SqlEventLogStorageTable.c.id.asc())
)
return conn.execute(query).fetchone()
def has_secondary_index(self, name: str) -> bool:
"""This method uses a checkpoint migration table to see if summary data has been constructed
in a secondary index table. Can be used to checkpoint event_log data migrations.
"""
query = (
db_select([1])
.where(SecondaryIndexMigrationTable.c.name == name)
.where(SecondaryIndexMigrationTable.c.migration_completed != None) # noqa: E711
.limit(1)
)
with self.index_connection() as conn:
results = conn.execute(query).fetchall()
return len(results) > 0
def enable_secondary_index(self, name: str) -> None:
"""This method marks an event_log data migration as complete, to indicate that a summary
data migration is complete.
"""
query = SecondaryIndexMigrationTable.insert().values(
name=name,
migration_completed=datetime.now(),
)
with self.index_connection() as conn:
try:
conn.execute(query)
except db_exc.IntegrityError:
conn.execute(
SecondaryIndexMigrationTable.update()
.where(SecondaryIndexMigrationTable.c.name == name)
.values(migration_completed=datetime.now())
)
def _apply_filter_to_query(
self,
query: SqlAlchemyQuery,
event_records_filter: EventRecordsFilter,
asset_details: Optional[AssetDetails] = None,
apply_cursor_filters: bool = True,
) -> SqlAlchemyQuery:
query = query.where(
SqlEventLogStorageTable.c.dagster_event_type == event_records_filter.event_type.value
)
if event_records_filter.asset_key:
query = query.where(
SqlEventLogStorageTable.c.asset_key == event_records_filter.asset_key.to_string(),
)
if event_records_filter.asset_partitions:
query = query.where(
SqlEventLogStorageTable.c.partition.in_(event_records_filter.asset_partitions)
)
if asset_details and asset_details.last_wipe_timestamp:
query = query.where(
SqlEventLogStorageTable.c.timestamp
> datetime.utcfromtimestamp(asset_details.last_wipe_timestamp)
)
if apply_cursor_filters:
# allow the run-sharded sqlite implementation to disable this cursor filtering so that
# it can implement its own custom cursor logic, as cursor ids are not unique across run
# shards
if event_records_filter.before_cursor is not None:
before_cursor_id = (
event_records_filter.before_cursor.id
if isinstance(event_records_filter.before_cursor, RunShardedEventsCursor)
else event_records_filter.before_cursor
)
query = query.where(SqlEventLogStorageTable.c.id < before_cursor_id)
if event_records_filter.after_cursor is not None:
after_cursor_id = (
event_records_filter.after_cursor.id
if isinstance(event_records_filter.after_cursor, RunShardedEventsCursor)
else event_records_filter.after_cursor
)
query = query.where(SqlEventLogStorageTable.c.id > after_cursor_id)
if event_records_filter.before_timestamp:
query = query.where(
SqlEventLogStorageTable.c.timestamp
< datetime.utcfromtimestamp(event_records_filter.before_timestamp)
)
if event_records_filter.after_timestamp:
query = query.where(
SqlEventLogStorageTable.c.timestamp
> datetime.utcfromtimestamp(event_records_filter.after_timestamp)
)
if event_records_filter.storage_ids:
query = query.where(SqlEventLogStorageTable.c.id.in_(event_records_filter.storage_ids))
if event_records_filter.tags and self.has_table(AssetEventTagsTable.name):
# If we don't have the tags table, we'll filter the results after the query
check.invariant(
isinstance(event_records_filter.asset_key, AssetKey),
"Asset key must be set in event records filter to filter by tags.",
)
return query
def _apply_tags_table_joins(
self,
table: db.Table,
tags: Mapping[str, Union[str, Sequence[str]]],
asset_key: Optional[AssetKey],
) -> db.Table:
event_id_col = table.c.id if table == SqlEventLogStorageTable else table.c.event_id
i = 0
for key, value in tags.items():
i += 1
tags_table = db_subquery(
db_select([AssetEventTagsTable]), f"asset_event_tags_subquery_{i}"
)
table = table.join(
tags_table,
db.and_(
event_id_col == tags_table.c.event_id,
not asset_key or tags_table.c.asset_key == asset_key.to_string(),
tags_table.c.key == key,
(
tags_table.c.value == value
if isinstance(value, str)
else tags_table.c.value.in_(value)
),
),
)
return table
def get_event_records(
self,
event_records_filter: EventRecordsFilter,
limit: Optional[int] = None,
ascending: bool = False,
) -> Sequence[EventLogRecord]:
return self._get_event_records(
event_records_filter=event_records_filter, limit=limit, ascending=ascending
)
def _get_event_records(
self,
event_records_filter: EventRecordsFilter,
limit: Optional[int] = None,
ascending: bool = False,
) -> Sequence[EventLogRecord]:
"""Returns a list of (record_id, record)."""
check.inst_param(event_records_filter, "event_records_filter", EventRecordsFilter)
check.opt_int_param(limit, "limit")
check.bool_param(ascending, "ascending")
if event_records_filter.asset_key:
asset_details = next(iter(self._get_assets_details([event_records_filter.asset_key])))
else:
asset_details = None
if event_records_filter.tags and self.has_table(AssetEventTagsTable.name):
table = self._apply_tags_table_joins(
SqlEventLogStorageTable, event_records_filter.tags, event_records_filter.asset_key
)
else:
table = SqlEventLogStorageTable
query = db_select(
[SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event]
).select_from(table)
query = self._apply_filter_to_query(
query=query,
event_records_filter=event_records_filter,
asset_details=asset_details,
)
if limit:
query = query.limit(limit)
if ascending:
query = query.order_by(SqlEventLogStorageTable.c.id.asc())
else:
query = query.order_by(SqlEventLogStorageTable.c.id.desc())
with self.index_connection() as conn:
results = conn.execute(query).fetchall()
event_records = []
for row_id, json_str in results:
try:
event_record = deserialize_value(json_str, NamedTuple)
if not isinstance(event_record, EventLogEntry):
logging.warning(
"Could not resolve event record as EventLogEntry for id `%s`.", row_id
)
continue
if event_records_filter.tags and not self.has_table(AssetEventTagsTable.name):
# If we can't filter tags via the tags table, filter the returned records
if limit is not None:
raise DagsterInvalidInvocationError(
"Cannot filter events on tags with a limit, without the asset event "
"tags table. To fix, run `dagster instance migrate`."
)
event_record_tags = event_record.tags
if not event_record_tags or any(
event_record_tags.get(k) != v for k, v in event_records_filter.tags.items()
):
continue
event_records.append(
EventLogRecord(storage_id=row_id, event_log_entry=event_record)
)
except seven.JSONDecodeError:
logging.warning("Could not parse event record id `%s`.", row_id)
return event_records
def supports_event_consumer_queries(self) -> bool:
return True
def _get_event_records_result(
self,
event_records_filter: EventRecordsFilter,
limit: int,
cursor: Optional[str],
ascending: bool,
):
records = self._get_event_records(
event_records_filter=event_records_filter,
limit=limit,
ascending=ascending,
)
if records:
new_cursor = EventLogCursor.from_storage_id(records[-1].storage_id).to_string()
elif cursor:
new_cursor = cursor
else:
new_cursor = EventLogCursor.from_storage_id(-1).to_string()
has_more = len(records) == limit
return EventRecordsResult(records, cursor=new_cursor, has_more=has_more)
def fetch_materializations(
self,
records_filter: Union[AssetKey, AssetRecordsFilter],
limit: int,
cursor: Optional[str] = None,
ascending: bool = False,
) -> EventRecordsResult:
enforce_max_records_limit(limit)
if isinstance(records_filter, AssetRecordsFilter):
event_records_filter = records_filter.to_event_records_filter(
event_type=DagsterEventType.ASSET_MATERIALIZATION,
cursor=cursor,
ascending=ascending,
)
else:
before_cursor, after_cursor = EventRecordsFilter.get_cursor_params(cursor, ascending)
asset_key = records_filter
event_records_filter = EventRecordsFilter(
event_type=DagsterEventType.ASSET_MATERIALIZATION,
asset_key=asset_key,
before_cursor=before_cursor,
after_cursor=after_cursor,
)
return self._get_event_records_result(event_records_filter, limit, cursor, ascending)
def fetch_observations(
self,
records_filter: Union[AssetKey, AssetRecordsFilter],
limit: int,
cursor: Optional[str] = None,
ascending: bool = False,
) -> EventRecordsResult:
enforce_max_records_limit(limit)
if isinstance(records_filter, AssetRecordsFilter):
event_records_filter = records_filter.to_event_records_filter(
event_type=DagsterEventType.ASSET_OBSERVATION,
cursor=cursor,
ascending=ascending,
)
else:
before_cursor, after_cursor = EventRecordsFilter.get_cursor_params(cursor, ascending)
asset_key = records_filter
event_records_filter = EventRecordsFilter(
event_type=DagsterEventType.ASSET_OBSERVATION,
asset_key=asset_key,
before_cursor=before_cursor,
after_cursor=after_cursor,
)
return self._get_event_records_result(event_records_filter, limit, cursor, ascending)
def fetch_run_status_changes(
self,
records_filter: Union[DagsterEventType, RunStatusChangeRecordsFilter],
limit: int,
cursor: Optional[str] = None,
ascending: bool = False,
) -> EventRecordsResult:
enforce_max_records_limit(limit)
event_type = (
records_filter
if isinstance(records_filter, DagsterEventType)
else records_filter.event_type
)
if event_type not in EVENT_TYPE_TO_PIPELINE_RUN_STATUS:
expected = ", ".join(EVENT_TYPE_TO_PIPELINE_RUN_STATUS.keys())
check.failed(f"Expected one of {expected}, received {event_type.value}")
before_cursor, after_cursor = EventRecordsFilter.get_cursor_params(cursor, ascending)
event_records_filter = (
records_filter.to_event_records_filter(cursor, ascending)
if isinstance(records_filter, RunStatusChangeRecordsFilter)
else EventRecordsFilter(
event_type, before_cursor=before_cursor, after_cursor=after_cursor
)
)
return self._get_event_records_result(event_records_filter, limit, cursor, ascending)
def get_logs_for_all_runs_by_log_id(
self,
after_cursor: int = -1,
dagster_event_type: Optional[Union[DagsterEventType, Set[DagsterEventType]]] = None,
limit: Optional[int] = None,
) -> Mapping[int, EventLogEntry]:
check.int_param(after_cursor, "after_cursor")
check.invariant(
after_cursor >= -1,
f"Don't know what to do with negative cursor {after_cursor}",
)
dagster_event_types = (
{dagster_event_type}
if isinstance(dagster_event_type, DagsterEventType)
else check.opt_set_param(
dagster_event_type, "dagster_event_type", of_type=DagsterEventType
)
)
query = (
db_select([SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event])
.where(SqlEventLogStorageTable.c.id > after_cursor)
.order_by(SqlEventLogStorageTable.c.id.asc())
)
if dagster_event_types:
query = query.where(
SqlEventLogStorageTable.c.dagster_event_type.in_(
[dagster_event_type.value for dagster_event_type in dagster_event_types]
)
)
if limit:
query = query.limit(limit)
with self.index_connection() as conn:
results = conn.execute(query).fetchall()
events = {}
record_id = None
try:
for (
record_id,
json_str,
) in results:
events[record_id] = deserialize_value(json_str, EventLogEntry)
except (seven.JSONDecodeError, DeserializationError):
logging.warning("Could not parse event record id `%s`.", record_id)
return events
def get_maximum_record_id(self) -> Optional[int]:
with self.index_connection() as conn:
result = conn.execute(db_select([db.func.max(SqlEventLogStorageTable.c.id)])).fetchone()
return result[0] # type: ignore
def _construct_asset_record_from_row(
self,
row,
last_materialization_record: Optional[EventLogRecord],
can_cache_asset_status_data: bool,
) -> AssetRecord:
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue
asset_key = AssetKey.from_db_string(row["asset_key"])
if asset_key:
return AssetRecord(
storage_id=row["id"],
asset_entry=AssetEntry(
asset_key=asset_key,
last_materialization_record=last_materialization_record,
last_run_id=row["last_run_id"],
asset_details=AssetDetails.from_db_string(row["asset_details"]),
cached_status=(
AssetStatusCacheValue.from_db_string(row["cached_status_data"])
if can_cache_asset_status_data
else None
),
),
)
else:
check.failed("Row did not contain asset key.")
def _get_latest_materialization_records(
self, raw_asset_rows
) -> Mapping[AssetKey, Optional[EventLogRecord]]:
# Given a list of raw asset rows, returns a mapping of asset key to latest asset materialization
# event log entry. Fetches backcompat EventLogEntry records when the last_materialization
# in the raw asset row is an AssetMaterialization.
to_backcompat_fetch = set()
results: Dict[AssetKey, Optional[EventLogRecord]] = {}
for row in raw_asset_rows:
asset_key = AssetKey.from_db_string(row["asset_key"])
if not asset_key:
continue
event_or_materialization = (
deserialize_value(row["last_materialization"], NamedTuple)
if row["last_materialization"]
else None
)
if isinstance(event_or_materialization, EventLogRecord):
results[asset_key] = event_or_materialization
else:
to_backcompat_fetch.add(asset_key)
latest_event_subquery = db_subquery(
db_select(
[
SqlEventLogStorageTable.c.asset_key,
db.func.max(SqlEventLogStorageTable.c.id).label("id"),
]
)
.where(
db.and_(
SqlEventLogStorageTable.c.asset_key.in_(
[asset_key.to_string() for asset_key in to_backcompat_fetch]
),
SqlEventLogStorageTable.c.dagster_event_type
== DagsterEventType.ASSET_MATERIALIZATION.value,
)
)
.group_by(SqlEventLogStorageTable.c.asset_key),
"latest_event_subquery",
)
backcompat_query = db_select(
[
SqlEventLogStorageTable.c.asset_key,
SqlEventLogStorageTable.c.id,
SqlEventLogStorageTable.c.event,
]
).select_from(
latest_event_subquery.join(
SqlEventLogStorageTable,
db.and_(
SqlEventLogStorageTable.c.asset_key == latest_event_subquery.c.asset_key,
SqlEventLogStorageTable.c.id == latest_event_subquery.c.id,
),
)
)
with self.index_connection() as conn:
event_rows = db_fetch_mappings(conn, backcompat_query)
for row in event_rows:
asset_key = AssetKey.from_db_string(cast(Optional[str], row["asset_key"]))
if asset_key:
results[asset_key] = EventLogRecord(
storage_id=cast(int, row["id"]),
event_log_entry=deserialize_value(cast(str, row["event"]), EventLogEntry),
)
return results
def can_cache_asset_status_data(self) -> bool:
return self.has_asset_key_col("cached_status_data")
def wipe_asset_cached_status(self, asset_key: AssetKey) -> None:
if self.can_cache_asset_status_data():
check.inst_param(asset_key, "asset_key", AssetKey)
with self.index_connection() as conn:
conn.execute(
AssetKeyTable.update()
.values(dict(cached_status_data=None))
.where(
AssetKeyTable.c.asset_key == asset_key.to_string(),
)
)
def get_asset_records(
self, asset_keys: Optional[Sequence[AssetKey]] = None
) -> Sequence[AssetRecord]:
rows = self._fetch_asset_rows(asset_keys=asset_keys)
latest_materialization_records = self._get_latest_materialization_records(rows)
can_cache_asset_status_data = self.can_cache_asset_status_data()
asset_records: List[AssetRecord] = []
for row in rows:
asset_key = AssetKey.from_db_string(row["asset_key"])
if asset_key:
asset_records.append(
self._construct_asset_record_from_row(
row,
latest_materialization_records.get(asset_key),
can_cache_asset_status_data,
)
)
return asset_records
def has_asset_key(self, asset_key: AssetKey) -> bool:
check.inst_param(asset_key, "asset_key", AssetKey)
rows = self._fetch_asset_rows(asset_keys=[asset_key])
return bool(rows)
def all_asset_keys(self):
rows = self._fetch_asset_rows()
asset_keys = [
AssetKey.from_db_string(row["asset_key"])
for row in sorted(rows, key=lambda x: x["asset_key"])
]
return [asset_key for asset_key in asset_keys if asset_key]
def get_asset_keys(
self,
prefix: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
cursor: Optional[str] = None,
) -> Sequence[AssetKey]:
rows = self._fetch_asset_rows(prefix=prefix, limit=limit, cursor=cursor)
asset_keys = [
AssetKey.from_db_string(row["asset_key"])
for row in sorted(rows, key=lambda x: x["asset_key"])
]
return [asset_key for asset_key in asset_keys if asset_key]
def get_latest_materialization_events(
self, asset_keys: Iterable[AssetKey]
) -> Mapping[AssetKey, Optional[EventLogEntry]]:
check.iterable_param(asset_keys, "asset_keys", AssetKey)
rows = self._fetch_asset_rows(asset_keys=asset_keys)
return {
asset_key: event_log_record.event_log_entry if event_log_record is not None else None
for asset_key, event_log_record in self._get_latest_materialization_records(
rows
).items()
}
def _fetch_asset_rows(
self,
asset_keys=None,
prefix: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
cursor: Optional[str] = None,
) -> Sequence[SqlAlchemyRow]:
# fetches rows containing asset_key, last_materialization, and asset_details from the DB,
# applying the filters specified in the arguments.
#
# Differs from _fetch_raw_asset_rows, in that it loops through to make sure enough rows are
# returned to satisfy the limit.
#
# returns a list of rows where each row is a tuple of serialized asset_key, materialization,
# and asset_details
should_query = True
current_cursor = cursor
if self.has_secondary_index(ASSET_KEY_INDEX_COLS):
# if we have migrated, we can limit using SQL
fetch_limit = limit
else:
# if we haven't migrated, overfetch in case the first N results are wiped
fetch_limit = max(limit, MIN_ASSET_ROWS) if limit else None
result = []
while should_query:
rows, has_more, current_cursor = self._fetch_raw_asset_rows(
asset_keys=asset_keys, prefix=prefix, limit=fetch_limit, cursor=current_cursor
)
result.extend(rows)
should_query = bool(has_more) and bool(limit) and len(result) < cast(int, limit)
is_partial_query = asset_keys is not None or bool(prefix) or bool(limit) or bool(cursor)
if not is_partial_query and self._can_mark_assets_as_migrated(rows):
self.enable_secondary_index(ASSET_KEY_INDEX_COLS)
return result[:limit] if limit else result
def _fetch_raw_asset_rows(
self,
asset_keys: Optional[Sequence[AssetKey]] = None,
prefix: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
cursor=None,
) -> Tuple[Iterable[SqlAlchemyRow], bool, Optional[str]]:
# fetches rows containing asset_key, last_materialization, and asset_details from the DB,
# applying the filters specified in the arguments. Does not guarantee that the number of
# rows returned will match the limit specified. This helper function is used to fetch a
# chunk of asset key rows, which may or may not be wiped.
#
# Returns a tuple of (rows, has_more, cursor), where each row is a tuple of serialized
# asset_key, materialization, and asset_details
# TODO update comment
columns = [
AssetKeyTable.c.id,
AssetKeyTable.c.asset_key,
AssetKeyTable.c.last_materialization,
AssetKeyTable.c.last_run_id,
AssetKeyTable.c.asset_details,
]
if self.can_cache_asset_status_data():
columns.extend([AssetKeyTable.c.cached_status_data])
is_partial_query = asset_keys is not None or bool(prefix) or bool(limit) or bool(cursor)
if self.has_asset_key_index_cols() and not is_partial_query:
# if the schema has been migrated, fetch the last_materialization_timestamp to see if
# we can lazily migrate the data table
columns.append(AssetKeyTable.c.last_materialization_timestamp)
columns.append(AssetKeyTable.c.wipe_timestamp)
query = db_select(columns).order_by(AssetKeyTable.c.asset_key.asc())
query = self._apply_asset_filter_to_query(query, asset_keys, prefix, limit, cursor)
if self.has_secondary_index(ASSET_KEY_INDEX_COLS):
query = query.where(
db.or_(
AssetKeyTable.c.wipe_timestamp.is_(None),
AssetKeyTable.c.last_materialization_timestamp > AssetKeyTable.c.wipe_timestamp,
)
)
with self.index_connection() as conn:
rows = db_fetch_mappings(conn, query)
return rows, False, None
with self.index_connection() as conn:
rows = db_fetch_mappings(conn, query)
wiped_timestamps_by_asset_key: Dict[AssetKey, float] = {}
row_by_asset_key: Dict[AssetKey, SqlAlchemyRow] = OrderedDict()
for row in rows:
asset_key = AssetKey.from_db_string(cast(str, row["asset_key"]))
if not asset_key:
continue
asset_details = AssetDetails.from_db_string(row["asset_details"])
if not asset_details or not asset_details.last_wipe_timestamp:
row_by_asset_key[asset_key] = row
continue
materialization_or_event_or_record = (
deserialize_value(cast(str, row["last_materialization"]), NamedTuple)
if row["last_materialization"]
else None
)
if isinstance(materialization_or_event_or_record, (EventLogRecord, EventLogEntry)):
if isinstance(materialization_or_event_or_record, EventLogRecord):
event_timestamp = materialization_or_event_or_record.event_log_entry.timestamp
else:
event_timestamp = materialization_or_event_or_record.timestamp
if asset_details.last_wipe_timestamp > event_timestamp:
# this asset has not been materialized since being wiped, skip
continue
else:
# add the key
row_by_asset_key[asset_key] = row
else:
row_by_asset_key[asset_key] = row
wiped_timestamps_by_asset_key[asset_key] = asset_details.last_wipe_timestamp
if wiped_timestamps_by_asset_key:
materialization_times = self._fetch_backcompat_materialization_times(
wiped_timestamps_by_asset_key.keys() # type: ignore
)
for asset_key, wiped_timestamp in wiped_timestamps_by_asset_key.items():
materialization_time = materialization_times.get(asset_key)
if not materialization_time or utc_datetime_from_naive(
materialization_time
) < utc_datetime_from_timestamp(wiped_timestamp):
# remove rows that have not been materialized since being wiped
row_by_asset_key.pop(asset_key)
has_more = limit and len(rows) == limit
new_cursor = rows[-1]["id"] if rows else None
return row_by_asset_key.values(), has_more, new_cursor # type: ignore
def update_asset_cached_status_data(
self, asset_key: AssetKey, cache_values: "AssetStatusCacheValue"
) -> None:
if self.can_cache_asset_status_data():
with self.index_connection() as conn:
conn.execute(
AssetKeyTable.update()
.where(
AssetKeyTable.c.asset_key == asset_key.to_string(),
)
.values(cached_status_data=serialize_value(cache_values))
)
def _fetch_backcompat_materialization_times(
self, asset_keys: Sequence[AssetKey]
) -> Mapping[AssetKey, datetime]:
# fetches the latest materialization timestamp for the given asset_keys. Uses the (slower)
# raw event log table.
backcompat_query = (
db_select(
[
SqlEventLogStorageTable.c.asset_key,
db.func.max(SqlEventLogStorageTable.c.timestamp).label("timestamp"),
]
)
.where(
SqlEventLogStorageTable.c.asset_key.in_(
[asset_key.to_string() for asset_key in asset_keys]
)
)
.group_by(SqlEventLogStorageTable.c.asset_key)
.order_by(db.func.max(SqlEventLogStorageTable.c.timestamp).asc())
)
with self.index_connection() as conn:
backcompat_rows = db_fetch_mappings(conn, backcompat_query)
return {
AssetKey.from_db_string(row["asset_key"]): row["timestamp"] for row in backcompat_rows
} # type: ignore
def _can_mark_assets_as_migrated(self, rows):
if not self.has_asset_key_index_cols():
return False
if self.has_secondary_index(ASSET_KEY_INDEX_COLS):
# we have already migrated
return False
for row in rows:
if not _get_from_row(row, "last_materialization_timestamp"):
return False
if _get_from_row(row, "asset_details") and not _get_from_row(row, "wipe_timestamp"):
return False
return True
def _apply_asset_filter_to_query(
self,
query: SqlAlchemyQuery,
asset_keys: Optional[Sequence[AssetKey]] = None,
prefix=None,
limit: Optional[int] = None,
cursor: Optional[str] = None,
) -> SqlAlchemyQuery:
if asset_keys is not None:
query = query.where(
AssetKeyTable.c.asset_key.in_([asset_key.to_string() for asset_key in asset_keys])
)
if prefix:
prefix_str = seven.dumps(prefix)[:-1]
query = query.where(AssetKeyTable.c.asset_key.startswith(prefix_str))
if cursor:
query = query.where(AssetKeyTable.c.asset_key > cursor)
if limit:
query = query.limit(limit)
return query
def _get_assets_details(
self, asset_keys: Sequence[AssetKey]
) -> Sequence[Optional[AssetDetails]]:
check.sequence_param(asset_keys, "asset_key", AssetKey)
rows = None
with self.index_connection() as conn:
rows = db_fetch_mappings(
conn,
db_select([AssetKeyTable.c.asset_key, AssetKeyTable.c.asset_details]).where(
AssetKeyTable.c.asset_key.in_(
[asset_key.to_string() for asset_key in asset_keys]
),
),
)
asset_key_to_details = {
cast(str, row["asset_key"]): (
deserialize_value(cast(str, row["asset_details"]), AssetDetails)
if row["asset_details"]
else None
)
for row in rows
}
# returns a list of the corresponding asset_details to provided asset_keys
return [
asset_key_to_details.get(asset_key.to_string(), None) for asset_key in asset_keys
]
def _add_assets_wipe_filter_to_query(
self,
query: SqlAlchemyQuery,
assets_details: Sequence[Optional[AssetDetails]],
asset_keys: Sequence[AssetKey],
) -> SqlAlchemyQuery:
check.invariant(
len(assets_details) == len(asset_keys),
"asset_details and asset_keys must be the same length",
)
for i in range(len(assets_details)):
asset_key, asset_details = asset_keys[i], assets_details[i]
if asset_details and asset_details.last_wipe_timestamp:
asset_key_in_row = SqlEventLogStorageTable.c.asset_key == asset_key.to_string()
# If asset key is in row, keep the row if the timestamp > wipe timestamp, else remove the row.
# If asset key is not in row, keep the row.
query = query.where(
db.or_(
db.and_(
asset_key_in_row,
SqlEventLogStorageTable.c.timestamp
> datetime.utcfromtimestamp(asset_details.last_wipe_timestamp),
),
db.not_(asset_key_in_row),
)
)
return query
def get_event_tags_for_asset(
self,
asset_key: AssetKey,
filter_tags: Optional[Mapping[str, str]] = None,
filter_event_id: Optional[int] = None,
) -> Sequence[Mapping[str, str]]:
"""Fetches asset event tags for the given asset key.
If filter_tags is provided, searches for events containing all of the filter tags. Then,
returns all tags for those events. This enables searching for multipartitioned asset
partition tags with a fixed dimension value, e.g. all of the tags for events where
"country" == "US".
If filter_event_id is provided, fetches only tags applied to the given event.
Returns a list of dicts, where each dict is a mapping of tag key to tag value for a
single event.
"""
asset_key = check.inst_param(asset_key, "asset_key", AssetKey)
filter_tags = check.opt_mapping_param(
filter_tags, "filter_tags", key_type=str, value_type=str
)
filter_event_id = check.opt_int_param(filter_event_id, "filter_event_id")
if not self.has_table(AssetEventTagsTable.name):
raise DagsterInvalidInvocationError(
"In order to search for asset event tags, you must run "
"`dagster instance migrate` to create the AssetEventTags table."
)
asset_details = self._get_assets_details([asset_key])[0]
if not filter_tags:
tags_query = db_select(
[
AssetEventTagsTable.c.key,
AssetEventTagsTable.c.value,
AssetEventTagsTable.c.event_id,
]
).where(AssetEventTagsTable.c.asset_key == asset_key.to_string())
if asset_details and asset_details.last_wipe_timestamp:
tags_query = tags_query.where(
AssetEventTagsTable.c.event_timestamp
> datetime.utcfromtimestamp(asset_details.last_wipe_timestamp)
)
else:
table = self._apply_tags_table_joins(AssetEventTagsTable, filter_tags, asset_key)
tags_query = db_select(
[
AssetEventTagsTable.c.key,
AssetEventTagsTable.c.value,
AssetEventTagsTable.c.event_id,
]
).select_from(table)
if asset_details and asset_details.last_wipe_timestamp:
tags_query = tags_query.where(
AssetEventTagsTable.c.event_timestamp
> datetime.utcfromtimestamp(asset_details.last_wipe_timestamp)
)
if filter_event_id is not None:
tags_query = tags_query.where(AssetEventTagsTable.c.event_id == filter_event_id)
with self.index_connection() as conn:
results = conn.execute(tags_query).fetchall()
tags_by_event_id: Dict[int, Dict[str, str]] = defaultdict(dict)
for row in results:
key, value, event_id = row
tags_by_event_id[event_id][key] = value
return list(tags_by_event_id.values())
def _asset_materialization_from_json_column(
self, json_str: str
) -> Optional[AssetMaterialization]:
if not json_str:
return None
# We switched to storing the entire event record of the last materialization instead of just
# the AssetMaterialization object, so that we have access to metadata like timestamp,
# pipeline, run_id, etc.
#
# This should make certain asset queries way more performant, without having to do extra
# queries against the event log.
#
# This should be accompanied by a schema change in 0.12.0, renaming `last_materialization`
# to `last_materialization_event`, for clarity. For now, we should do some back-compat.
#
# https://github.com/dagster-io/dagster/issues/3945
event_or_materialization = deserialize_value(json_str, NamedTuple)
if isinstance(event_or_materialization, AssetMaterialization):
return event_or_materialization
if (
not isinstance(event_or_materialization, EventLogEntry)
or not event_or_materialization.is_dagster_event
or not event_or_materialization.dagster_event.asset_key # type: ignore
):
return None
return event_or_materialization.dagster_event.step_materialization_data.materialization # type: ignore
def _get_asset_key_values_on_wipe(self) -> Mapping[str, Any]:
wipe_timestamp = pendulum.now("UTC").timestamp()
values = {
"asset_details": serialize_value(AssetDetails(last_wipe_timestamp=wipe_timestamp)),
"last_run_id": None,
}
if self.has_asset_key_index_cols():
values.update(
dict(
wipe_timestamp=utc_datetime_from_timestamp(wipe_timestamp),
)
)
if self.can_cache_asset_status_data():
values.update(dict(cached_status_data=None))
return values
def wipe_asset(self, asset_key: AssetKey) -> None:
check.inst_param(asset_key, "asset_key", AssetKey)
wiped_values = self._get_asset_key_values_on_wipe()
with self.index_connection() as conn:
conn.execute(
AssetKeyTable.update()
.values(**wiped_values)
.where(
AssetKeyTable.c.asset_key == asset_key.to_string(),
)
)
def get_materialized_partitions(
self,
asset_key: AssetKey,
before_cursor: Optional[int] = None,
after_cursor: Optional[int] = None,
) -> Set[str]:
query = (
db_select(
[
SqlEventLogStorageTable.c.partition,
db.func.max(SqlEventLogStorageTable.c.id),
]
)
.where(
db.and_(
SqlEventLogStorageTable.c.asset_key == asset_key.to_string(),
SqlEventLogStorageTable.c.partition != None, # noqa: E711
SqlEventLogStorageTable.c.dagster_event_type
== DagsterEventType.ASSET_MATERIALIZATION.value,
)
)
.group_by(SqlEventLogStorageTable.c.partition)
)
assets_details = self._get_assets_details([asset_key])
query = self._add_assets_wipe_filter_to_query(query, assets_details, [asset_key])
if after_cursor:
query = query.where(SqlEventLogStorageTable.c.id > after_cursor)
if before_cursor:
query = query.where(SqlEventLogStorageTable.c.id < before_cursor)
with self.index_connection() as conn:
results = conn.execute(query).fetchall()
return set([cast(str, row[0]) for row in results])
def _latest_event_ids_by_partition_subquery(
self,
asset_key: AssetKey,
event_types: Sequence[DagsterEventType],
asset_partitions: Optional[Sequence[str]] = None,
before_cursor: Optional[int] = None,
after_cursor: Optional[int] = None,
):
"""Subquery for locating the latest event ids by partition for a given asset key and set
of event types.
"""
query = db_select(
[
SqlEventLogStorageTable.c.dagster_event_type,
SqlEventLogStorageTable.c.partition,
db.func.max(SqlEventLogStorageTable.c.id).label("id"),
]
).where(
db.and_(
SqlEventLogStorageTable.c.asset_key == asset_key.to_string(),
SqlEventLogStorageTable.c.partition != None, # noqa: E711
SqlEventLogStorageTable.c.dagster_event_type.in_(
[event_type.value for event_type in event_types]
),
)
)
if asset_partitions is not None:
query = query.where(SqlEventLogStorageTable.c.partition.in_(asset_partitions))
if before_cursor is not None:
query = query.where(SqlEventLogStorageTable.c.id < before_cursor)
if after_cursor is not None:
query = query.where(SqlEventLogStorageTable.c.id > after_cursor)
latest_event_ids_subquery = query.group_by(
SqlEventLogStorageTable.c.dagster_event_type, SqlEventLogStorageTable.c.partition
)
assets_details = self._get_assets_details([asset_key])
return db_subquery(
self._add_assets_wipe_filter_to_query(
latest_event_ids_subquery, assets_details, [asset_key]
),
"latest_event_ids_by_partition_subquery",
)
def get_latest_storage_id_by_partition(
self, asset_key: AssetKey, event_type: DagsterEventType
) -> Mapping[str, int]:
"""Fetch the latest materialzation storage id for each partition for a given asset key.
Returns a mapping of partition to storage id.
"""
check.inst_param(asset_key, "asset_key", AssetKey)
latest_event_ids_by_partition_subquery = self._latest_event_ids_by_partition_subquery(
asset_key, [event_type]
)
latest_event_ids_by_partition = db_select(
[
latest_event_ids_by_partition_subquery.c.partition,
latest_event_ids_by_partition_subquery.c.id,
]
)
with self.index_connection() as conn:
rows = conn.execute(latest_event_ids_by_partition).fetchall()
latest_materialization_storage_id_by_partition: Dict[str, int] = {}
for row in rows:
latest_materialization_storage_id_by_partition[cast(str, row[0])] = cast(int, row[1])
return latest_materialization_storage_id_by_partition
def get_latest_tags_by_partition(
self,
asset_key: AssetKey,
event_type: DagsterEventType,
tag_keys: Sequence[str],
asset_partitions: Optional[Sequence[str]] = None,
before_cursor: Optional[int] = None,
after_cursor: Optional[int] = None,
) -> Mapping[str, Mapping[str, str]]:
check.inst_param(asset_key, "asset_key", AssetKey)
check.inst_param(event_type, "event_type", DagsterEventType)
check.sequence_param(tag_keys, "tag_keys", of_type=str)
check.opt_nullable_sequence_param(asset_partitions, "asset_partitions", of_type=str)
check.opt_int_param(before_cursor, "before_cursor")
check.opt_int_param(after_cursor, "after_cursor")
latest_event_ids_subquery = self._latest_event_ids_by_partition_subquery(
asset_key=asset_key,
event_types=[event_type],
asset_partitions=asset_partitions,
before_cursor=before_cursor,
after_cursor=after_cursor,
)
latest_tags_by_partition_query = (
db_select(
[
latest_event_ids_subquery.c.partition,
AssetEventTagsTable.c.key,
AssetEventTagsTable.c.value,
]
)
.select_from(
latest_event_ids_subquery.join(
AssetEventTagsTable,
AssetEventTagsTable.c.event_id == latest_event_ids_subquery.c.id,
)
)
.where(AssetEventTagsTable.c.key.in_(tag_keys))
)
latest_tags_by_partition: Dict[str, Dict[str, str]] = defaultdict(dict)
with self.index_connection() as conn:
rows = conn.execute(latest_tags_by_partition_query).fetchall()
for row in rows:
latest_tags_by_partition[cast(str, row[0])][cast(str, row[1])] = cast(str, row[2])
# convert defaultdict to dict
return dict(latest_tags_by_partition)
def get_latest_asset_partition_materialization_attempts_without_materializations(
self, asset_key: AssetKey, after_storage_id: Optional[int] = None
) -> Mapping[str, Tuple[str, int]]:
"""Fetch the latest materialzation and materialization planned events for each partition of the given asset.
Return the partitions that have a materialization planned event but no matching (same run) materialization event.
These materializations could be in progress, or they could have failed. A separate query checking the run status
is required to know.
Returns a mapping of partition to [run id, event id].
"""
check.inst_param(asset_key, "asset_key", AssetKey)
latest_event_ids_subquery = self._latest_event_ids_by_partition_subquery(
asset_key,
[
DagsterEventType.ASSET_MATERIALIZATION,
DagsterEventType.ASSET_MATERIALIZATION_PLANNED,
],
after_cursor=after_storage_id,
)
latest_events_subquery = db_subquery(
db_select(
[
SqlEventLogStorageTable.c.dagster_event_type,
SqlEventLogStorageTable.c.partition,
SqlEventLogStorageTable.c.run_id,
SqlEventLogStorageTable.c.id,
]
).select_from(
latest_event_ids_subquery.join(
SqlEventLogStorageTable,
SqlEventLogStorageTable.c.id == latest_event_ids_subquery.c.id,
),
),
"latest_events_subquery",
)
materialization_planned_events = db_select(
[
latest_events_subquery.c.dagster_event_type,
latest_events_subquery.c.partition,
latest_events_subquery.c.run_id,
latest_events_subquery.c.id,
]
).where(
latest_events_subquery.c.dagster_event_type
== DagsterEventType.ASSET_MATERIALIZATION_PLANNED.value
)
materialization_events = db_select(
[
latest_events_subquery.c.dagster_event_type,
latest_events_subquery.c.partition,
latest_events_subquery.c.run_id,
]
).where(
latest_events_subquery.c.dagster_event_type
== DagsterEventType.ASSET_MATERIALIZATION.value
)
with self.index_connection() as conn:
materialization_planned_rows = db_fetch_mappings(conn, materialization_planned_events)
materialization_rows = db_fetch_mappings(conn, materialization_events)
materialization_planned_rows_by_partition = {
cast(str, row["partition"]): (cast(str, row["run_id"]), cast(int, row["id"]))
for row in materialization_planned_rows
}
for row in materialization_rows:
if (
row["partition"] in materialization_planned_rows_by_partition
and materialization_planned_rows_by_partition[cast(str, row["partition"])][0]
== row["run_id"]
):
materialization_planned_rows_by_partition.pop(cast(str, row["partition"]))
return materialization_planned_rows_by_partition
def _check_partitions_table(self) -> None:
# Guards against cases where the user is not running the latest migration for
# partitions storage. Should be updated when the partitions storage schema changes.
if not self.has_table("dynamic_partitions"):
raise DagsterInvalidInvocationError(
"Using dynamic partitions definitions requires the dynamic partitions table, which"
" currently does not exist. Add this table by running `dagster"
" instance migrate`."
)
def get_dynamic_partitions(self, partitions_def_name: str) -> Sequence[str]:
"""Get the list of partition keys for a partition definition."""
self._check_partitions_table()
columns = [
DynamicPartitionsTable.c.partitions_def_name,
DynamicPartitionsTable.c.partition,
]
query = (
db_select(columns)
.where(DynamicPartitionsTable.c.partitions_def_name == partitions_def_name)
.order_by(DynamicPartitionsTable.c.id)
)
with self.index_connection() as conn:
rows = conn.execute(query).fetchall()
return [cast(str, row[1]) for row in rows]
def has_dynamic_partition(self, partitions_def_name: str, partition_key: str) -> bool:
self._check_partitions_table()
query = (
db_select([DynamicPartitionsTable.c.partition])
.where(
db.and_(
DynamicPartitionsTable.c.partitions_def_name == partitions_def_name,
DynamicPartitionsTable.c.partition == partition_key,
)
)
.limit(1)
)
with self.index_connection() as conn:
results = conn.execute(query).fetchall()
return len(results) > 0
def add_dynamic_partitions(
self, partitions_def_name: str, partition_keys: Sequence[str]
) -> None:
self._check_partitions_table()
with self.index_connection() as conn:
existing_rows = conn.execute(
db_select([DynamicPartitionsTable.c.partition]).where(
db.and_(
DynamicPartitionsTable.c.partition.in_(partition_keys),
DynamicPartitionsTable.c.partitions_def_name == partitions_def_name,
)
)
).fetchall()
existing_keys = set([row[0] for row in existing_rows])
new_keys = [
partition_key
for partition_key in partition_keys
if partition_key not in existing_keys
]
if new_keys:
conn.execute(
DynamicPartitionsTable.insert(),
[
dict(partitions_def_name=partitions_def_name, partition=partition_key)
for partition_key in new_keys
],
)
def delete_dynamic_partition(self, partitions_def_name: str, partition_key: str) -> None:
self._check_partitions_table()
with self.index_connection() as conn:
conn.execute(
DynamicPartitionsTable.delete().where(
db.and_(
DynamicPartitionsTable.c.partitions_def_name == partitions_def_name,
DynamicPartitionsTable.c.partition == partition_key,
)
)
)
@property
def supports_global_concurrency_limits(self) -> bool:
return self.has_table(ConcurrencySlotsTable.name)
def _reconcile_concurrency_limits_from_slots(self) -> None:
"""Helper function that can be reconciles the concurrency limits table from the concurrency
slots table. This should only run when the concurrency limits table exists and is empty,
since all of the slot configuration operations should keep them in sync. We reconcile from
the slots table because the initial implementation did not have the limits table.
"""
if not self.has_table(ConcurrencyLimitsTable.name):
return
if not self._has_rows(ConcurrencySlotsTable) or self._has_rows(ConcurrencyLimitsTable):
return
with self.index_transaction() as conn:
rows = conn.execute(
db_select(
[
ConcurrencySlotsTable.c.concurrency_key,
db.func.count().label("count"),
]
)
.where(
ConcurrencySlotsTable.c.deleted == False, # noqa: E712
)
.group_by(
ConcurrencySlotsTable.c.concurrency_key,
)
).fetchall()
conn.execute(
ConcurrencyLimitsTable.insert().values(
[
{
"concurrency_key": row[0],
"limit": row[1],
}
for row in rows
]
)
)
def _has_rows(self, table) -> bool:
with self.index_connection() as conn:
row = conn.execute(db_select([True]).select_from(table).limit(1)).fetchone()
return bool(row[0]) if row else False
def initialize_concurrency_limit_to_default(self, concurrency_key: str) -> bool:
if not self.has_table(ConcurrencyLimitsTable.name):
return False
self._reconcile_concurrency_limits_from_slots()
if not self.has_instance:
return False
default_limit = self._instance.global_op_concurrency_default_limit
if default_limit is None:
return False
with self.index_transaction() as conn:
try:
conn.execute(
ConcurrencyLimitsTable.insert().values(
concurrency_key=concurrency_key, limit=default_limit
)
)
self._allocate_concurrency_slots(conn, concurrency_key, default_limit)
except db_exc.IntegrityError:
pass
return True
def _upsert_and_lock_limit_row(self, conn, concurrency_key: str, num: int):
"""Helper function that can be overridden by each implementing sql variant which obtains a
lock on the concurrency limits row for the given key and updates it to the given value.
"""
if not self.has_table(ConcurrencyLimitsTable.name):
# no need to grab the lock on the concurrency limits row if the table does not exist
return None
row = conn.execute(
db_select([ConcurrencyLimitsTable.c.id])
.select_from(ConcurrencyLimitsTable)
.where(ConcurrencyLimitsTable.c.concurrency_key == concurrency_key)
.with_for_update()
.limit(1)
).fetchone()
if not row:
conn.execute(
ConcurrencyLimitsTable.insert().values(concurrency_key=concurrency_key, limit=num)
)
else:
conn.execute(
ConcurrencyLimitsTable.update()
.where(ConcurrencyLimitsTable.c.concurrency_key == concurrency_key)
.values(limit=num)
)
def set_concurrency_slots(self, concurrency_key: str, num: int) -> None:
"""Allocate a set of concurrency slots.
Args:
concurrency_key (str): The key to allocate the slots for.
num (int): The number of slots to allocate.
"""
max_limit = get_max_concurrency_limit_value()
if num > max_limit:
raise DagsterInvalidInvocationError(
f"Cannot have more than {max_limit} slots per concurrency key."
)
if num < 0:
raise DagsterInvalidInvocationError("Cannot have a negative number of slots.")
# ensure that we have concurrency limits set for all keys
self._reconcile_concurrency_limits_from_slots()
with self.index_transaction() as conn:
self._upsert_and_lock_limit_row(conn, concurrency_key, num)
keys_to_assign = self._allocate_concurrency_slots(conn, concurrency_key, num)
if keys_to_assign:
# we've added some slots... if there are any pending steps, we can assign them now or
# they will be unutilized until free_concurrency_slots is called
self.assign_pending_steps(keys_to_assign)
def delete_concurrency_limit(self, concurrency_key: str) -> None:
"""Delete a concurrency limit and its associated slots.
Args:
concurrency_key (str): The key to delete.
"""
# ensure that we have concurrency limits set for all keys
self._reconcile_concurrency_limits_from_slots()
with self.index_transaction() as conn:
if self.has_table(ConcurrencyLimitsTable.name):
conn.execute(
ConcurrencyLimitsTable.delete().where(
ConcurrencyLimitsTable.c.concurrency_key == concurrency_key
)
)
self._allocate_concurrency_slots(conn, concurrency_key, 0)
def _allocate_concurrency_slots(self, conn, concurrency_key: str, num: int) -> List[str]:
keys_to_assign = []
count_row = conn.execute(
db_select([db.func.count()])
.select_from(ConcurrencySlotsTable)
.where(
db.and_(
ConcurrencySlotsTable.c.concurrency_key == concurrency_key,
ConcurrencySlotsTable.c.deleted == False, # noqa: E712
)
)
).fetchone()
existing = cast(int, count_row[0]) if count_row else 0
if existing > num:
# need to delete some slots, favoring ones where the slot is unallocated
rows = conn.execute(
db_select([ConcurrencySlotsTable.c.id])
.select_from(ConcurrencySlotsTable)
.where(
db.and_(
ConcurrencySlotsTable.c.concurrency_key == concurrency_key,
ConcurrencySlotsTable.c.deleted == False, # noqa: E712
)
)
.order_by(
db_case([(ConcurrencySlotsTable.c.run_id.is_(None), 1)], else_=0).desc(),
ConcurrencySlotsTable.c.id.desc(),
)
.limit(existing - num)
).fetchall()
if rows:
# mark rows as deleted
conn.execute(
ConcurrencySlotsTable.update()
.values(deleted=True)
.where(ConcurrencySlotsTable.c.id.in_([row[0] for row in rows]))
)
# actually delete rows that are marked as deleted and are not claimed... the rest
# will be deleted when the slots are released by the free_concurrency_slots
conn.execute(
ConcurrencySlotsTable.delete().where(
db.and_(
ConcurrencySlotsTable.c.deleted == True, # noqa: E712
ConcurrencySlotsTable.c.run_id == None, # noqa: E711
)
)
)
elif num > existing:
# need to add some slots
rows = [
{
"concurrency_key": concurrency_key,
"run_id": None,
"step_key": None,
"deleted": False,
}
for _ in range(existing, num)
]
conn.execute(ConcurrencySlotsTable.insert().values(rows))
keys_to_assign.extend([concurrency_key for _ in range(existing, num)])
return keys_to_assign
def has_unassigned_slots(self, concurrency_key: str) -> bool:
with self.index_connection() as conn:
pending_row = conn.execute(
db_select([db.func.count()])
.select_from(PendingStepsTable)
.where(
db.and_(
PendingStepsTable.c.concurrency_key == concurrency_key,
PendingStepsTable.c.assigned_timestamp != None, # noqa: E711
)
)
).fetchone()
slots = conn.execute(
db_select([db.func.count()])
.select_from(ConcurrencySlotsTable)
.where(
db.and_(
ConcurrencySlotsTable.c.concurrency_key == concurrency_key,
ConcurrencySlotsTable.c.deleted == False, # noqa: E712
)
)
).fetchone()
pending_count = cast(int, pending_row[0]) if pending_row else 0
slots_count = cast(int, slots[0]) if slots else 0
return slots_count > pending_count
def check_concurrency_claim(
self, concurrency_key: str, run_id: str, step_key: str
) -> ConcurrencyClaimStatus:
with self.index_connection() as conn:
pending_row = conn.execute(
db_select(
[
PendingStepsTable.c.assigned_timestamp,
PendingStepsTable.c.priority,
PendingStepsTable.c.create_timestamp,
]
).where(
db.and_(
PendingStepsTable.c.run_id == run_id,
PendingStepsTable.c.step_key == step_key,
PendingStepsTable.c.concurrency_key == concurrency_key,
)
)
).fetchone()
if not pending_row:
# no pending step pending_row exists, the slot is blocked and the enqueued timestamp is None
return ConcurrencyClaimStatus(
concurrency_key=concurrency_key,
slot_status=ConcurrencySlotStatus.BLOCKED,
priority=None,
assigned_timestamp=None,
enqueued_timestamp=None,
)
priority = cast(int, pending_row[1]) if pending_row[1] else None
assigned_timestamp = cast(datetime, pending_row[0]) if pending_row[0] else None
create_timestamp = cast(datetime, pending_row[2]) if pending_row[2] else None
if assigned_timestamp is None:
return ConcurrencyClaimStatus(
concurrency_key=concurrency_key,
slot_status=ConcurrencySlotStatus.BLOCKED,
priority=priority,
assigned_timestamp=None,
enqueued_timestamp=create_timestamp,
)
# pending step is assigned, check to see if it's been claimed
slot_row = conn.execute(
db_select([db.func.count()]).where(
db.and_(
ConcurrencySlotsTable.c.concurrency_key == concurrency_key,
ConcurrencySlotsTable.c.run_id == run_id,
ConcurrencySlotsTable.c.step_key == step_key,
)
)
).fetchone()
return ConcurrencyClaimStatus(
concurrency_key=concurrency_key,
slot_status=(
ConcurrencySlotStatus.CLAIMED
if slot_row and slot_row[0]
else ConcurrencySlotStatus.BLOCKED
),
priority=priority,
assigned_timestamp=assigned_timestamp,
enqueued_timestamp=create_timestamp,
)
def can_claim_from_pending(self, concurrency_key: str, run_id: str, step_key: str):
with self.index_connection() as conn:
row = conn.execute(
db_select([PendingStepsTable.c.assigned_timestamp]).where(
db.and_(
PendingStepsTable.c.run_id == run_id,
PendingStepsTable.c.step_key == step_key,
PendingStepsTable.c.concurrency_key == concurrency_key,
)
)
).fetchone()
return row and row[0] is not None
def has_pending_step(self, concurrency_key: str, run_id: str, step_key: str):
with self.index_connection() as conn:
row = conn.execute(
db_select([db.func.count()])
.select_from(PendingStepsTable)
.where(
db.and_(
PendingStepsTable.c.concurrency_key == concurrency_key,
PendingStepsTable.c.run_id == run_id,
PendingStepsTable.c.step_key == step_key,
)
)
).fetchone()
return row and cast(int, row[0]) > 0
def assign_pending_steps(self, concurrency_keys: Sequence[str]):
if not concurrency_keys:
return
with self.index_connection() as conn:
for key in concurrency_keys:
row = conn.execute(
db_select([PendingStepsTable.c.id])
.where(
db.and_(
PendingStepsTable.c.concurrency_key == key,
PendingStepsTable.c.assigned_timestamp == None, # noqa: E711
)
)
.order_by(
PendingStepsTable.c.priority.desc(),
PendingStepsTable.c.create_timestamp.asc(),
)
.limit(1)
).fetchone()
if row:
conn.execute(
PendingStepsTable.update()
.where(PendingStepsTable.c.id == row[0])
.values(assigned_timestamp=db.func.now())
)
def add_pending_step(
self,
concurrency_key: str,
run_id: str,
step_key: str,
priority: Optional[int] = None,
should_assign: bool = False,
):
with self.index_connection() as conn:
try:
conn.execute(
PendingStepsTable.insert().values(
[
dict(
run_id=run_id,
step_key=step_key,
concurrency_key=concurrency_key,
priority=priority or 0,
assigned_timestamp=db.func.now() if should_assign else None,
)
]
)
)
except db_exc.IntegrityError:
# do nothing
pass
def _remove_pending_steps(self, run_id: str, step_key: Optional[str] = None) -> Sequence[str]:
# fetch the assigned steps to delete, while grabbing the concurrency keys so that we can
# assign the next set of queued steps, if necessary
select_query = (
db_select(
[
PendingStepsTable.c.id,
PendingStepsTable.c.assigned_timestamp,
PendingStepsTable.c.concurrency_key,
]
)
.select_from(PendingStepsTable)
.where(PendingStepsTable.c.run_id == run_id)
.with_for_update()
)
if step_key:
select_query = select_query.where(PendingStepsTable.c.step_key == step_key)
with self.index_connection() as conn:
rows = conn.execute(select_query).fetchall()
if not rows:
return []
# now, actually delete the pending steps
conn.execute(
PendingStepsTable.delete().where(
PendingStepsTable.c.id.in_([row[0] for row in rows])
)
)
# return the concurrency keys for the freed slots which were assigned
to_assign = [cast(str, row[2]) for row in rows if row[1] is not None]
return to_assign
def claim_concurrency_slot(
self, concurrency_key: str, run_id: str, step_key: str, priority: Optional[int] = None
) -> ConcurrencyClaimStatus:
"""Claim concurrency slot for step.
Args:
concurrency_keys (str): The concurrency key to claim.
run_id (str): The run id to claim for.
step_key (str): The step key to claim for.
"""
# first, register the step by adding to pending queue
if not self.has_pending_step(
concurrency_key=concurrency_key, run_id=run_id, step_key=step_key
):
has_unassigned_slots = self.has_unassigned_slots(concurrency_key)
self.add_pending_step(
concurrency_key=concurrency_key,
run_id=run_id,
step_key=step_key,
priority=priority,
should_assign=has_unassigned_slots,
)
# if the step is not assigned (i.e. has not been popped from queue), block the claim
claim_status = self.check_concurrency_claim(
concurrency_key=concurrency_key, run_id=run_id, step_key=step_key
)
if claim_status.is_claimed or not claim_status.is_assigned:
return claim_status
# attempt to claim a concurrency slot... this should generally work because we only assign
# based on the number of unclaimed slots, but this should act as a safeguard, using the slot
# rows as a semaphore
slot_status = self._claim_concurrency_slot(
concurrency_key=concurrency_key, run_id=run_id, step_key=step_key
)
return claim_status.with_slot_status(slot_status)
def _claim_concurrency_slot(
self, concurrency_key: str, run_id: str, step_key: str
) -> ConcurrencySlotStatus:
"""Claim a concurrency slot for the step. Helper method that is called for steps that are
popped off the priority queue.
Args:
concurrency_key (str): The concurrency key to claim.
run_id (str): The run id to claim a slot for.
step_key (str): The step key to claim a slot for.
"""
with self.index_connection() as conn:
result = conn.execute(
db_select([ConcurrencySlotsTable.c.id])
.select_from(ConcurrencySlotsTable)
.where(
db.and_(
ConcurrencySlotsTable.c.concurrency_key == concurrency_key,
ConcurrencySlotsTable.c.step_key == None, # noqa: E711
ConcurrencySlotsTable.c.deleted == False, # noqa: E712
)
)
.with_for_update(skip_locked=True)
.limit(1)
).fetchone()
if not result or not result[0]:
return ConcurrencySlotStatus.BLOCKED
if not conn.execute(
ConcurrencySlotsTable.update()
.values(run_id=run_id, step_key=step_key)
.where(ConcurrencySlotsTable.c.id == result[0])
).rowcount:
return ConcurrencySlotStatus.BLOCKED
return ConcurrencySlotStatus.CLAIMED
def get_concurrency_keys(self) -> Set[str]:
self._reconcile_concurrency_limits_from_slots()
"""Get the set of concurrency limited keys."""
with self.index_connection() as conn:
if self.has_table(ConcurrencyLimitsTable.name):
query = db_select([ConcurrencyLimitsTable.c.concurrency_key]).select_from(
ConcurrencyLimitsTable
)
else:
query = (
db_select([ConcurrencySlotsTable.c.concurrency_key])
.select_from(ConcurrencySlotsTable)
.where(ConcurrencySlotsTable.c.deleted == False) # noqa: E712
.distinct()
)
rows = conn.execute(query).fetchall()
return {cast(str, row[0]) for row in rows}
def get_concurrency_info(self, concurrency_key: str) -> ConcurrencyKeyInfo:
"""Get the list of concurrency slots for a given concurrency key.
Args:
concurrency_key (str): The concurrency key to get the slots for.
Returns:
List[Tuple[str, int]]: A list of tuples of run_id and the number of slots it is
occupying for the given concurrency key.
"""
with self.index_connection() as conn:
slot_query = (
db_select(
[
ConcurrencySlotsTable.c.run_id,
ConcurrencySlotsTable.c.step_key,
ConcurrencySlotsTable.c.deleted,
]
)
.select_from(ConcurrencySlotsTable)
.where(ConcurrencySlotsTable.c.concurrency_key == concurrency_key)
)
slot_rows = db_fetch_mappings(conn, slot_query)
pending_query = (
db_select(
[
PendingStepsTable.c.run_id,
PendingStepsTable.c.step_key,
PendingStepsTable.c.assigned_timestamp,
PendingStepsTable.c.create_timestamp,
PendingStepsTable.c.priority,
]
)
.select_from(PendingStepsTable)
.where(PendingStepsTable.c.concurrency_key == concurrency_key)
)
pending_rows = db_fetch_mappings(conn, pending_query)
return ConcurrencyKeyInfo(
concurrency_key=concurrency_key,
slot_count=len([slot_row for slot_row in slot_rows if not slot_row["deleted"]]),
claimed_slots=[
ClaimedSlotInfo(slot_row["run_id"], slot_row["step_key"])
for slot_row in slot_rows
if slot_row["run_id"]
],
pending_steps=[
PendingStepInfo(
run_id=row["run_id"],
step_key=row["step_key"],
enqueued_timestamp=utc_datetime_from_naive(row["create_timestamp"]),
assigned_timestamp=utc_datetime_from_naive(row["assigned_timestamp"])
if row["assigned_timestamp"]
else None,
priority=row["priority"],
)
for row in pending_rows
],
)
def get_concurrency_run_ids(self) -> Set[str]:
with self.index_connection() as conn:
rows = conn.execute(db_select([PendingStepsTable.c.run_id]).distinct()).fetchall()
return set([cast(str, row[0]) for row in rows])
def free_concurrency_slots_for_run(self, run_id: str) -> None:
self._free_concurrency_slots(run_id=run_id)
removed_assigned_concurrency_keys = self._remove_pending_steps(run_id=run_id)
if removed_assigned_concurrency_keys:
# assign any pending steps that can now claim a slot
self.assign_pending_steps(removed_assigned_concurrency_keys)
def free_concurrency_slot_for_step(self, run_id: str, step_key: str) -> None:
self._free_concurrency_slots(run_id=run_id, step_key=step_key)
removed_assigned_concurrency_keys = self._remove_pending_steps(
run_id=run_id, step_key=step_key
)
if removed_assigned_concurrency_keys:
# assign any pending steps that can now claim a slot
self.assign_pending_steps(removed_assigned_concurrency_keys)
def _free_concurrency_slots(self, run_id: str, step_key: Optional[str] = None) -> Sequence[str]:
"""Frees concurrency slots for a given run/step.
Args:
run_id (str): The run id to free the slots for.
step_key (Optional[str]): The step key to free the slots for. If not provided, all the
slots for all the steps of the run will be freed.
"""
with self.index_connection() as conn:
# first delete any rows that apply and are marked as deleted. This happens when the
# configured number of slots has been reduced, and some of the pruned slots included
# ones that were already allocated to the run/step
delete_query = ConcurrencySlotsTable.delete().where(
db.and_(
ConcurrencySlotsTable.c.run_id == run_id,
ConcurrencySlotsTable.c.deleted == True, # noqa: E712
)
)
if step_key:
delete_query = delete_query.where(ConcurrencySlotsTable.c.step_key == step_key)
conn.execute(delete_query)
# next, fetch the slots to free up, while grabbing the concurrency keys so that we can
# allocate any pending steps from the queue for the freed slots, if necessary
select_query = (
db_select([ConcurrencySlotsTable.c.id, ConcurrencySlotsTable.c.concurrency_key])
.select_from(ConcurrencySlotsTable)
.where(ConcurrencySlotsTable.c.run_id == run_id)
.with_for_update()
)
if step_key:
select_query = select_query.where(ConcurrencySlotsTable.c.step_key == step_key)
rows = conn.execute(select_query).fetchall()
if not rows:
return []
# now, actually free the slots
conn.execute(
ConcurrencySlotsTable.update()
.values(run_id=None, step_key=None)
.where(
db.and_(
ConcurrencySlotsTable.c.id.in_([row[0] for row in rows]),
)
)
)
# return the concurrency keys for the freed slots
return [cast(str, row[1]) for row in rows]
def store_asset_check_event(self, event: EventLogEntry, event_id: Optional[int]) -> None:
check.inst_param(event, "event", EventLogEntry)
check.opt_int_param(event_id, "event_id")
check.invariant(
self.supports_asset_checks,
"Asset checks require a database schema migration. Run `dagster instance migrate`.",
)
if event.dagster_event_type == DagsterEventType.ASSET_CHECK_EVALUATION_PLANNED:
self._store_asset_check_evaluation_planned(event, event_id)
if event.dagster_event_type == DagsterEventType.ASSET_CHECK_EVALUATION:
if event.run_id == "" or event.run_id is None:
self._store_runless_asset_check_evaluation(event, event_id)
else:
self._update_asset_check_evaluation(event, event_id)
def _store_asset_check_evaluation_planned(
self, event: EventLogEntry, event_id: Optional[int]
) -> None:
planned = cast(
AssetCheckEvaluationPlanned, check.not_none(event.dagster_event).event_specific_data
)
with self.index_connection() as conn:
conn.execute(
AssetCheckExecutionsTable.insert().values(
asset_key=planned.asset_key.to_string(),
check_name=planned.check_name,
run_id=event.run_id,
execution_status=AssetCheckExecutionRecordStatus.PLANNED.value,
evaluation_event=serialize_value(event),
evaluation_event_timestamp=self._event_insert_timestamp(event),
)
)
def _event_insert_timestamp(self, event):
# Postgres requires a datetime that is in UTC but has no timezone info
return datetime.fromtimestamp(event.timestamp, timezone.utc).replace(tzinfo=None)
def _store_runless_asset_check_evaluation(
self, event: EventLogEntry, event_id: Optional[int]
) -> None:
evaluation = cast(
AssetCheckEvaluation, check.not_none(event.dagster_event).event_specific_data
)
with self.index_connection() as conn:
conn.execute(
AssetCheckExecutionsTable.insert().values(
asset_key=evaluation.asset_key.to_string(),
check_name=evaluation.check_name,
run_id=event.run_id,
execution_status=(
AssetCheckExecutionRecordStatus.SUCCEEDED.value
if evaluation.passed
else AssetCheckExecutionRecordStatus.FAILED.value
),
evaluation_event=serialize_value(event),
evaluation_event_timestamp=self._event_insert_timestamp(event),
evaluation_event_storage_id=event_id,
materialization_event_storage_id=(
evaluation.target_materialization_data.storage_id
if evaluation.target_materialization_data
else None
),
)
)
def _update_asset_check_evaluation(self, event: EventLogEntry, event_id: Optional[int]) -> None:
evaluation = cast(
AssetCheckEvaluation, check.not_none(event.dagster_event).event_specific_data
)
with self.index_connection() as conn:
rows_updated = conn.execute(
AssetCheckExecutionsTable.update()
.where(
# (asset_key, check_name, run_id) uniquely identifies the row created for the planned event
db.and_(
AssetCheckExecutionsTable.c.asset_key == evaluation.asset_key.to_string(),
AssetCheckExecutionsTable.c.check_name == evaluation.check_name,
AssetCheckExecutionsTable.c.run_id == event.run_id,
)
)
.values(
execution_status=(
AssetCheckExecutionRecordStatus.SUCCEEDED.value
if evaluation.passed
else AssetCheckExecutionRecordStatus.FAILED.value
),
evaluation_event=serialize_value(event),
evaluation_event_timestamp=self._event_insert_timestamp(event),
evaluation_event_storage_id=event_id,
materialization_event_storage_id=(
evaluation.target_materialization_data.storage_id
if evaluation.target_materialization_data
else None
),
)
).rowcount
# 0 isn't normally expected, but occurs with the external instance of step launchers where
# they don't have planned events.
if rows_updated > 1:
raise DagsterInvariantViolationError(
f"Updated {rows_updated} rows for asset check evaluation {evaluation.asset_check_key} "
"as a result of duplicate AssetCheckPlanned events."
)
def get_asset_check_execution_history(
self,
check_key: AssetCheckKey,
limit: int,
cursor: Optional[int] = None,
) -> Sequence[AssetCheckExecutionRecord]:
check.inst_param(check_key, "key", AssetCheckKey)
check.int_param(limit, "limit")
check.opt_int_param(cursor, "cursor")
query = (
db_select(
[
AssetCheckExecutionsTable.c.id,
AssetCheckExecutionsTable.c.run_id,
AssetCheckExecutionsTable.c.execution_status,
AssetCheckExecutionsTable.c.evaluation_event,
AssetCheckExecutionsTable.c.create_timestamp,
]
)
.where(
db.and_(
AssetCheckExecutionsTable.c.asset_key == check_key.asset_key.to_string(),
AssetCheckExecutionsTable.c.check_name == check_key.name,
)
)
.order_by(AssetCheckExecutionsTable.c.id.desc())
).limit(limit)
if cursor:
query = query.where(AssetCheckExecutionsTable.c.id < cursor)
with self.index_connection() as conn:
rows = db_fetch_mappings(conn, query)
return [AssetCheckExecutionRecord.from_db_row(row) for row in rows]
def get_latest_asset_check_execution_by_key(
self, check_keys: Sequence[AssetCheckKey]
) -> Mapping[AssetCheckKey, AssetCheckExecutionRecord]:
if not check_keys:
return {}
latest_ids_subquery = db_subquery(
db_select(
[
db.func.max(AssetCheckExecutionsTable.c.id).label("id"),
]
)
.where(
db.and_(
AssetCheckExecutionsTable.c.asset_key.in_(
[key.asset_key.to_string() for key in check_keys]
),
AssetCheckExecutionsTable.c.check_name.in_([key.name for key in check_keys]),
)
)
.group_by(
AssetCheckExecutionsTable.c.asset_key,
AssetCheckExecutionsTable.c.check_name,
)
)
query = db_select(
[
AssetCheckExecutionsTable.c.id,
AssetCheckExecutionsTable.c.asset_key,
AssetCheckExecutionsTable.c.check_name,
AssetCheckExecutionsTable.c.run_id,
AssetCheckExecutionsTable.c.execution_status,
AssetCheckExecutionsTable.c.evaluation_event,
AssetCheckExecutionsTable.c.create_timestamp,
]
).select_from(
AssetCheckExecutionsTable.join(
latest_ids_subquery,
db.and_(
AssetCheckExecutionsTable.c.id == latest_ids_subquery.c.id,
),
)
)
with self.index_connection() as conn:
rows = db_fetch_mappings(conn, query)
return {
AssetCheckKey(
asset_key=check.not_none(AssetKey.from_db_string(cast(str, row["asset_key"]))),
name=cast(str, row["check_name"]),
): AssetCheckExecutionRecord.from_db_row(row)
for row in rows
}
@property
def supports_asset_checks(self):
return self.has_table(AssetCheckExecutionsTable.name)
def get_latest_planned_materialization_info(
self,
asset_key: AssetKey,
partition: Optional[str] = None,
) -> Optional[PlannedMaterializationInfo]:
records = self._get_event_records(
event_records_filter=EventRecordsFilter(
DagsterEventType.ASSET_MATERIALIZATION_PLANNED,
asset_key=asset_key,
asset_partitions=[partition] if partition else None,
),
limit=1,
ascending=False,
)
if not records:
return None
return PlannedMaterializationInfo(
storage_id=records[0].storage_id,
run_id=records[0].run_id,
)
def _get_from_row(row: SqlAlchemyRow, column: str) -> object:
"""Utility function for extracting a column from a sqlalchemy row proxy, since '_asdict' is not
supported in sqlalchemy 1.3.
"""
if column not in row.keys():
return None
return row[column]