Ask AI

Source code for dagster_polars.io_managers.delta

from collections import defaultdict
from enum import Enum
from pprint import pformat
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union

import polars as pl
from dagster import InputContext, MetadataValue, MultiPartitionKey, OutputContext
from dagster._annotations import experimental
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.storage.upath_io_manager import is_dict_type

from dagster_polars.io_managers.base import BasePolarsUPathIOManager

try:
    from deltalake import DeltaTable
    from deltalake.exceptions import TableNotFoundError
except ImportError as e:
    if "deltalake" in str(e):
        raise ImportError(
            "Install 'dagster-polars[deltalake]' to use DeltaLake functionality"
        ) from e
    else:
        raise e

if TYPE_CHECKING:
    from upath import UPath


SINGLE_LOADING_TYPES = (pl.DataFrame, pl.LazyFrame)


class DeltaWriteMode(str, Enum):
    error = "error"
    append = "append"
    overwrite = "overwrite"
    ignore = "ignore"


[docs] @experimental class PolarsDeltaIOManager(BasePolarsUPathIOManager): """Implements writing and reading DeltaLake tables. Features: - All features provided by :py:class:`~dagster_polars.BasePolarsUPathIOManager`. - All read/write options can be set via corresponding metadata or config parameters (metadata takes precedence). - Supports native DeltaLake partitioning by storing different asset partitions in the same DeltaLake table. To enable this behavior, set the `partition_by` metadata value or config parameter **and** use a non-dict type annotation when loading the asset. The `partition_by` value will be used in `delta_write_options` of `pl.DataFrame.write_delta` and `pyarrow_options` of `pl.scan_detla`). When using a one-dimensional `PartitionsDefinition`, it should be a single string like "column`. When using a `MultiPartitionsDefinition`, it should be a dict with dimension to column names mapping, like `{"dimension": "column"}`. Install `dagster-polars[delta]` to use this IOManager. Examples: .. code-block:: python from dagster import asset from dagster_polars import PolarsDeltaIOManager import polars as pl @asset( io_manager_key="polars_delta_io_manager", key_prefix=["my_dataset"] ) def my_asset() -> pl.DataFrame: # data will be stored at <base_dir>/my_dataset/my_asset.delta ... defs = Definitions( assets=[my_table], resources={ "polars_parquet_io_manager": PolarsDeltaIOManager(base_dir="s3://my-bucket/my-dir") } ) Appending to a DeltaLake table: .. code-block:: python @asset( io_manager_key="polars_delta_io_manager", metadata={ "mode": "append" }, ) def my_table() -> pl.DataFrame: ... Using native DeltaLake partitioning by storing different asset partitions in the same DeltaLake table: .. code-block:: python from dagster import AssetExecutionContext, DailyPartitionedDefinition from dagster_polars import LazyFramePartitions @asset( io_manager_key="polars_delta_io_manager", metadata={ "partition_by": "partition_col" }, partitions_def=StaticPartitionsDefinition(["a, "b", "c"]) ) def upstream(context: AssetExecutionContext) -> pl.DataFrame: df = ... # column with the partition_key must match `partition_by` metadata value return df.with_columns(pl.lit(context.partition_key).alias("partition_col")) @asset def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: ... When using `MuiltiPartitionsDefinition`, `partition_by` metadata value should be a dictionary mapping dimensions to column names. .. code-block:: python from dagster import AssetExecutionContext, DailyPartitionedDefinition, MultiPartitionsDefinition, StaticPartitionsDefinition from dagster_polars import LazyFramePartitions @asset( io_manager_key="polars_delta_io_manager", metadata={ "partition_by": {"time": "date", "clients": "client"} # dimension->column mapping }, partitions_def=MultiPartitionsDefinition( { "date": DailyPartitionedDefinition(...), "clients": StaticPartitionsDefinition(...) } ) ) def upstream(context: AssetExecutionContext) -> pl.DataFrame: df = ... partition_keys_by_dimension = context.partition_key.keys_by_dimension return df.with_columns( pl.lit(partition_keys_by_dimension["time"]).alias("date"), # time dimension matches date column pl.lit(partition_keys_by_dimension["clients"]).alias("client") # clients dimension matches client column ) @asset def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: ... """ extension: str = ".delta" mode: DeltaWriteMode = DeltaWriteMode.overwrite.value # type: ignore overwrite_schema: bool = False version: Optional[int] = None def sink_df_to_path( self, context: OutputContext, df: pl.LazyFrame, path: "UPath", ): context_metadata = context.definition_metadata or {} streaming = context_metadata.get("streaming", False) return self.write_df_to_path(context, df.collect(streaming=streaming), path) def write_df_to_path( self, context: OutputContext, df: pl.DataFrame, path: "UPath", ): context_metadata = context.definition_metadata or {} delta_write_options = context_metadata.get( "delta_write_options" ) # This needs to be gone and just only key value on the metadata if context.has_asset_partitions: delta_write_options = delta_write_options or {} partition_by = context_metadata.get( "partition_by" ) # this could be wrong, you could have partition_by in delta_write_options and in the metadata if partition_by is not None: assert ( context.partition_key is not None ), 'can\'t set "partition_by" for an asset without partitions' if isinstance(partition_by, dict) and isinstance( context.partition_key, MultiPartitionKey ): delta_write_options["partition_by"] = list(partition_by.values()) elif isinstance(partition_by, str) and isinstance(context.partition_key, str): delta_write_options["partition_by"] = partition_by else: raise ValueError( "partitio_by metadata value must be a string for single-partitioned assets or a dictionary for multi-partitioned assets" ) delta_write_options["partition_filters"] = self.get_partition_filters(context) if delta_write_options is not None: context.log.debug(f"Writing with delta_write_options: {pformat(delta_write_options)}") storage_options = self.storage_options try: dt = DeltaTable(str(path), storage_options=storage_options) except TableNotFoundError: dt = str(path) df.write_delta( dt, mode=context_metadata.get("mode") or self.mode.value, overwrite_schema=context_metadata.get("overwrite_schema") or self.overwrite_schema, storage_options=storage_options, delta_write_options=delta_write_options, ) if isinstance(dt, DeltaTable): current_version = dt.version() else: current_version = DeltaTable( str(path), storage_options=storage_options, without_files=True ).version() context.add_output_metadata({"version": current_version}) def scan_df_from_path( self, path: "UPath", context: InputContext, ) -> pl.LazyFrame: """This method scans a DeltaLake table into a `polars.LazyFrame`. It can be called in 3 different situations: 1. with an unpartitioned asset 2. with a partitioned asset without native partitioning enabled - multiple times on nested .delta tables 3. with a partitioned asset and with native partitioning enabled - a single time on the .delta table. In the (3) option we apply partition filters to only load mapped partitions """ assert context.upstream_output is not None assert context.upstream_output.definition_metadata is not None context_metadata = context.definition_metadata or {} version = self.get_delta_version_to_load(path, context) context.log.debug(f"Reading Delta table with version: {version}") pyarrow_options = context_metadata.get("pyarrow_options", {}) partition_by = context.upstream_output.definition_metadata.get("partition_by") # we want to apply partition filters when loading some partitions, but not all partitions if ( partition_by and len(context.asset_partition_keys) > 0 and context.has_asset_key is not None and context.has_asset_partitions is not None and context.asset_partition_keys != set( context.upstream_output.asset_partitions_def.get_partition_keys( dynamic_partitions_store=context.instance ) ) ): pyarrow_options["partitions"] = self.get_partition_filters(context) if pyarrow_options: context.log.debug(f"Reading with pyarrow_options: {pyarrow_options}") delta_table_options = context_metadata.get("delta_table_options") if delta_table_options: context.log.debug("Reading with delta_table_options: {delta_table_options}") return pl.scan_delta( str(path), version=version, delta_table_options=delta_table_options, pyarrow_options=pyarrow_options, storage_options=self.storage_options, ) def load_partitions(self, context: InputContext): assert context.upstream_output is not None # any partition would work as they all are stored in the same DeltaLake table path = self._get_path_without_extension(context) context.log.debug( f"Loading {len(context.asset_partition_keys)} partitions from {path} using {self.__class__.__name__}..." ) if context.upstream_output.definition_metadata.get("partition_by") and not is_dict_type( context.dagster_type.typing_type ): # user enabled native partitioning and wants a `pl.DataFrame` or `pl.LazyFrame` return self.load_from_path(context, self._with_extension(path)) else: # default behaviour return super().load_partitions(context) def get_path_for_partition( self, context: Union[InputContext, OutputContext], path: "UPath", partition: str ) -> "UPath": if isinstance(context, InputContext): if ( context.upstream_output is not None and context.upstream_output.definition_metadata is not None and context.upstream_output.definition_metadata.get("partition_by") is not None ): # upstream asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself return path if isinstance(context, OutputContext): if ( context.definition_metadata is not None and context.definition_metadata.get("partition_by") is not None ): # this asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself return path return path / partition # partitioning is handled by the IOManager @staticmethod def get_partition_filters( context: Union[InputContext, OutputContext], ) -> Sequence[Tuple[str, str, Any]]: if isinstance(context, OutputContext): partition_by = context.definition_metadata.get("partition_by") elif isinstance(context, InputContext) and context.upstream_output is not None: partition_by = context.upstream_output.definition_metadata.get("partition_by") else: raise DagsterInvariantViolationError("Invalid context type: type(context)") if partition_by is None or not context.has_asset_partitions: filters = [] elif isinstance(partition_by, dict): all_keys_by_dim = defaultdict(list) for partition_key in context.asset_partition_keys: assert isinstance( partition_key, MultiPartitionKey ), f"received dict `partition_by` metadata value {partition_by}, but the partition_key is not a `MultiPartitionKey`: {partition_key}" for dim, key in partition_key.keys_by_dimension.items(): all_keys_by_dim[dim].append(key) filters = [(partition_by[dim], "in", keys) for dim, keys in all_keys_by_dim.items()] elif isinstance(partition_by, str): assert not isinstance( context.asset_partition_keys[0], MultiPartitionKey ), f"receiveds string `partition_by` metadata value {partition_by}, but the partition_key is not a `MultiPartitionKey`: {context.asset_partition_keys[0]}" filters = [(partition_by, "in", context.asset_partition_keys)] else: raise NotImplementedError("Unsupported `partitio_by` metadata value: {partition_by}") return filters def get_metadata( self, context: OutputContext, obj: Union[pl.DataFrame, pl.LazyFrame, None] ) -> Dict[str, MetadataValue]: context_metadata = context.definition_metadata or {} metadata = super().get_metadata(context, obj) if context.has_asset_partitions: partition_by = context_metadata.get("partition_by") if partition_by is not None: metadata["partition_by"] = partition_by if context_metadata.get("mode") == "append": # modify the medatata to reflect the fact that we are appending to the table if context.has_asset_partitions: # paths = self._get_paths_for_partitions(context) # assert len(paths) == 1 # path = list(paths.values())[0] # FIXME: what to about row_count metadata do if we are appending to a partitioned table? # we should not be using the full table length, # but it's unclear how to get the length of the partition we are appending to pass else: metadata["append_row_count"] = metadata["dagster/row_count"] path = self._get_path(context) # we need to get row_count from the full table metadata["dagster/row_count"] = MetadataValue.int( DeltaTable(str(path), storage_options=self.storage_options) .to_pyarrow_dataset() .count_rows() ) return metadata def get_delta_version_to_load(self, path: "UPath", context: InputContext) -> int: context_metadata = context.definition_metadata or {} version_from_metadata = context_metadata.get("version") version_from_config = self.version version: Optional[int] = None if version_from_metadata is not None and version_from_config is not None: context.log.warning( f"Both version from metadata ({version_from_metadata}) " f"and config ({version_from_config}) are set. Using version from metadata." ) version = int(version_from_metadata) elif version_from_metadata is None and version_from_config is not None: version = int(version_from_config) elif version_from_metadata is not None and version_from_config is None: version = int(version_from_metadata) if version is None: return DeltaTable( str(path), storage_options=self.storage_options, without_files=True ).version() else: return version