Ask AI

Source code for dagster._core.definitions.multi_dimensional_partitions

import hashlib
import itertools
from datetime import datetime
from functools import lru_cache, reduce
from typing import (
    Dict,
    List,
    Mapping,
    NamedTuple,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    Union,
    cast,
)

import dagster._check as check
from dagster._annotations import public
from dagster._core.definitions.partition import (
    DefaultPartitionsSubset,
    DynamicPartitionsDefinition,
    PartitionsDefinition,
    PartitionsSubset,
    StaticPartitionsDefinition,
)
from dagster._core.definitions.partition_key_range import PartitionKeyRange
from dagster._core.definitions.time_window_partitions import (
    TimeWindow,
    TimeWindowPartitionsDefinition,
)
from dagster._core.errors import (
    DagsterInvalidDefinitionError,
    DagsterInvalidInvocationError,
    DagsterUnknownPartitionError,
)
from dagster._core.instance import DynamicPartitionsStore
from dagster._core.storage.tags import (
    MULTIDIMENSIONAL_PARTITION_PREFIX,
    get_multidimensional_partition_tag,
)
from dagster._time import get_current_datetime

INVALID_STATIC_PARTITIONS_KEY_CHARACTERS = set(["|", ",", "[", "]"])

MULTIPARTITION_KEY_DELIMITER = "|"


class PartitionDimensionKey(
    NamedTuple("_PartitionDimensionKey", [("dimension_name", str), ("partition_key", str)])
):
    """Representation of a single dimension of a multi-dimensional partition key."""

    def __new__(cls, dimension_name: str, partition_key: str):
        return super(PartitionDimensionKey, cls).__new__(
            cls,
            dimension_name=check.str_param(dimension_name, "dimension_name"),
            partition_key=check.str_param(partition_key, "partition_key"),
        )


