Ask AI

Source code for dagster_polars.io_managers.parquet

from typing import TYPE_CHECKING, Any, Dict, Optional, cast

import polars as pl
import pyarrow.dataset as ds
from dagster import InputContext, OutputContext
from dagster._annotations import experimental
from fsspec.implementations.local import LocalFileSystem
from packaging.version import Version

from dagster_polars.io_managers.base import BasePolarsUPathIOManager

if TYPE_CHECKING:
    from upath import UPath


DAGSTER_POLARS_STORAGE_METADATA_KEY = "dagster_polars_metadata"


def get_pyarrow_dataset(path: "UPath", context: InputContext) -> ds.Dataset:
    context_metadata = context.definition_metadata or {}

    fs = path.fs if hasattr(path, "fs") else None

    if context_metadata.get("partitioning") is not None:
        context.log.warning(
            f'"partitioning" metadata value for PolarsParquetIOManager is deprecated '
            f'in favor of "partition_by" (loading from {path})'
        )

    dataset = ds.dataset(
        str(path),
        filesystem=fs,
        format=context_metadata.get("format", "parquet"),
        partitioning=context_metadata.get("partitioning") or context_metadata.get("partition_by"),
        partition_base_dir=context_metadata.get("partition_base_dir"),
        exclude_invalid_files=context_metadata.get("exclude_invalid_files", True),
        ignore_prefixes=context_metadata.get("ignore_prefixes", [".", "_"]),
    )

    return dataset


def scan_parquet(path: "UPath", context: InputContext) -> pl.LazyFrame:
    """Scan a parquet file and return a lazy frame (uses polars native reader).

    :param path:
    :param context:
    :return:
    """
    context_metadata = context.definition_metadata or {}

    storage_options = cast(
        Optional[Dict[str, Any]],
        (path.storage_options if hasattr(path, "storage_options") else None),
    )

    kwargs = dict(
        n_rows=context_metadata.get("n_rows", None),
        cache=context_metadata.get("cache", True),
        parallel=context_metadata.get("parallel", "auto"),
        rechunk=context_metadata.get("rechunk", True),
        low_memory=context_metadata.get("low_memory", False),
        use_statistics=context_metadata.get("use_statistics", True),
        hive_partitioning=context_metadata.get("hive_partitioning", True),
        retries=context_metadata.get("retries", 0),
    )
    if Version(pl.__version__) >= Version("0.20.4"):
        kwargs["row_index_name"] = context_metadata.get("row_index_name", None)
        kwargs["row_index_offset"] = context_metadata.get("row_index_offset", 0)
    else:
        kwargs["row_count_name"] = context_metadata.get("row_count_name", None)
        kwargs["row_count_offset"] = context_metadata.get("row_count_offset", 0)

    return pl.scan_parquet(str(path), storage_options=storage_options, **kwargs)  # type: ignore


[docs] @experimental class PolarsParquetIOManager(BasePolarsUPathIOManager): """Implements reading and writing Polars DataFrames in Apache Parquet format. Features: - All features provided by :py:class:`~dagster_polars.BasePolarsUPathIOManager`. - All read/write options can be set via corresponding metadata or config parameters (metadata takes precedence). - Supports reading partitioned Parquet datasets (for example, often produced by Spark). - Supports reading/writing custom metadata in the Parquet file's schema as json-serialized bytes at `"dagster_polars_metadata"` key. Examples: .. code-block:: python from dagster import asset from dagster_polars import PolarsParquetIOManager import polars as pl @asset( io_manager_key="polars_parquet_io_manager", key_prefix=["my_dataset"] ) def my_asset() -> pl.DataFrame: # data will be stored at <base_dir>/my_dataset/my_asset.parquet ... defs = Definitions( assets=[my_table], resources={ "polars_parquet_io_manager": PolarsParquetIOManager(base_dir="s3://my-bucket/my-dir") } ) Reading partitioned Parquet datasets: .. code-block:: python from dagster import SourceAsset my_asset = SourceAsset( key=["path", "to", "dataset"], io_manager_key="polars_parquet_io_manager", metadata={ "partition_by": ["year", "month", "day"] } ) """ extension: str = ".parquet" def sink_df_to_path( self, context: OutputContext, df: pl.LazyFrame, path: "UPath", ): context_metadata = context.definition_metadata or {} fs = path.fs if hasattr(path, "fs") else None if isinstance(fs, LocalFileSystem): compression = context_metadata.get("compression", "zstd") compression_level = context_metadata.get("compression_level") statistics = context_metadata.get("statistics", False) row_group_size = context_metadata.get("row_group_size") df.sink_parquet( str(path), compression=compression, compression_level=compression_level, statistics=statistics, row_group_size=row_group_size, ) else: # TODO(ion): add sink_parquet once this PR gets merged: https://github.com/pola-rs/polars/pull/11519 context.log.warning( "Cloud sink is not possible yet, instead it's dispatched to pyarrow writer which collects it into memory first.", ) return self.write_df_to_path(context, df.collect(), path) def write_df_to_path( self, context: OutputContext, df: pl.DataFrame, path: "UPath", ): context_metadata = context.definition_metadata or {} compression = context_metadata.get("compression", "zstd") compression_level = context_metadata.get("compression_level") statistics = context_metadata.get("statistics", False) row_group_size = context_metadata.get("row_group_size") pyarrow_options = context_metadata.get("pyarrow_options", None) fs = path.fs if hasattr(path, "fs") else None if pyarrow_options is not None: pyarrow_options["filesystem"] = fs df.write_parquet( str(path), compression=compression, # type: ignore compression_level=compression_level, statistics=statistics, row_group_size=row_group_size, use_pyarrow=True, pyarrow_options=pyarrow_options, ) elif fs is not None: with fs.open(str(path), mode="wb") as f: df.write_parquet( f, compression=compression, # type: ignore compression_level=compression_level, statistics=statistics, row_group_size=row_group_size, ) else: df.write_parquet( str(path), compression=compression, # type: ignore compression_level=compression_level, statistics=statistics, row_group_size=row_group_size, ) def scan_df_from_path( self, path: "UPath", context: InputContext, partition_key: Optional[str] = None, ) -> pl.LazyFrame: return scan_parquet(path, context)