Ask AI

Source code for dagster_deltalake.io_manager

import sys
from abc import abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Iterator, Optional, Sequence, Type, Union, cast

from dagster import OutputContext
from dagster._config.pythonic_config import ConfigurableIOManagerFactory
from dagster._core.definitions.time_window_partitions import TimeWindow
from dagster._core.storage.db_io_manager import (
    DbClient,
    DbIOManager,
    DbTypeHandler,
    TablePartitionDimension,
    TableSlice,
)
from pydantic import Field

if sys.version_info >= (3, 8):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

if sys.version_info >= (3, 11):
    from typing import NotRequired
else:
    from typing_extensions import NotRequired

from dagster_deltalake.config import AzureConfig, ClientConfig, GcsConfig, LocalConfig, S3Config

DELTA_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
DELTA_DATE_FORMAT = "%Y-%m-%d"


@dataclass(frozen=True)
class TableConnection:
    table_uri: str
    storage_options: Dict[str, str]
    table_config: Optional[Dict[str, str]]


class _StorageOptionsConfig(TypedDict, total=False):
    local: Dict[str, str]
    s3: Dict[str, str]
    azure: Dict[str, str]
    gcs: Dict[str, str]


class WriteMode(str, Enum):
    error = "error"
    append = "append"
    overwrite = "overwrite"
    ignore = "ignore"


class WriterEngine(str, Enum):
    pyarrow = "pyarrow"
    rust = "rust"


class _DeltaTableIOManagerResourceConfig(TypedDict):
    root_uri: str
    mode: WriteMode
    overwrite_schema: bool
    writer_engine: WriterEngine
    storage_options: _StorageOptionsConfig
    client_options: NotRequired[Dict[str, str]]
    table_config: NotRequired[Dict[str, str]]
    custom_metadata: NotRequired[Dict[str, str]]
    writer_properties: NotRequired[Dict[str, str]]


