import os
import shutil
import sys
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from typing import IO, Generator, Iterator, Mapping, Optional, Sequence, Tuple
from typing_extensions import Final
from watchdog.events import PatternMatchingEventHandler
from watchdog.observers.polling import PollingObserver
from dagster import (
Field,
Float,
StringSource,
_check as check,
)
from dagster._config.config_schema import UserConfigSchema
from dagster._core.execution.compute_logs import mirror_stream_to_file
from dagster._core.storage.compute_log_manager import (
CapturedLogContext,
CapturedLogData,
CapturedLogMetadata,
CapturedLogSubscription,
ComputeIOType,
ComputeLogManager,
)
from dagster._serdes import ConfigurableClass, ConfigurableClassData
from dagster._seven import json
from dagster._utils import ensure_dir, ensure_file, touch_file
from dagster._utils.security import non_secure_md5_hash_str
DEFAULT_WATCHDOG_POLLING_TIMEOUT: Final = 2.5
IO_TYPE_EXTENSION: Final[Mapping[ComputeIOType, str]] = {
ComputeIOType.STDOUT: "out",
ComputeIOType.STDERR: "err",
}
MAX_FILENAME_LENGTH: Final = 255
[docs]
class LocalComputeLogManager(ComputeLogManager, ConfigurableClass):
"""Stores copies of stdout & stderr for each compute step locally on disk."""
def __init__(
self,
base_dir: str,
polling_timeout: Optional[float] = None,
inst_data: Optional[ConfigurableClassData] = None,
):
self._base_dir = base_dir
self._polling_timeout = check.opt_float_param(
polling_timeout, "polling_timeout", DEFAULT_WATCHDOG_POLLING_TIMEOUT
)
self._subscription_manager = LocalComputeLogSubscriptionManager(self)
self._inst_data = check.opt_inst_param(inst_data, "inst_data", ConfigurableClassData)
@property
def inst_data(self) -> Optional[ConfigurableClassData]:
return self._inst_data
@property
def polling_timeout(self) -> float:
return self._polling_timeout
@classmethod
def config_type(cls) -> UserConfigSchema:
return {
"base_dir": StringSource,
"polling_timeout": Field(Float, is_required=False),
}
@classmethod
def from_config_value(
cls, inst_data: Optional[ConfigurableClassData], config_value
) -> "LocalComputeLogManager":
return LocalComputeLogManager(inst_data=inst_data, **config_value)
@contextmanager
def capture_logs(self, log_key: Sequence[str]) -> Generator[CapturedLogContext, None, None]:
outpath = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT])
errpath = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR])
with mirror_stream_to_file(sys.stdout, outpath), mirror_stream_to_file(sys.stderr, errpath):
yield CapturedLogContext(log_key)
# leave artifact on filesystem so that we know the capture is completed
touch_file(self.complete_artifact_path(log_key))
@contextmanager
def open_log_stream(
self, log_key: Sequence[str], io_type: ComputeIOType
) -> Iterator[Optional[IO]]:
path = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type])
ensure_file(path)
with open(path, "+a", encoding="utf-8") as f:
yield f
def is_capture_complete(self, log_key: Sequence[str]) -> bool:
return os.path.exists(self.complete_artifact_path(log_key))
def get_log_data(
self, log_key: Sequence[str], cursor: Optional[str] = None, max_bytes: Optional[int] = None
) -> CapturedLogData:
stdout_cursor, stderr_cursor = self.parse_cursor(cursor)
stdout, stdout_offset = self._read_bytes(
log_key, ComputeIOType.STDOUT, offset=stdout_cursor, max_bytes=max_bytes
)
stderr, stderr_offset = self._read_bytes(
log_key, ComputeIOType.STDERR, offset=stderr_cursor, max_bytes=max_bytes
)
return CapturedLogData(
log_key=log_key,
stdout=stdout,
stderr=stderr,
cursor=self.build_cursor(stdout_offset, stderr_offset),
)
def get_log_metadata(self, log_key: Sequence[str]) -> CapturedLogMetadata:
return CapturedLogMetadata(
stdout_location=self.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT]
),
stderr_location=self.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR]
),
stdout_download_url=self.get_captured_log_download_url(log_key, ComputeIOType.STDOUT),
stderr_download_url=self.get_captured_log_download_url(log_key, ComputeIOType.STDERR),
)
def delete_logs(
self, log_key: Optional[Sequence[str]] = None, prefix: Optional[Sequence[str]] = None
):
if log_key:
paths = [
self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT]),
self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR]),
self.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT], partial=True
),
self.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR], partial=True
),
self.get_captured_local_path(log_key, "complete"),
]
for path in paths:
if os.path.exists(path) and os.path.isfile(path):
os.remove(path)
elif prefix:
dir_to_delete = os.path.join(self._base_dir, *prefix)
if os.path.exists(dir_to_delete) and os.path.isdir(dir_to_delete):
# recursively delete all files in dir
shutil.rmtree(dir_to_delete)
else:
check.failed("Must pass in either `log_key` or `prefix` argument to delete_logs")
def _read_bytes(
self,
log_key: Sequence[str],
io_type: ComputeIOType,
offset: Optional[int] = 0,
max_bytes: Optional[int] = None,
):
path = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type])
return self.read_path(path, offset or 0, max_bytes)
def parse_cursor(self, cursor: Optional[str] = None) -> Tuple[int, int]:
# Translates a string cursor into a set of byte offsets for stdout, stderr
if not cursor:
return 0, 0
parts = cursor.split(":")
if not parts or len(parts) != 2:
return 0, 0
stdout, stderr = [int(_) for _ in parts]
return stdout, stderr
def build_cursor(self, stdout_offset: int, stderr_offset: int) -> str:
return f"{stdout_offset}:{stderr_offset}"
def complete_artifact_path(self, log_key):
return self.get_captured_local_path(log_key, "complete")
def read_path(
self,
path: str,
offset: int = 0,
max_bytes: Optional[int] = None,
):
if not os.path.exists(path) or not os.path.isfile(path):
return None, offset
with open(path, "rb") as f:
f.seek(offset, os.SEEK_SET)
if max_bytes is None:
data = f.read()
else:
data = f.read(max_bytes)
new_offset = f.tell()
return data, new_offset
def get_captured_log_download_url(self, log_key, io_type):
check.inst_param(io_type, "io_type", ComputeIOType)
url = "/logs"
for part in log_key:
url = f"{url}/{part}"
return f"{url}/{IO_TYPE_EXTENSION[io_type]}"
def get_captured_local_path(self, log_key: Sequence[str], extension: str, partial=False):
[*namespace, filebase] = log_key
filename = f"{filebase}.{extension}"
if partial:
filename = f"{filename}.partial"
if len(filename) > MAX_FILENAME_LENGTH:
filename = "{}.{}".format(non_secure_md5_hash_str(filebase.encode("utf-8")), extension)
base_dir_path = Path(self._base_dir).resolve()
log_path = base_dir_path.joinpath(*namespace, filename).resolve()
if base_dir_path not in log_path.parents:
raise ValueError("Invalid path")
return str(log_path)
def subscribe(
self, log_key: Sequence[str], cursor: Optional[str] = None
) -> CapturedLogSubscription:
subscription = CapturedLogSubscription(self, log_key, cursor)
self._subscription_manager.add_subscription(subscription)
return subscription
def unsubscribe(self, subscription):
self._subscription_manager.remove_subscription(subscription)
def get_log_keys_for_log_key_prefix(
self, log_key_prefix: Sequence[str], io_type: ComputeIOType
) -> Sequence[Sequence[str]]:
"""Returns the logs keys for a given log key prefix. This is determined by looking at the
directory defined by the log key prefix and creating a log_key for each file in the directory.
"""
base_dir_path = Path(self._base_dir).resolve()
directory = base_dir_path.joinpath(*log_key_prefix)
objects = directory.iterdir()
results = []
list_key_prefix = list(log_key_prefix)
for obj in objects:
if obj.is_file() and obj.suffix == "." + IO_TYPE_EXTENSION[io_type]:
results.append(list_key_prefix + [obj.stem])
return results
def dispose(self) -> None:
self._subscription_manager.dispose()
class LocalComputeLogSubscriptionManager:
def __init__(self, manager):
self._manager = manager
self._subscriptions = defaultdict(list)
self._watchers = {}
self._observer = None
def add_subscription(self, subscription: CapturedLogSubscription) -> None:
check.inst_param(subscription, "subscription", CapturedLogSubscription)
if self.is_complete(subscription):
subscription.fetch()
subscription.complete()
else:
log_key = self._log_key(subscription)
watch_key = self._watch_key(log_key)
self._subscriptions[watch_key].append(subscription)
self.watch(subscription)
def is_complete(self, subscription: CapturedLogSubscription) -> bool:
check.inst_param(subscription, "subscription", CapturedLogSubscription)
return self._manager.is_capture_complete(subscription.log_key)
def remove_subscription(self, subscription: CapturedLogSubscription) -> None:
check.inst_param(subscription, "subscription", CapturedLogSubscription)
log_key = self._log_key(subscription)
watch_key = self._watch_key(log_key)
if subscription in self._subscriptions[watch_key]:
self._subscriptions[watch_key].remove(subscription)
subscription.complete()
def _log_key(self, subscription: CapturedLogSubscription) -> Sequence[str]:
check.inst_param(subscription, "subscription", CapturedLogSubscription)
return subscription.log_key
def _watch_key(self, log_key: Sequence[str]) -> str:
return json.dumps(log_key)
def remove_all_subscriptions(self, log_key: Sequence[str]) -> None:
watch_key = self._watch_key(log_key)
for subscription in self._subscriptions.pop(watch_key, []):
subscription.complete()
def watch(self, subscription: CapturedLogSubscription) -> None:
log_key = self._log_key(subscription)
watch_key = self._watch_key(log_key)
if watch_key in self._watchers:
return
update_paths = [
self._manager.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT]),
self._manager.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR]),
self._manager.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT], partial=True
),
self._manager.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR], partial=True
),
]
complete_paths = [self._manager.complete_artifact_path(log_key)]
directory = os.path.dirname(
self._manager.get_captured_local_path(log_key, ComputeIOType.STDERR),
)
if not self._observer:
self._observer = PollingObserver(timeout=self._manager.polling_timeout)
self._observer.start()
ensure_dir(directory)
self._watchers[watch_key] = self._observer.schedule(
LocalComputeLogFilesystemEventHandler(self, log_key, update_paths, complete_paths),
str(directory),
)
def notify_subscriptions(self, log_key: Sequence[str]) -> None:
watch_key = self._watch_key(log_key)
for subscription in self._subscriptions[watch_key]:
subscription.fetch()
def unwatch(self, log_key: Sequence[str], handler) -> None:
watch_key = self._watch_key(log_key)
if watch_key in self._watchers:
self._observer.remove_handler_for_watch(handler, self._watchers[watch_key]) # type: ignore
del self._watchers[watch_key]
def dispose(self) -> None:
if self._observer:
self._observer.stop()
self._observer.join(15)
class LocalComputeLogFilesystemEventHandler(PatternMatchingEventHandler):
def __init__(self, manager, log_key, update_paths, complete_paths):
self.manager = manager
self.log_key = log_key
self.update_paths = update_paths
self.complete_paths = complete_paths
patterns = update_paths + complete_paths
super(LocalComputeLogFilesystemEventHandler, self).__init__(patterns=patterns)
def on_created(self, event):
if event.src_path in self.complete_paths:
self.manager.remove_all_subscriptions(self.log_key)
self.manager.unwatch(self.log_key, self)
def on_modified(self, event):
if event.src_path in self.update_paths:
self.manager.notify_subscriptions(self.log_key)