from typing import TYPE_CHECKING, Any, Dict, Optional, cast
import polars as pl
import pyarrow.dataset as ds
from dagster import InputContext, OutputContext
from dagster._annotations import experimental
from fsspec.implementations.local import LocalFileSystem
from packaging.version import Version
from dagster_polars.io_managers.base import BasePolarsUPathIOManager
if TYPE_CHECKING:
from upath import UPath
DAGSTER_POLARS_STORAGE_METADATA_KEY = "dagster_polars_metadata"
def get_pyarrow_dataset(path: "UPath", context: InputContext) -> ds.Dataset:
context_metadata = context.definition_metadata or {}
fs = path.fs if hasattr(path, "fs") else None
if context_metadata.get("partitioning") is not None:
context.log.warning(
f'"partitioning" metadata value for PolarsParquetIOManager is deprecated '
f'in favor of "partition_by" (loading from {path})'
)
dataset = ds.dataset(
str(path),
filesystem=fs,
format=context_metadata.get("format", "parquet"),
partitioning=context_metadata.get("partitioning") or context_metadata.get("partition_by"),
partition_base_dir=context_metadata.get("partition_base_dir"),
exclude_invalid_files=context_metadata.get("exclude_invalid_files", True),
ignore_prefixes=context_metadata.get("ignore_prefixes", [".", "_"]),
)
return dataset
def scan_parquet(path: "UPath", context: InputContext) -> pl.LazyFrame:
"""Scan a parquet file and return a lazy frame (uses polars native reader).
:param path:
:param context:
:return:
"""
context_metadata = context.definition_metadata or {}
storage_options = cast(
Optional[Dict[str, Any]],
(path.storage_options if hasattr(path, "storage_options") else None),
)
kwargs = dict(
n_rows=context_metadata.get("n_rows", None),
cache=context_metadata.get("cache", True),
parallel=context_metadata.get("parallel", "auto"),
rechunk=context_metadata.get("rechunk", True),
low_memory=context_metadata.get("low_memory", False),
use_statistics=context_metadata.get("use_statistics", True),
hive_partitioning=context_metadata.get("hive_partitioning", True),
retries=context_metadata.get("retries", 0),
)
if Version(pl.__version__) >= Version("0.20.4"):
kwargs["row_index_name"] = context_metadata.get("row_index_name", None)
kwargs["row_index_offset"] = context_metadata.get("row_index_offset", 0)
else:
kwargs["row_count_name"] = context_metadata.get("row_count_name", None)
kwargs["row_count_offset"] = context_metadata.get("row_count_offset", 0)
return pl.scan_parquet(str(path), storage_options=storage_options, **kwargs) # type: ignore
[docs]
@experimental
class PolarsParquetIOManager(BasePolarsUPathIOManager):
"""Implements reading and writing Polars DataFrames in Apache Parquet format.
Features:
- All features provided by :py:class:`~dagster_polars.BasePolarsUPathIOManager`.
- All read/write options can be set via corresponding metadata or config parameters (metadata takes precedence).
- Supports reading partitioned Parquet datasets (for example, often produced by Spark).
- Supports reading/writing custom metadata in the Parquet file's schema as json-serialized bytes at `"dagster_polars_metadata"` key.
Examples:
.. code-block:: python
from dagster import asset
from dagster_polars import PolarsParquetIOManager
import polars as pl
@asset(
io_manager_key="polars_parquet_io_manager",
key_prefix=["my_dataset"]
)
def my_asset() -> pl.DataFrame: # data will be stored at <base_dir>/my_dataset/my_asset.parquet
...
defs = Definitions(
assets=[my_table],
resources={
"polars_parquet_io_manager": PolarsParquetIOManager(base_dir="s3://my-bucket/my-dir")
}
)
Reading partitioned Parquet datasets:
.. code-block:: python
from dagster import SourceAsset
my_asset = SourceAsset(
key=["path", "to", "dataset"],
io_manager_key="polars_parquet_io_manager",
metadata={
"partition_by": ["year", "month", "day"]
}
)
"""
extension: str = ".parquet"
def sink_df_to_path(
self,
context: OutputContext,
df: pl.LazyFrame,
path: "UPath",
):
context_metadata = context.definition_metadata or {}
fs = path.fs if hasattr(path, "fs") else None
if isinstance(fs, LocalFileSystem):
compression = context_metadata.get("compression", "zstd")
compression_level = context_metadata.get("compression_level")
statistics = context_metadata.get("statistics", False)
row_group_size = context_metadata.get("row_group_size")
df.sink_parquet(
str(path),
compression=compression,
compression_level=compression_level,
statistics=statistics,
row_group_size=row_group_size,
)
else:
# TODO(ion): add sink_parquet once this PR gets merged: https://github.com/pola-rs/polars/pull/11519
context.log.warning(
"Cloud sink is not possible yet, instead it's dispatched to pyarrow writer which collects it into memory first.",
)
return self.write_df_to_path(context, df.collect(), path)
def write_df_to_path(
self,
context: OutputContext,
df: pl.DataFrame,
path: "UPath",
):
context_metadata = context.definition_metadata or {}
compression = context_metadata.get("compression", "zstd")
compression_level = context_metadata.get("compression_level")
statistics = context_metadata.get("statistics", False)
row_group_size = context_metadata.get("row_group_size")
pyarrow_options = context_metadata.get("pyarrow_options", None)
fs = path.fs if hasattr(path, "fs") else None
if pyarrow_options is not None:
pyarrow_options["filesystem"] = fs
df.write_parquet(
str(path),
compression=compression, # type: ignore
compression_level=compression_level,
statistics=statistics,
row_group_size=row_group_size,
use_pyarrow=True,
pyarrow_options=pyarrow_options,
)
elif fs is not None:
with fs.open(str(path), mode="wb") as f:
df.write_parquet(
f,
compression=compression, # type: ignore
compression_level=compression_level,
statistics=statistics,
row_group_size=row_group_size,
)
else:
df.write_parquet(
str(path),
compression=compression, # type: ignore
compression_level=compression_level,
statistics=statistics,
row_group_size=row_group_size,
)
def scan_df_from_path(
self,
path: "UPath",
context: InputContext,
partition_key: Optional[str] = None,
) -> pl.LazyFrame:
return scan_parquet(path, context)