import _thread as thread
import contextlib
import contextvars
import datetime
import errno
import functools
import inspect
import multiprocessing
import os
import re
import signal
import socket
import subprocess
import sys
import tempfile
import threading
import time
import uuid
from datetime import timezone
from enum import Enum
from pathlib import Path
from signal import Signals
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
ContextManager,
Dict,
Generator,
Generic,
Hashable,
Iterable,
Iterator,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
import packaging.version
from filelock import FileLock
from pydantic import BaseModel
from typing_extensions import Literal, TypeAlias, TypeGuard
import dagster._check as check
import dagster._seven as seven
from dagster._utils.internal_init import IHasInternalInit as IHasInternalInit
if sys.version_info > (3,):
from pathlib import Path
else:
from pathlib2 import Path
if TYPE_CHECKING:
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.repository_definition.repository_definition import (
RepositoryDefinition,
)
from dagster._core.events import DagsterEvent
K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")
EPOCH = datetime.datetime.fromtimestamp(0, timezone.utc).replace(tzinfo=None)
PICKLE_PROTOCOL = 4
DEFAULT_WORKSPACE_YAML_FILENAME = "workspace.yaml"
PrintFn: TypeAlias = Callable[[Any], None]
SingleInstigatorDebugCrashFlags: TypeAlias = Mapping[str, Union[int, Exception]]
DebugCrashFlags: TypeAlias = Mapping[str, SingleInstigatorDebugCrashFlags]
def check_for_debug_crash(
debug_crash_flags: Optional[SingleInstigatorDebugCrashFlags], key: str
) -> None:
if not debug_crash_flags:
return
kill_signal_or_exception = debug_crash_flags.get(key)
if not kill_signal_or_exception:
return
if isinstance(kill_signal_or_exception, Exception):
raise kill_signal_or_exception
os.kill(os.getpid(), kill_signal_or_exception)
time.sleep(10)
raise Exception("Process didn't terminate after sending crash signal")
# Use this to get the "library version" (pre-1.0 version) from the "core version" (post 1.0
# version). 16 is from the 0.16.0 that library versions stayed on when core went to 1.0.0.
def library_version_from_core_version(core_version: str) -> str:
parsed_version = parse_package_version(core_version)
release = parsed_version.release
if release[0] >= 1:
library_version = ".".join(["0", str(16 + release[1]), str(release[2])])
if parsed_version.is_prerelease:
library_version = library_version + "".join(
[str(pre) for pre in check.not_none(parsed_version.pre)]
)
if parsed_version.is_postrelease:
library_version = library_version + "post" + str(parsed_version.post)
return library_version
else:
return core_version
def parse_package_version(version_str: str) -> packaging.version.Version:
parsed_version = packaging.version.parse(version_str)
assert isinstance(parsed_version, packaging.version.Version)
return parsed_version
def convert_dagster_submodule_name(name: str, mode: Literal["private", "public"]) -> str:
"""This function was introduced when all Dagster submodules were marked private by
underscore-prefixing the root submodules (e.g. `dagster._core`). The function provides
backcompatibility by converting modules between the old and new (i.e. public and private) forms.
This is needed when reading older data or communicating with older versions of Dagster.
"""
if mode == "private":
return re.sub(r"^dagster\.([^_])", r"dagster._\1", name)
elif mode == "public":
return re.sub(r"^dagster._", "dagster.", name)
else:
check.failed("`mode` must be 'private' or 'public'")
[docs]
def file_relative_path(dunderfile: str, relative_path: str) -> str:
"""Get a path relative to the currently executing Python file.
This function is useful when one needs to load a file that is relative to the position of
the current file. (Such as when you encode a configuration file path in source file and want
in runnable in any current working directory)
Args:
dunderfile (str): Should always be ``__file__``.
relative_path (str): Path to get relative to the currently executing file.
**Examples**:
.. code-block:: python
file_relative_path(__file__, 'path/relative/to/file')
"""
check.str_param(dunderfile, "dunderfile")
check.str_param(relative_path, "relative_path")
return os.fspath(Path(dunderfile, "..", relative_path).resolve())
def script_relative_path(file_path: str) -> str:
"""Useful for testing with local files. Use a path relative to where the
test resides and this function will return the absolute path
of that file. Otherwise it will be relative to script that
ran the test.
Note: this is function is very, very expensive (on the order of 1
millisecond per invocation) so this should only be used in performance
insensitive contexts. Prefer file_relative_path for anything with
performance constraints.
"""
# from http://bit.ly/2snyC6s
check.str_param(file_path, "file_path")
scriptdir = inspect.stack()[1][1]
return os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(scriptdir)), file_path))
# Adapted from https://github.com/okunishinishi/python-stringcase/blob/master/stringcase.py
def camelcase(string: str) -> str:
check.str_param(string, "string")
string = re.sub(r"^[\-_\.]", "", str(string))
if not string:
return string
return str(string[0]).upper() + re.sub(
r"[\-_\.\s]([a-z])", lambda matched: str(matched.group(1)).upper(), string[1:]
)
def ensure_single_item(ddict: Mapping[T, U]) -> Tuple[T, U]:
check.mapping_param(ddict, "ddict")
check.param_invariant(len(ddict) == 1, "ddict", "Expected dict with single item")
return next(iter(ddict.items()))
@contextlib.contextmanager
def pushd(path: str) -> Iterator[str]:
old_cwd = os.getcwd()
os.chdir(path)
try:
yield path
finally:
os.chdir(old_cwd)
def safe_isfile(path: str) -> bool:
"""Backport of Python 3.8 os.path.isfile behavior.
This is intended to backport https://docs.python.org/dev/whatsnew/3.8.html#os-path. I'm not
sure that there are other ways to provoke this behavior on Unix other than the null byte,
but there are certainly other ways to do it on Windows. Afaict, we won't mask other
ValueErrors, and the behavior in the status quo ante is rough because we risk throwing an
unexpected, uncaught ValueError from very deep in our logic.
"""
try:
return os.path.isfile(path)
except ValueError:
return False
def mkdir_p(path: str) -> str:
try:
os.makedirs(path)
return path
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
return path
else:
raise
def hash_collection(
collection: Union[
Mapping[Hashable, Any], Sequence[Any], AbstractSet[Any], Tuple[Any, ...], NamedTuple
],
) -> int:
"""Hash a mutable collection or immutable collection containing mutable elements.
This is useful for hashing Dagster-specific NamedTuples that contain mutable lists or dicts.
The default NamedTuple __hash__ function assumes the contents of the NamedTuple are themselves
hashable, and will throw an error if they are not. This can occur when trying to e.g. compute a
cache key for the tuple for use with `lru_cache`.
This alternative implementation will recursively process collection elements to convert basic
lists and dicts to tuples prior to hashing. It is recommended to cache the result:
Example:
.. code-block:: python
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash_named_tuple(self)
return self._hash
"""
assert isinstance(
collection, (list, dict, set, tuple)
), f"Cannot hash collection of type {type(collection)}"
return hash(make_hashable(collection))
@overload
def make_hashable(value: Union[List[Any], Set[Any]]) -> Tuple[Any, ...]: ...
@overload
def make_hashable(value: Dict[Any, Any]) -> Tuple[Tuple[Any, Any]]: ...
@overload
def make_hashable(value: Any) -> Any: ...
def make_hashable(value: Any) -> Any:
from dagster._record import as_dict, is_record
if isinstance(value, dict):
return tuple(sorted((key, make_hashable(value)) for key, value in value.items()))
elif is_record(value):
return tuple(make_hashable(value) for value in as_dict(value).values())
elif isinstance(value, (list, tuple, set)):
return tuple([make_hashable(x) for x in value])
elif isinstance(value, BaseModel):
return make_hashable(value.dict())
else:
return value
def get_prop_or_key(elem: object, key: str) -> object:
if isinstance(elem, Mapping):
return elem.get(key)
else:
return getattr(elem, key)
def list_pull(alist: Iterable[object], key: str) -> Sequence[object]:
return list(map(lambda elem: get_prop_or_key(elem, key), alist))
def all_none(kwargs: Mapping[object, object]) -> bool:
for value in kwargs.values():
if value is not None:
return False
return True
def check_script(path: str, return_code: int = 0) -> None:
try:
subprocess.check_output([sys.executable, path])
except subprocess.CalledProcessError as exc:
if return_code != 0:
if exc.returncode == return_code:
return
raise
def check_cli_execute_file_job(
path: str, pipeline_fn_name: str, env_file: Optional[str] = None
) -> None:
from dagster._core.test_utils import instance_for_test
with instance_for_test():
cli_cmd = [
sys.executable,
"-m",
"dagster",
"pipeline",
"execute",
"-f",
path,
"-a",
pipeline_fn_name,
]
if env_file:
cli_cmd.append("-c")
cli_cmd.append(env_file)
try:
subprocess.check_output(cli_cmd)
except subprocess.CalledProcessError as cpe:
print(cpe) # noqa: T201
raise cpe
def safe_tempfile_path_unmanaged() -> str:
# This gets a valid temporary file path in the safest possible way, although there is still no
# guarantee that another process will not create a file at this path. The NamedTemporaryFile is
# deleted when the context manager exits and the file object is closed.
#
# This is preferable to using NamedTemporaryFile as a context manager and passing the name
# attribute of the file object around because NamedTemporaryFiles cannot be opened a second time
# if already open on Windows NT or later:
# https://docs.python.org/3.8/library/tempfile.html#tempfile.NamedTemporaryFile
# https://github.com/dagster-io/dagster/issues/1582
with tempfile.NamedTemporaryFile() as fd:
path = fd.name
return Path(path).as_posix()
@contextlib.contextmanager
def safe_tempfile_path() -> Iterator[str]:
path = None
try:
path = safe_tempfile_path_unmanaged()
yield path
finally:
if path is not None and os.path.exists(path):
os.unlink(path)
@overload
def ensure_gen(thing_or_gen: Generator[T, Any, Any]) -> Generator[T, Any, Any]:
pass
@overload
def ensure_gen(thing_or_gen: T) -> Generator[T, Any, Any]:
pass
def ensure_gen(
thing_or_gen: Union[T, Iterator[T], Generator[T, Any, Any]],
) -> Generator[T, Any, Any]:
if not inspect.isgenerator(thing_or_gen):
thing_or_gen = cast(T, thing_or_gen)
def _gen_thing():
yield thing_or_gen
return _gen_thing()
return thing_or_gen
def ensure_dir(file_path: str) -> str:
try:
os.makedirs(file_path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
return file_path
def ensure_file(path: str) -> str:
ensure_dir(os.path.dirname(path))
if not os.path.exists(path):
touch_file(path)
return path
def touch_file(path: str) -> None:
ensure_dir(os.path.dirname(path))
with open(path, "a", encoding="utf8"):
os.utime(path, None)
def _termination_handler(
should_stop_event: threading.Event,
is_done_event: threading.Event,
) -> None:
should_stop_event.wait()
if not is_done_event.is_set():
# if we should stop but are not yet done, interrupt the MainThread
send_interrupt()
def send_interrupt() -> None:
if seven.IS_WINDOWS:
# This will raise a KeyboardInterrupt in python land - meaning this wont be able to
# interrupt things like sleep()
thread.interrupt_main()
else:
# If on unix send an os level signal to interrupt any situation we may be stuck in
os.kill(os.getpid(), signal.SIGINT)
# Function to be invoked by daemon thread in processes which seek to be cancellable.
# The motivation for this approach is to be able to exit cleanly on Windows. An alternative
# path is to change how the processes are opened and send CTRL_BREAK signals, which at
# the time of authoring seemed a more costly approach.
#
# Reading for the curious:
# * https://stackoverflow.com/questions/35772001/how-to-handle-the-signal-in-python-on-windows-machine
# * https://stefan.sofa-rockers.org/2013/08/15/handling-sub-process-hierarchies-python-linux-os-x/
def start_termination_thread(
should_stop_event: threading.Event, is_done_event: threading.Event
) -> None:
check.inst_param(should_stop_event, "should_stop_event", ttype=type(multiprocessing.Event()))
int_thread = threading.Thread(
target=_termination_handler,
args=(should_stop_event, is_done_event),
name="termination-handler",
daemon=True,
)
int_thread.start()
# Executes the next() function within an instance of the supplied context manager class
# (leaving the context before yielding each result)
def iterate_with_context(
context_fn: Callable[[], ContextManager[Any]], iterator: Iterator[T]
) -> Iterator[T]:
while True:
# Allow interrupts during user code so that we can terminate slow/hanging steps
with context_fn():
try:
next_output = next(iterator)
except StopIteration:
return
yield next_output
T_GeneratedContext = TypeVar("T_GeneratedContext")
class EventGenerationManager(Generic[T_GeneratedContext]):
"""Utility class that wraps an event generator function, that also yields a single instance of
a typed object. All events yielded before the typed object are yielded through the method
`generate_setup_events` and all events yielded after the typed object are yielded through the
method `generate_teardown_events`.
This is used to help replace the context managers used in pipeline initialization with
generators so that we can begin emitting initialization events AND construct a pipeline context
object, while managing explicit setup/teardown.
This does require calling `generate_setup_events` AND `generate_teardown_events` in order to
get the typed object.
"""
def __init__(
self,
generator: Iterator[Union["DagsterEvent", T_GeneratedContext]],
object_cls: Type[T_GeneratedContext],
require_object: Optional[bool] = True,
):
self.generator = check.generator(generator)
self.object_cls: Type[T_GeneratedContext] = check.class_param(object_cls, "object_cls")
self.require_object = check.bool_param(require_object, "require_object")
self.object: Optional[T_GeneratedContext] = None
self.did_setup = False
self.did_teardown = False
def generate_setup_events(self) -> Iterator["DagsterEvent"]:
self.did_setup = True
try:
while self.object is None:
obj = next(self.generator)
if isinstance(obj, self.object_cls):
self.object = obj
else:
yield obj
except StopIteration:
if self.require_object:
check.inst_param(
self.object,
"self.object",
self.object_cls,
f"generator never yielded object of type {self.object_cls.__name__}",
)
def get_object(self) -> T_GeneratedContext:
if not self.did_setup:
check.failed("Called `get_object` before `generate_setup_events`")
return cast(T_GeneratedContext, self.object)
def generate_teardown_events(self) -> Iterator["DagsterEvent"]:
self.did_teardown = True
if self.object:
yield from self.generator
def is_enum_value(value: object) -> bool:
return False if value is None else issubclass(value.__class__, Enum)
def git_repository_root() -> str:
return subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode("utf-8").strip()
def segfault() -> None:
"""Reliable cross-Python version segfault.
https://bugs.python.org/issue1215#msg143236
"""
import ctypes
ctypes.string_at(0)
def find_free_port() -> int:
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
def is_port_in_use(host, port) -> bool:
# Similar to the socket options that uvicorn uses to bind ports:
# https://github.com/encode/uvicorn/blob/62f19c1c39929c84968712c371c9b7b96a041dec/uvicorn/config.py#L565-L566
sock = socket.socket(family=socket.AF_INET)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind((host, port))
return False
except socket.error as e:
return e.errno == errno.EADDRINUSE
finally:
sock.close()
@contextlib.contextmanager
def alter_sys_path(to_add: Sequence[str], to_remove: Sequence[str]) -> Iterator[None]:
to_restore = [path for path in sys.path]
# remove paths
for path in to_remove:
if path in sys.path:
sys.path.remove(path)
# add paths
for path in to_add:
sys.path.insert(0, path)
try:
yield
finally:
sys.path = to_restore
@contextlib.contextmanager
def restore_sys_modules() -> Iterator[None]:
sys_modules = {k: v for k, v in sys.modules.items()}
try:
yield
finally:
to_delete = set(sys.modules) - set(sys_modules)
for key in to_delete:
del sys.modules[key]
def process_is_alive(pid: int) -> bool:
if seven.IS_WINDOWS:
import psutil
return psutil.pid_exists(pid=pid)
# https://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid-in-python
if pid < 0:
return False
if pid == 0:
# According to "man 2 kill" PID 0 refers to every process
# in the process group of the calling process.
# On certain systems 0 is a valid PID but we have no way
# to know that in a portable fashion.
raise ValueError("invalid PID 0")
try:
os.kill(pid, 0)
except OSError as err:
if err.errno == errno.ESRCH:
# ESRCH == No such process
return False
elif err.errno == errno.EPERM:
# EPERM clearly means there's a process to deny access to
return True
else:
# According to "man 2 kill" possible error values are
# (EINVAL, EPERM, ESRCH)
raise
else:
return True
def compose(*args: Callable[[object], object]) -> Callable[[object], object]:
"""Compose python functions args such that compose(f, g)(x) is equivalent to f(g(x)).""" # noqa: D402
# reduce using functional composition over all the arguments, with the identity function as
# initializer
return functools.reduce(lambda f, g: lambda x: f(g(x)), args, lambda x: x)
def dict_without_keys(ddict: Mapping[K, V], *keys: K) -> Dict[K, V]:
return {key: value for key, value in ddict.items() if key not in set(keys)}
class Counter:
def __init__(self):
self._lock = threading.Lock()
self._counts = {}
super(Counter, self).__init__()
def increment(self, key: str) -> None:
with self._lock:
self._counts[key] = self._counts.get(key, 0) + 1
def counts(self) -> Mapping[str, int]:
with self._lock:
copy = {k: v for k, v in self._counts.items()}
return copy
traced_counter: contextvars.ContextVar[Optional[Counter]] = contextvars.ContextVar(
"traced_counts",
default=None,
)
T_Callable = TypeVar("T_Callable", bound=Callable)
def traced(func: T_Callable) -> T_Callable:
"""A decorator that keeps track of how many times a function is called."""
@functools.wraps(func)
def inner(*args, **kwargs):
counter = traced_counter.get()
if counter and isinstance(counter, Counter):
counter.increment(func.__qualname__)
return func(*args, **kwargs)
return cast(T_Callable, inner)
def get_terminate_signal() -> signal.Signals:
if sys.platform == "win32":
return signal.SIGTERM
return signal.SIGKILL
def get_run_crash_explanation(prefix: str, exit_code: int) -> str:
# As per https://docs.python.org/3/library/subprocess.html#subprocess.CompletedProcess.returncode
# negative exit code means a posix signal
if exit_code < 0 and -exit_code in [signal.value for signal in Signals]:
posix_signal = -exit_code
signal_str = Signals(posix_signal).name
exit_clause = f"was terminated by signal {posix_signal} ({signal_str})."
if posix_signal == get_terminate_signal():
exit_clause = (
exit_clause + " This usually indicates that the process was"
" killed by the operating system due to running out of"
" memory. Possible solutions include increasing the"
" amount of memory available to the run, reducing"
" the amount of memory used by the ops in the run, or"
" configuring the executor to run fewer ops concurrently."
)
else:
exit_clause = f"unexpectedly exited with code {exit_code}."
return prefix + " " + exit_clause
def last_file_comp(path: str) -> str:
return os.path.basename(os.path.normpath(path))
def is_named_tuple_instance(obj: object) -> TypeGuard[NamedTuple]:
return isinstance(obj, tuple) and hasattr(obj, "_fields")
def is_named_tuple_subclass(klass: Type[object]) -> TypeGuard[Type[NamedTuple]]:
return isinstance(klass, type) and issubclass(klass, tuple) and hasattr(klass, "_fields")
@overload
def normalize_to_repository(
definitions_or_repository: Optional[Union["Definitions", "RepositoryDefinition"]] = ...,
repository: Optional["RepositoryDefinition"] = ...,
error_on_none: Literal[True] = ...,
) -> "RepositoryDefinition": ...
@overload
def normalize_to_repository(
definitions_or_repository: Optional[Union["Definitions", "RepositoryDefinition"]] = ...,
repository: Optional["RepositoryDefinition"] = ...,
error_on_none: Literal[False] = ...,
) -> Optional["RepositoryDefinition"]: ...
def normalize_to_repository(
definitions_or_repository: Optional[Union["Definitions", "RepositoryDefinition"]] = None,
repository: Optional["RepositoryDefinition"] = None,
error_on_none: bool = True,
) -> Optional["RepositoryDefinition"]:
"""Normalizes the arguments that take a RepositoryDefinition or Definitions object to a
RepositoryDefinition.
This is intended to handle both the case where a single argument takes a
`Union[RepositoryDefinition, Definitions]` or separate keyword arguments accept
`RepositoryDefinition` or `Definitions`.
"""
from dagster._core.definitions.definitions_class import Definitions
if (definitions_or_repository and repository) or (
error_on_none and not (definitions_or_repository or repository)
):
check.failed("Exactly one of `definitions` or `repository_def` must be provided.")
elif isinstance(definitions_or_repository, Definitions):
return definitions_or_repository.get_repository_def()
elif definitions_or_repository:
return definitions_or_repository
elif repository:
return repository
else:
return None
def xor(a: object, b: object) -> bool:
return bool(a) != bool(b)
def tail_file(path_or_fd: Union[str, int], should_stop: Callable[[], bool]) -> Iterator[str]:
with open(path_or_fd, "r") as output_stream:
while True:
line = output_stream.readline()
if line:
yield line
elif should_stop():
break
else:
time.sleep(0.01)
def is_uuid(value: str) -> bool:
try:
uuid.UUID(value)
return True
except ValueError:
return False
def run_with_concurrent_update_guard(
target_file_path: Path,
update_fn: Callable[..., None],
*,
guard_timeout_seconds: float = 60,
**kwargs,
) -> None:
"""This function prevents multiple processes attempting to update the same target artifacts
from running concurrently. It uses a lock file to ensure that only one process can update the
target file at a time.
If the target file has been updated by another process while waiting for the lock, we skip
running the update_fn, assuming we are about to do redundant work.
Args:
target_file_path (Path): The path to the target file that needs to be updated.
update_fn (Callable[[Any], None]): The function that will update the target file.
guard_timeout_seconds (float): The maximum time to wait for the lock to be released.
Default: 60 seconds.
**kwargs: The keyword arguments to pass to the function.
"""
start_mtime = 0
if target_file_path.exists():
start_mtime = target_file_path.lstat().st_mtime
with FileLock(target_file_path.with_suffix(".concurrent-update-lock")).acquire(
timeout=guard_timeout_seconds
):
# double check if the target file has been updated by another process while waiting for lock
if target_file_path.exists() and target_file_path.lstat().st_mtime > start_mtime:
return
update_fn(**kwargs)
return