[docs] class DeltaLakeIOManager(ConfigurableIOManagerFactory): """Base class for an IO manager definition that reads inputs from and writes outputs to Delta Lake. Examples: .. code-block:: python from dagster_deltalake import DeltaLakeIOManager from dagster_deltalake_pandas import DeltaLakePandasTypeHandler class MyDeltaLakeIOManager(DeltaLakeIOManager): @staticmethod def type_handlers() -> Sequence[DbTypeHandler]: return [DeltaLakePandasTypeHandler()] @asset( key_prefix=["my_schema"] # will be used as the schema (parent folder) in Delta Lake ) def my_table() -> pd.DataFrame: # the name of the asset will be the table name ... defs = Definitions( assets=[my_table], resources={"io_manager": MyDeltaLakeIOManager()} ) If you do not provide a schema, Dagster will determine a schema based on the assets and ops using the I/O Manager. For assets, the schema will be determined from the asset key, as in the above example. For ops, the schema can be specified by including a "schema" entry in output metadata. If none of these is provided, the schema will default to "public". .. code-block:: python @op( out={"my_table": Out(metadata={"schema": "my_schema"})} ) def make_my_table() -> pd.DataFrame: ... To only use specific columns of a table as input to a downstream op or asset, add the metadata "columns" to the In or AssetIn. .. code-block:: python @asset( ins={"my_table": AssetIn("my_table", metadata={"columns": ["a"]})} ) def my_table_a(my_table: pd.DataFrame): # my_table will just contain the data from column "a" ... """ root_uri: str = Field(description="Storage location where Delta tables are stored.") mode: WriteMode = Field( default=WriteMode.overwrite.value, description="The write mode passed to save the output." ) overwrite_schema: bool = Field(default=False) writer_engine: WriterEngine = Field( default=WriterEngine.pyarrow.value, description="Engine passed to write_deltalake." ) storage_options: Union[AzureConfig, S3Config, LocalConfig, GcsConfig] = Field( discriminator="provider", description="Configuration for accessing storage location.", ) client_options: Optional[ClientConfig] = Field( default=None, description="Additional configuration passed to http client." ) table_config: Optional[Dict[str, str]] = Field( default=None, description="Additional config and metadata added to table on creation.", ) schema_: Optional[str] = Field( default=None, alias="schema", description="Name of the schema to use." ) # schema is a reserved word for pydantic custom_metadata: Optional[Dict[str, str]] = Field( default=None, description="Custom metadata that is added to transaction commit." ) writer_properties: Optional[Dict[str, str]] = Field( default=None, description="Writer properties passed to the rust engine writer." ) @staticmethod @abstractmethod def type_handlers() -> Sequence[DbTypeHandler]: ... @staticmethod def default_load_type() -> Optional[Type]: return None def create_io_manager(self, context) -> DbIOManager: self.storage_options.dict() return DbIOManager( db_client=DeltaLakeDbClient(), database="deltalake", schema=self.schema_, type_handlers=self.type_handlers(), default_load_type=self.default_load_type(), io_manager_name="DeltaLakeIOManager", )
class DeltaLakeDbClient(DbClient): @staticmethod def delete_table_slice( context: OutputContext, table_slice: TableSlice, connection: TableConnection ) -> None: # deleting the table slice here is a no-op, since we use deltalake's internal mechanism # to overwrite table partitions. pass @staticmethod def ensure_schema_exists( context: OutputContext, table_slice: TableSlice, connection: TableConnection ) -> None: # schemas are just folders and automatically created on write. pass @staticmethod def get_select_statement(table_slice: TableSlice) -> str: # The select statement here is just for illustrative purposes, # and is never actually executed. It does however logically correspond # the operation being executed. col_str = ", ".join(table_slice.columns) if table_slice.columns else "*" if table_slice.partition_dimensions and len(table_slice.partition_dimensions) > 0: query = f"SELECT {col_str} FROM {table_slice.schema}.{table_slice.table} WHERE\n" return query + _partition_where_clause(table_slice.partition_dimensions) else: return f"""SELECT {col_str} FROM {table_slice.schema}.{table_slice.table}""" @staticmethod @contextmanager def connect(context, table_slice: TableSlice) -> Iterator[TableConnection]: resource_config = cast(_DeltaTableIOManagerResourceConfig, context.resource_config) root_uri = resource_config["root_uri"].rstrip("/") storage_options = resource_config["storage_options"] if "local" in storage_options: storage_options = storage_options["local"] elif "s3" in storage_options: storage_options = storage_options["s3"] elif "azure" in storage_options: storage_options = storage_options["azure"] elif "gcs" in storage_options: storage_options = storage_options["gcs"] else: storage_options = {} client_options = resource_config.get("client_options") client_options = client_options or {} storage_options = { **{k: str(v) for k, v in storage_options.items() if v is not None}, **{k: str(v) for k, v in client_options.items() if v is not None}, } table_config = resource_config.get("table_config") table_uri = f"{root_uri}/{table_slice.schema}/{table_slice.table}" conn = TableConnection( table_uri=table_uri, storage_options=storage_options or {}, table_config=table_config, ) yield conn def _partition_where_clause( partition_dimensions: Sequence[TablePartitionDimension], ) -> str: return " AND\n".join( ( _time_window_where_clause(partition_dimension) if isinstance(partition_dimension.partitions, TimeWindow) else _static_where_clause(partition_dimension) ) for partition_dimension in partition_dimensions ) def _time_window_where_clause(table_partition: TablePartitionDimension) -> str: partition = cast(TimeWindow, table_partition.partitions) start_dt, end_dt = partition start_dt_str = start_dt.strftime(DELTA_DATETIME_FORMAT) end_dt_str = end_dt.strftime(DELTA_DATETIME_FORMAT) return f"""{table_partition.partition_expr} >= '{start_dt_str}' AND {table_partition.partition_expr} < '{end_dt_str}'""" def _static_where_clause(table_partition: TablePartitionDimension) -> str: partitions = ", ".join(f"'{partition}'" for partition in table_partition.partitions) return f"""{table_partition.partition_expr} in ({partitions})"""