Ask AI

Source code for dagster_polars.io_managers.base

import sys
from abc import abstractmethod
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Literal,
    Mapping,
    Optional,
    Tuple,
    Union,
    cast,
    get_args,
    get_origin,
    overload,
)

import polars as pl
from dagster import (
    ConfigurableIOManager,
    EnvVar,
    InitResourceContext,
    InputContext,
    MetadataValue,
    OutputContext,
    UPathIOManager,
    _check as check,
)
from pydantic import PrivateAttr
from pydantic.fields import Field

from dagster_polars.io_managers.utils import get_polars_metadata
from dagster_polars.types import (
    DataFramePartitions,
    DataFramePartitionsWithMetadata,
    LazyFramePartitions,
    LazyFramePartitionsWithMetadata,
    LazyFrameWithMetadata,
    StorageMetadata,
)

if TYPE_CHECKING:
    from upath import UPath

POLARS_EAGER_FRAME_ANNOTATIONS = [
    pl.DataFrame,
    Optional[pl.DataFrame],
    # common default types
    Any,
    type(None),
    None,
    # multiple partitions
    Dict[str, pl.DataFrame],
    Mapping[str, pl.DataFrame],
    DataFramePartitions,
    # DataFrame + metadata
    Tuple[pl.DataFrame, StorageMetadata],
    Optional[Tuple[pl.DataFrame, StorageMetadata]],
    # multiple partitions + metadata
    DataFramePartitionsWithMetadata,
]

POLARS_LAZY_FRAME_ANNOTATIONS = [
    pl.LazyFrame,
    Optional[pl.LazyFrame],
    # multiple partitions
    Dict[str, pl.LazyFrame],
    Mapping[str, pl.LazyFrame],
    LazyFramePartitions,
    # LazyFrame + metadata
    Tuple[pl.LazyFrame, StorageMetadata],
    Optional[Tuple[pl.LazyFrame, StorageMetadata]],
    # multiple partitions + metadata
    LazyFramePartitionsWithMetadata,
]


if sys.version_info >= (3, 9):
    POLARS_EAGER_FRAME_ANNOTATIONS.append(dict[str, pl.DataFrame])
    POLARS_EAGER_FRAME_ANNOTATIONS.append(dict[str, Optional[pl.DataFrame]])

    POLARS_LAZY_FRAME_ANNOTATIONS.append(dict[str, pl.LazyFrame])
    POLARS_LAZY_FRAME_ANNOTATIONS.append(dict[str, Optional[pl.LazyFrame]])


def annotation_is_typing_optional(annotation) -> bool:
    return get_origin(annotation) == Union and type(None) in get_args(annotation)


def annotation_is_tuple(annotation) -> bool:
    return get_origin(annotation) in (Tuple, tuple)


def annotation_for_multiple_partitions(annotation) -> bool:
    if not annotation_is_typing_optional(annotation):
        return annotation_is_tuple(annotation) and get_origin(annotation) in (dict, Dict, Mapping)
    else:
        inner_annotation = get_args(annotation)[0]
        return annotation_is_tuple(inner_annotation) and get_origin(inner_annotation) in (
            dict,
            Dict,
            Mapping,
        )


def annotation_is_tuple_with_metadata(annotation) -> bool:
    if annotation_is_typing_optional(annotation):
        annotation = get_args(annotation)[0]

    return annotation_is_tuple(annotation) and get_origin(get_args(annotation)[1]) in [
        dict,
        Dict,
        Mapping,
    ]


def annotation_for_storage_metadata(annotation) -> bool:
    # first unwrap the Optional type
    if annotation_is_typing_optional(annotation):
        annotation = get_args(annotation)[0]

    if not annotation_for_multiple_partitions(annotation):
        return annotation_is_tuple_with_metadata(annotation)
    else:
        # unwrap the partitions
        annotation = get_args(annotation)[1]
        return annotation_is_tuple_with_metadata(annotation)


def _process_env_vars(config: Mapping[str, Any]) -> Dict[str, Any]:
    out = {}
    for key, value in config.items():
        if isinstance(value, dict) and len(value) == 1 and next(iter(value.keys())) == "env":
            out[key] = EnvVar(next(iter(value.values()))).get_value()
        else:
            out[key] = value
    return out