[docs] class MultiPartitionKey(str): """A multi-dimensional partition key stores the partition key for each dimension. Subclasses the string class to keep partition key type as a string. Contains additional methods to access the partition key for each dimension. Creates a string representation of the partition key for each dimension, separated by a pipe (|). Orders the dimensions by name, to ensure consistent string representation. """ dimension_keys: List[PartitionDimensionKey] = [] def __new__(cls, keys_by_dimension: Mapping[str, str]): check.mapping_param( keys_by_dimension, "partitions_by_dimension", key_type=str, value_type=str ) dimension_keys: List[PartitionDimensionKey] = [ PartitionDimensionKey(dimension, keys_by_dimension[dimension]) for dimension in sorted(list(keys_by_dimension.keys())) ] str_key = super(MultiPartitionKey, cls).__new__( cls, MULTIPARTITION_KEY_DELIMITER.join( [dim_key.partition_key for dim_key in dimension_keys] ), ) str_key.dimension_keys = dimension_keys return str_key def __getnewargs__(self): # When this instance is pickled, replace the argument to __new__ with the # dimension key mapping instead of the string representation. return ({dim_key.dimension_name: dim_key.partition_key for dim_key in self.dimension_keys},) @property def keys_by_dimension(self) -> Mapping[str, str]: return {dim_key.dimension_name: dim_key.partition_key for dim_key in self.dimension_keys}
class PartitionDimensionDefinition( NamedTuple( "_PartitionDimensionDefinition", [ ("name", str), ("partitions_def", PartitionsDefinition), ], ) ): def __new__( cls, name: str, partitions_def: PartitionsDefinition, ): return super().__new__( cls, name=check.str_param(name, "name"), partitions_def=check.inst_param(partitions_def, "partitions_def", PartitionsDefinition), ) def __eq__(self, other: object) -> bool: return ( isinstance(other, PartitionDimensionDefinition) and self.name == other.name and self.partitions_def == other.partitions_def ) ALLOWED_PARTITION_DIMENSION_TYPES = ( StaticPartitionsDefinition, TimeWindowPartitionsDefinition, DynamicPartitionsDefinition, ) def _check_valid_partitions_dimensions( partitions_dimensions: Mapping[str, PartitionsDefinition], ) -> None: for dim_name, partitions_def in partitions_dimensions.items(): if not any(isinstance(partitions_def, t) for t in ALLOWED_PARTITION_DIMENSION_TYPES): raise DagsterInvalidDefinitionError( f"Invalid partitions definition type {type(partitions_def)}. " "Only the following partitions definition types are supported: " f"{ALLOWED_PARTITION_DIMENSION_TYPES}." ) if isinstance(partitions_def, DynamicPartitionsDefinition) and partitions_def.name is None: raise DagsterInvalidDefinitionError( "DynamicPartitionsDefinition must have a name to be used in a" " MultiPartitionsDefinition." ) if isinstance(partitions_def, StaticPartitionsDefinition): if any( [ INVALID_STATIC_PARTITIONS_KEY_CHARACTERS & set(key) for key in partitions_def.get_partition_keys() ] ): raise DagsterInvalidDefinitionError( f"Invalid character in partition key for dimension {dim_name}. " "A multi-partitions definition cannot contain partition keys with " "the following characters: |, [, ], ," )
[docs] class MultiPartitionsDefinition(PartitionsDefinition[MultiPartitionKey]): """Takes the cross-product of partitions from two partitions definitions. For example, with a static partitions definition where the partitions are ["a", "b", "c"] and a daily partitions definition, this partitions definition will have the following partitions: 2020-01-01|a 2020-01-01|b 2020-01-01|c 2020-01-02|a 2020-01-02|b ... We recommended limiting partition counts for each asset to 25,000 partitions or fewer. Args: partitions_defs (Mapping[str, PartitionsDefinition]): A mapping of dimension name to partitions definition. The total set of partitions will be the cross-product of the partitions from each PartitionsDefinition. Attributes: partitions_defs (Sequence[PartitionDimensionDefinition]): A sequence of PartitionDimensionDefinition objects, each of which contains a dimension name and a PartitionsDefinition. The total set of partitions will be the cross-product of the partitions from each PartitionsDefinition. This sequence is ordered by dimension name, to ensure consistent ordering of the partitions. """ def __init__(self, partitions_defs: Mapping[str, PartitionsDefinition]): if not len(partitions_defs.keys()) == 2: raise DagsterInvalidInvocationError( "Dagster currently only supports multi-partitions definitions with 2 partitions" " definitions. Your multi-partitions definition has" f" {len(partitions_defs.keys())} partitions definitions." ) check.mapping_param( partitions_defs, "partitions_defs", key_type=str, value_type=PartitionsDefinition ) _check_valid_partitions_dimensions(partitions_defs) self._partitions_defs: List[PartitionDimensionDefinition] = sorted( [ PartitionDimensionDefinition(name, partitions_def) for name, partitions_def in partitions_defs.items() ], key=lambda x: x.name, ) @property def partitions_subset_class(self) -> Type["PartitionsSubset"]: return DefaultPartitionsSubset def get_partition_keys_in_range( self, partition_key_range: PartitionKeyRange, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, ) -> Sequence[str]: start: MultiPartitionKey = self.get_partition_key_from_str(partition_key_range.start) end: MultiPartitionKey = self.get_partition_key_from_str(partition_key_range.end) partition_key_sequences = [ partition_dim.partitions_def.get_partition_keys_in_range( PartitionKeyRange( start.keys_by_dimension[partition_dim.name], end.keys_by_dimension[partition_dim.name], ), dynamic_partitions_store=dynamic_partitions_store, ) for partition_dim in self._partitions_defs ] return [ MultiPartitionKey( {self._partitions_defs[i].name: key for i, key in enumerate(partition_key_tuple)} ) for partition_key_tuple in itertools.product(*partition_key_sequences) ] def get_serializable_unique_identifier( self, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None ) -> str: return hashlib.sha1( str( { dim_def.name: dim_def.partitions_def.get_serializable_unique_identifier( dynamic_partitions_store ) for dim_def in self.partitions_defs } ).encode("utf-8") ).hexdigest() @property def partition_dimension_names(self) -> List[str]: return [dim_def.name for dim_def in self._partitions_defs] @property def partitions_defs(self) -> Sequence[PartitionDimensionDefinition]: return self._partitions_defs def get_partitions_def_for_dimension(self, dimension_name: str) -> PartitionsDefinition: for dim_def in self._partitions_defs: if dim_def.name == dimension_name: return dim_def.partitions_def check.failed(f"Invalid dimension name {dimension_name}") # We override the default implementation of `has_partition_key` for performance. def has_partition_key( self, partition_key: Union[MultiPartitionKey, str], current_time: Optional[datetime] = None, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, ) -> bool: partition_key = ( partition_key if isinstance(partition_key, MultiPartitionKey) else self.get_partition_key_from_str(partition_key) ) if partition_key.keys_by_dimension.keys() != set(self.partition_dimension_names): raise DagsterUnknownPartitionError( f"Invalid partition key {partition_key}. The dimensions of the partition key are" " not the dimensions of the partitions definition." ) for dimension in self.partitions_defs: if not dimension.partitions_def.has_partition_key( partition_key.keys_by_dimension[dimension.name], current_time=current_time, dynamic_partitions_store=dynamic_partitions_store, ): return False return True # store results for repeated calls with the same current_time @lru_cache(maxsize=1) def _get_partition_keys( self, current_time: datetime, dynamic_partitions_store: Optional[DynamicPartitionsStore] ) -> Sequence[MultiPartitionKey]: partition_key_sequences = [ partition_dim.partitions_def.get_partition_keys( current_time=current_time, dynamic_partitions_store=dynamic_partitions_store ) for partition_dim in self._partitions_defs ] return [ MultiPartitionKey( {self._partitions_defs[i].name: key for i, key in enumerate(partition_key_tuple)} ) for partition_key_tuple in itertools.product(*partition_key_sequences) ]
[docs] @public def get_partition_keys( self, current_time: Optional[datetime] = None, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, ) -> Sequence[MultiPartitionKey]: """Returns a list of MultiPartitionKeys representing the partition keys of the PartitionsDefinition. Args: current_time (Optional[datetime]): A datetime object representing the current time, only applicable to time-based partition dimensions. dynamic_partitions_store (Optional[DynamicPartitionsStore]): The DynamicPartitionsStore object that is responsible for fetching dynamic partitions. Required when a dimension is a DynamicPartitionsDefinition with a name defined. Users can pass the DagsterInstance fetched via `context.instance` to this argument. Returns: Sequence[MultiPartitionKey] """ return self._get_partition_keys( current_time or get_current_datetime(), dynamic_partitions_store )
def filter_valid_partition_keys( self, partition_keys: Set[str], dynamic_partitions_store: DynamicPartitionsStore ) -> Set[MultiPartitionKey]: partition_keys_by_dimension = { dim.name: dim.partitions_def.get_partition_keys( dynamic_partitions_store=dynamic_partitions_store ) for dim in self.partitions_defs } validated_partitions = set() for partition_key in partition_keys: partition_key_strs = partition_key.split(MULTIPARTITION_KEY_DELIMITER) if len(partition_key_strs) != len(self.partitions_defs): continue multipartition_key = MultiPartitionKey( {dim.name: partition_key_strs[i] for i, dim in enumerate(self._partitions_defs)} ) if all( key in partition_keys_by_dimension.get(dim, []) for dim, key in multipartition_key.keys_by_dimension.items() ): validated_partitions.add(partition_key) return validated_partitions def __eq__(self, other): return ( isinstance(other, MultiPartitionsDefinition) and self.partitions_defs == other.partitions_defs ) def __hash__(self): return hash( tuple( [ (partitions_def.name, partitions_def.__repr__()) for partitions_def in self.partitions_defs ] ) ) def __str__(self) -> str: dimension_1 = self._partitions_defs[0] dimension_2 = self._partitions_defs[1] partition_str = ( "Multi-partitioned, with dimensions: \n" f"{dimension_1.name.capitalize()}: {dimension_1.partitions_def} \n" f"{dimension_2.name.capitalize()}: {dimension_2.partitions_def}" ) return partition_str def __repr__(self) -> str: return f"{type(self).__name__}(dimensions={[str(dim) for dim in self.partitions_defs]}" def get_partition_key_from_str(self, partition_key_str: str) -> MultiPartitionKey: """Given a string representation of a partition key, returns a MultiPartitionKey object.""" check.str_param(partition_key_str, "partition_key_str") partition_key_strs = partition_key_str.split(MULTIPARTITION_KEY_DELIMITER) check.invariant( len(partition_key_strs) == len(self.partitions_defs), f"Expected {len(self.partitions_defs)} partition keys in partition key string" f" {partition_key_str}, but got {len(partition_key_strs)}", ) return MultiPartitionKey( {dim.name: partition_key_strs[i] for i, dim in enumerate(self._partitions_defs)} ) def _get_primary_and_secondary_dimension( self, ) -> Tuple[PartitionDimensionDefinition, PartitionDimensionDefinition]: # Multipartitions subsets are serialized by primary dimension. If changing # the selection of primary/secondary dimension, will need to also update the # serialization of MultiPartitionsSubsets time_dimensions = self._get_time_window_dims() if len(time_dimensions) == 1: primary_dimension, secondary_dimension = ( time_dimensions[0], next(iter([dim for dim in self.partitions_defs if dim != time_dimensions[0]])), ) else: primary_dimension, secondary_dimension = ( self.partitions_defs[0], self.partitions_defs[1], ) return primary_dimension, secondary_dimension @property def primary_dimension(self) -> PartitionDimensionDefinition: return self._get_primary_and_secondary_dimension()[0] @property def secondary_dimension(self) -> PartitionDimensionDefinition: return self._get_primary_and_secondary_dimension()[1] def get_tags_for_partition_key(self, partition_key: str) -> Mapping[str, str]: partition_key = cast(MultiPartitionKey, self.get_partition_key_from_str(partition_key)) tags = {**super().get_tags_for_partition_key(partition_key)} tags.update(get_tags_from_multi_partition_key(partition_key)) return tags @property def time_window_dimension(self) -> PartitionDimensionDefinition: time_window_dims = self._get_time_window_dims() check.invariant( len(time_window_dims) == 1, "Expected exactly one time window partitioned dimension" ) return next(iter(time_window_dims)) def _get_time_window_dims(self) -> List[PartitionDimensionDefinition]: return [ dim for dim in self.partitions_defs if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition) ] @property def has_time_window_dimension(self) -> bool: return bool(self._get_time_window_dims()) @property def time_window_partitions_def(self) -> TimeWindowPartitionsDefinition: check.invariant(self.has_time_window_dimension, "Must have time window dimension") return cast( TimeWindowPartitionsDefinition, check.inst(self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition), ) def time_window_for_partition_key(self, partition_key: str) -> TimeWindow: if not isinstance(partition_key, MultiPartitionKey): partition_key = self.get_partition_key_from_str(partition_key) time_window_dimension = self.time_window_dimension return cast( TimeWindowPartitionsDefinition, time_window_dimension.partitions_def ).time_window_for_partition_key( cast(MultiPartitionKey, partition_key).keys_by_dimension[time_window_dimension.name] ) def get_multipartition_keys_with_dimension_value( self, dimension_name: str, dimension_partition_key: str, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, current_time: Optional[datetime] = None, ) -> Sequence[MultiPartitionKey]: check.str_param(dimension_name, "dimension_name") check.str_param(dimension_partition_key, "dimension_partition_key") matching_dimensions = [ dimension for dimension in self.partitions_defs if dimension.name == dimension_name ] other_dimensions = [ dimension for dimension in self.partitions_defs if dimension.name != dimension_name ] check.invariant( len(matching_dimensions) == 1, f"Dimension {dimension_name} not found in MultiPartitionsDefinition with dimensions" f" {[dim.name for dim in self.partitions_defs]}", ) partition_sequences = [ partition_dim.partitions_def.get_partition_keys( current_time=current_time, dynamic_partitions_store=dynamic_partitions_store ) for partition_dim in other_dimensions ] + [[dimension_partition_key]] # Names of partitions dimensions in the same order as partition_sequences partition_dim_names = [dim.name for dim in other_dimensions] + [dimension_name] return [ MultiPartitionKey( { partition_dim_names[i]: partition_key for i, partition_key in enumerate(partitions_tuple) } ) for partitions_tuple in itertools.product(*partition_sequences) ] def get_num_partitions( self, current_time: Optional[datetime] = None, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, ) -> int: # Static partitions definitions can contain duplicate keys (will throw error in 1.3.0) # In the meantime, relying on get_num_partitions to handle duplicates to display # correct counts in the Dagster UI. dimension_counts = [ dim.partitions_def.get_num_partitions( current_time=current_time, dynamic_partitions_store=dynamic_partitions_store ) for dim in self.partitions_defs ] return reduce(lambda x, y: x * y, dimension_counts, 1)
def get_tags_from_multi_partition_key(multi_partition_key: MultiPartitionKey) -> Mapping[str, str]: check.inst_param(multi_partition_key, "multi_partition_key", MultiPartitionKey) return { get_multidimensional_partition_tag(dimension.dimension_name): dimension.partition_key for dimension in multi_partition_key.dimension_keys } def get_multipartition_key_from_tags(tags: Mapping[str, str]) -> str: partitions_by_dimension: Dict[str, str] = {} for tag in tags: if tag.startswith(MULTIDIMENSIONAL_PARTITION_PREFIX): dimension = tag[len(MULTIDIMENSIONAL_PARTITION_PREFIX) :] partitions_by_dimension[dimension] = tags[tag] return MultiPartitionKey(partitions_by_dimension)