[docs]class BasePolarsUPathIOManager(ConfigurableIOManager, UPathIOManager): """Base class for `dagster-polars` IOManagers. Doesn't define a specific storage format. To implement a specific storage format (parquet, csv, etc), inherit from this class and implement the `write_df_to_path`, `sink_df_to_path` and `scan_df_from_path` methods. Features: - All the features of :py:class:`~dagster.UPathIOManager` - works with local and remote filesystems (like S3), supports loading multiple partitions with respect to :py:class:`~dagster.PartitionMapping`, and more - loads the correct type - `polars.DataFrame`, `polars.LazyFrame`, or other types defined in :py:mod:`dagster_polars.types` - based on the input type annotation (or `dagster.DagsterType`'s `typing_type`) - can sink lazy `pl.LazyFrame` DataFrames - handles `Nones` with `Optional` types by skipping loading missing inputs or saving `None` outputs - logs various metadata about the DataFrame - size, schema, sample, stats, ... - the `"columns"` input metadata value can be used to select a subset of columns to load """ # method calling chain: # 1. Non-partitioned: UPathIOManager.load_input -> UPathIOManager._load_single_input -> UPathIOManager.load_from_path -> BasePolarsUPathIOManager.scan_df_from_path # 2. Partitioned: UPathIOManager.load_input -> UPathIOManager.load_partitions -> UPathIOManager.load_from_path -> BasePolarsUPathIOManager.scan_df_from_path # If a child IOManager supports loading multiple partitions at once, it should override .load_partitions to immidiately return a LazyFrame (by using scan_df_from_path) base_dir: Optional[str] = Field(default=None, description="Base directory for storing files.") cloud_storage_options: Optional[Mapping[str, Any]] = Field( default=None, description="Storage authentication for cloud object store", alias="storage_options", ) _base_path = PrivateAttr() def setup_for_execution(self, context: InitResourceContext) -> None: from upath import UPath sp = ( _process_env_vars(self.cloud_storage_options) if self.cloud_storage_options is not None else {} ) self._base_path = ( UPath(self.base_dir, **sp) if self.base_dir is not None else UPath(check.not_none(context.instance).storage_directory()) ) @abstractmethod def write_df_to_path( self, context: OutputContext, df: pl.DataFrame, path: "UPath", metadata: Optional[StorageMetadata] = None, ): ... @abstractmethod def sink_df_to_path( self, context: OutputContext, df: pl.LazyFrame, path: "UPath", metadata: Optional[StorageMetadata] = None, ): ... @overload @abstractmethod def scan_df_from_path( self, path: "UPath", context: InputContext, with_metadata: Literal[None, False] ) -> pl.LazyFrame: ... @overload @abstractmethod def scan_df_from_path( self, path: "UPath", context: InputContext, with_metadata: Literal[True] ) -> LazyFrameWithMetadata: ... @abstractmethod def scan_df_from_path( self, path: "UPath", context: InputContext, with_metadata: Optional[bool] = False ) -> Union[pl.LazyFrame, LazyFrameWithMetadata]: ... def dump_to_path( self, context: OutputContext, obj: Union[ pl.DataFrame, Optional[pl.DataFrame], Tuple[pl.DataFrame, Dict[str, Any]], pl.LazyFrame, Optional[pl.LazyFrame], Tuple[pl.LazyFrame, Dict[str, Any]], ], path: "UPath", ): typing_type = context.dagster_type.typing_type if annotation_is_typing_optional(typing_type) and ( obj is None or annotation_for_storage_metadata(typing_type) and obj[0] is None ): context.log.warning(self.get_optional_output_none_log_message(context, path)) return else: assert obj is not None, "output should not be None if it's type is not Optional" if not annotation_for_storage_metadata(typing_type): if typing_type in POLARS_EAGER_FRAME_ANNOTATIONS: obj = cast(pl.DataFrame, obj) df = obj self.write_df_to_path(context=context, df=df, path=path) elif typing_type in POLARS_LAZY_FRAME_ANNOTATIONS: obj = cast(pl.LazyFrame, obj) df = obj self.sink_df_to_path(context=context, df=df, path=path) else: raise NotImplementedError( f"dump_df_to_path for {typing_type} is not implemented" ) else: if not annotation_is_typing_optional(typing_type): frame_type = get_args(typing_type)[0] else: frame_type = get_args(get_args(typing_type)[0])[0] if frame_type in POLARS_EAGER_FRAME_ANNOTATIONS: obj = cast(Tuple[pl.DataFrame, Dict[str, Any]], obj) df, metadata = obj self.write_df_to_path(context=context, df=df, path=path, metadata=metadata) elif frame_type in POLARS_LAZY_FRAME_ANNOTATIONS: obj = cast(Tuple[pl.LazyFrame, Dict[str, Any]], obj) df, metadata = obj self.sink_df_to_path(context=context, df=df, path=path, metadata=metadata) else: raise NotImplementedError( f"dump_df_to_path for {typing_type} is not implemented" ) def needs_output_metadata(self, context: Union[InputContext, OutputContext]) -> bool: return annotation_for_storage_metadata(context.dagster_type.typing_type) def load_from_path( self, context: InputContext, path: "UPath" ) -> Union[ pl.DataFrame, pl.LazyFrame, Tuple[pl.DataFrame, Dict[str, Any]], Tuple[pl.LazyFrame, Dict[str, Any]], None, ]: if annotation_is_typing_optional(context.dagster_type.typing_type) and not path.exists(): context.log.warning(self.get_missing_optional_input_log_message(context, path)) return None assert context.definition_metadata is not None metadata: Optional[StorageMetadata] = None return_storage_metadata = self.needs_output_metadata(context) if not return_storage_metadata: ldf = self.scan_df_from_path(path=path, context=context) # type: ignore else: ldf, metadata = self.scan_df_from_path(path=path, context=context, with_metadata=True) columns = context.definition_metadata.get("columns") if columns is not None: context.log.debug(f"Loading {columns=}") ldf = ldf.select(columns) if context.dagster_type.typing_type in POLARS_EAGER_FRAME_ANNOTATIONS: if not return_storage_metadata: return ldf.collect() else: assert metadata is not None return ldf.collect(), metadata elif context.dagster_type.typing_type in POLARS_LAZY_FRAME_ANNOTATIONS: if not return_storage_metadata: return ldf else: assert metadata is not None return ldf, metadata else: raise NotImplementedError( f"Can't load object for type annotation {context.dagster_type.typing_type}" ) def get_metadata( self, context: OutputContext, obj: Union[pl.DataFrame, pl.LazyFrame, None] ) -> Dict[str, MetadataValue]: if obj is None: return {"missing": MetadataValue.bool(True)} else: if annotation_for_storage_metadata(context.dagster_type.typing_type): df = obj[0] else: df = obj return ( get_polars_metadata(context, df) if df is not None else {"missing": MetadataValue.bool(True)} ) def get_missing_optional_input_log_message(self, context: InputContext, path: "UPath") -> str: return f"Optional input {context.name} at {path} doesn't exist in the filesystem and won't be loaded!" def get_optional_output_none_log_message(self, context: OutputContext, path: "UPath") -> str: return f"The object for the optional output {context.name} is None, so it won't be saved to {path}!"