Ask AI

Source code for dagster._core.storage.local_compute_log_manager

import os
import shutil
import sys
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from typing import IO, TYPE_CHECKING, Generator, Iterator, Mapping, Optional, Sequence, Tuple

from typing_extensions import Final
from watchdog.events import PatternMatchingEventHandler
from watchdog.observers.polling import PollingObserver

from dagster import (
    Field,
    Float,
    StringSource,
    _check as check,
)
from dagster._config.config_schema import UserConfigSchema
from dagster._core.execution.compute_logs import mirror_stream_to_file
from dagster._core.storage.dagster_run import DagsterRun
from dagster._serdes import ConfigurableClass, ConfigurableClassData
from dagster._seven import json
from dagster._utils import ensure_dir, ensure_file, touch_file
from dagster._utils.security import non_secure_md5_hash_str

from .captured_log_manager import (
    CapturedLogContext,
    CapturedLogData,
    CapturedLogManager,
    CapturedLogMetadata,
    CapturedLogSubscription,
)
from .compute_log_manager import (
    MAX_BYTES_FILE_READ,
    ComputeIOType,
    ComputeLogFileData,
    ComputeLogManager,
    ComputeLogSubscription,
)

if TYPE_CHECKING:
    from dagster._core.storage.cloud_storage_compute_log_manager import LogSubscription

DEFAULT_WATCHDOG_POLLING_TIMEOUT: Final = 2.5

IO_TYPE_EXTENSION: Final[Mapping[ComputeIOType, str]] = {
    ComputeIOType.STDOUT: "out",
    ComputeIOType.STDERR: "err",
}

MAX_FILENAME_LENGTH: Final = 255


[docs]class LocalComputeLogManager(CapturedLogManager, ComputeLogManager, ConfigurableClass): """Stores copies of stdout & stderr for each compute step locally on disk.""" def __init__( self, base_dir: str, polling_timeout: Optional[float] = None, inst_data: Optional[ConfigurableClassData] = None, ): self._base_dir = base_dir self._polling_timeout = check.opt_float_param( polling_timeout, "polling_timeout", DEFAULT_WATCHDOG_POLLING_TIMEOUT ) self._subscription_manager = LocalComputeLogSubscriptionManager(self) self._inst_data = check.opt_inst_param(inst_data, "inst_data", ConfigurableClassData) @property def inst_data(self) -> Optional[ConfigurableClassData]: return self._inst_data @property def polling_timeout(self) -> float: return self._polling_timeout @classmethod def config_type(cls) -> UserConfigSchema: return { "base_dir": StringSource, "polling_timeout": Field(Float, is_required=False), } @classmethod def from_config_value( cls, inst_data: Optional[ConfigurableClassData], config_value ) -> "LocalComputeLogManager": return LocalComputeLogManager(inst_data=inst_data, **config_value) @contextmanager def capture_logs(self, log_key: Sequence[str]) -> Generator[CapturedLogContext, None, None]: outpath = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT]) errpath = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR]) with mirror_stream_to_file(sys.stdout, outpath), mirror_stream_to_file(sys.stderr, errpath): yield CapturedLogContext(log_key) # leave artifact on filesystem so that we know the capture is completed touch_file(self.complete_artifact_path(log_key)) @contextmanager def open_log_stream( self, log_key: Sequence[str], io_type: ComputeIOType ) -> Iterator[Optional[IO]]: path = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type]) ensure_file(path) with open(path, "+a", encoding="utf-8") as f: yield f def is_capture_complete(self, log_key: Sequence[str]) -> bool: return os.path.exists(self.complete_artifact_path(log_key)) def get_log_data( self, log_key: Sequence[str], cursor: Optional[str] = None, max_bytes: Optional[int] = None ) -> CapturedLogData: stdout_cursor, stderr_cursor = self.parse_cursor(cursor) stdout, stdout_offset = self._read_bytes( log_key, ComputeIOType.STDOUT, offset=stdout_cursor, max_bytes=max_bytes ) stderr, stderr_offset = self._read_bytes( log_key, ComputeIOType.STDERR, offset=stderr_cursor, max_bytes=max_bytes ) return CapturedLogData( log_key=log_key, stdout=stdout, stderr=stderr, cursor=self.build_cursor(stdout_offset, stderr_offset), ) def get_log_metadata(self, log_key: Sequence[str]) -> CapturedLogMetadata: return CapturedLogMetadata( stdout_location=self.get_captured_local_path( log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT] ), stderr_location=self.get_captured_local_path( log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR] ), stdout_download_url=self.get_captured_log_download_url(log_key, ComputeIOType.STDOUT), stderr_download_url=self.get_captured_log_download_url(log_key, ComputeIOType.STDERR), ) def delete_logs( self, log_key: Optional[Sequence[str]] = None, prefix: Optional[Sequence[str]] = None ): if log_key: paths = [ self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT]), self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR]), self.get_captured_local_path( log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT], partial=True ), self.get_captured_local_path( log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR], partial=True ), self.get_captured_local_path(log_key, "complete"), ] for path in paths: if os.path.exists(path) and os.path.isfile(path): os.remove(path) elif prefix: dir_to_delete = os.path.join(self._base_dir, *prefix) if os.path.exists(dir_to_delete) and os.path.isdir(dir_to_delete): # recursively delete all files in dir shutil.rmtree(dir_to_delete) else: check.failed("Must pass in either `log_key` or `prefix` argument to delete_logs") def _read_bytes( self, log_key: Sequence[str], io_type: ComputeIOType, offset: Optional[int] = 0, max_bytes: Optional[int] = None, ): path = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type]) return self.read_path(path, offset or 0, max_bytes) def parse_cursor(self, cursor: Optional[str] = None) -> Tuple[int, int]: # Translates a string cursor into a set of byte offsets for stdout, stderr if not cursor: return 0, 0 parts = cursor.split(":") if not parts or len(parts) != 2: return 0, 0 stdout, stderr = [int(_) for _ in parts] return stdout, stderr def build_cursor(self, stdout_offset: int, stderr_offset: int) -> str: return f"{stdout_offset}:{stderr_offset}" def complete_artifact_path(self, log_key): return self.get_captured_local_path(log_key, "complete") def read_path( self, path: str, offset: int = 0, max_bytes: Optional[int] = None, ): if not os.path.exists(path) or not os.path.isfile(path): return None, offset with open(path, "rb") as f: f.seek(offset, os.SEEK_SET) if max_bytes is None: data = f.read() else: data = f.read(max_bytes) new_offset = f.tell() return data, new_offset def get_captured_log_download_url(self, log_key, io_type): check.inst_param(io_type, "io_type", ComputeIOType) url = "/logs" for part in log_key: url = f"{url}/{part}" return f"{url}/{IO_TYPE_EXTENSION[io_type]}" def get_captured_local_path(self, log_key: Sequence[str], extension: str, partial=False): [*namespace, filebase] = log_key filename = f"{filebase}.{extension}" if partial: filename = f"{filename}.partial" if len(filename) > MAX_FILENAME_LENGTH: filename = "{}.{}".format(non_secure_md5_hash_str(filebase.encode("utf-8")), extension) base_dir_path = Path(self._base_dir).resolve() log_path = base_dir_path.joinpath(*namespace, filename).resolve() if base_dir_path not in log_path.parents: raise ValueError("Invalid path") return str(log_path) def subscribe( self, log_key: Sequence[str], cursor: Optional[str] = None ) -> CapturedLogSubscription: subscription = CapturedLogSubscription(self, log_key, cursor) self.on_subscribe(subscription) return subscription def unsubscribe(self, subscription): self.on_unsubscribe(subscription) ############################################### # # Methods for the ComputeLogManager interface # ############################################### @contextmanager def _watch_logs( self, dagster_run: DagsterRun, step_key: Optional[str] = None ) -> Iterator[None]: check.inst_param(dagster_run, "dagster_run", DagsterRun) check.opt_str_param(step_key, "step_key") log_key = self.build_log_key_for_run(dagster_run.run_id, step_key or dagster_run.job_name) with self.capture_logs(log_key): yield def get_local_path(self, run_id: str, key: str, io_type: ComputeIOType) -> str: """Legacy adapter from compute log manager to more generic captured log manager API.""" check.inst_param(io_type, "io_type", ComputeIOType) log_key = self.build_log_key_for_run(run_id, key) return self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type]) def read_logs_file( self, run_id: str, key: str, io_type: ComputeIOType, cursor: int = 0, max_bytes: int = MAX_BYTES_FILE_READ, ) -> ComputeLogFileData: path = self.get_local_path(run_id, key, io_type) if not os.path.exists(path) or not os.path.isfile(path): return ComputeLogFileData(path=path, data=None, cursor=0, size=0, download_url=None) # See: https://docs.python.org/2/library/stdtypes.html#file.tell for Windows behavior with open(path, "rb") as f: f.seek(cursor, os.SEEK_SET) data = f.read(max_bytes) cursor = f.tell() stats = os.fstat(f.fileno()) # local download path download_url = self.download_url(run_id, key, io_type) return ComputeLogFileData( path=path, data=data.decode("utf-8"), cursor=cursor, size=stats.st_size, download_url=download_url, ) def get_key(self, dagster_run: DagsterRun, step_key: Optional[str]): check.inst_param(dagster_run, "dagster_run", DagsterRun) check.opt_str_param(step_key, "step_key") return step_key or dagster_run.job_name def is_watch_completed(self, run_id: str, key: str) -> bool: log_key = self.build_log_key_for_run(run_id, key) return self.is_capture_complete(log_key) def on_watch_start(self, dagster_run: DagsterRun, step_key: Optional[str]): pass def on_watch_finish(self, dagster_run: DagsterRun, step_key: Optional[str] = None): check.inst_param(dagster_run, "dagster_run", DagsterRun) check.opt_str_param(step_key, "step_key") log_key = self.build_log_key_for_run(dagster_run.run_id, step_key or dagster_run.job_name) touchpath = self.complete_artifact_path(log_key) touch_file(touchpath) def download_url(self, run_id: str, key: str, io_type: ComputeIOType): check.inst_param(io_type, "io_type", ComputeIOType) return f"/download/{run_id}/{key}/{io_type.value}" def on_subscribe(self, subscription: "LogSubscription") -> None: self._subscription_manager.add_subscription(subscription) def on_unsubscribe(self, subscription: "LogSubscription") -> None: self._subscription_manager.remove_subscription(subscription) def dispose(self) -> None: self._subscription_manager.dispose()
class LocalComputeLogSubscriptionManager: def __init__(self, manager): self._manager = manager self._subscriptions = defaultdict(list) self._watchers = {} self._observer = None def add_subscription(self, subscription: "LogSubscription") -> None: check.inst_param( subscription, "subscription", (ComputeLogSubscription, CapturedLogSubscription) ) if self.is_complete(subscription): subscription.fetch() subscription.complete() else: log_key = self._log_key(subscription) watch_key = self._watch_key(log_key) self._subscriptions[watch_key].append(subscription) self.watch(subscription) def is_complete(self, subscription: "LogSubscription") -> bool: check.inst_param( subscription, "subscription", (ComputeLogSubscription, CapturedLogSubscription) ) if isinstance(subscription, ComputeLogSubscription): return self._manager.is_watch_completed(subscription.run_id, subscription.key) return self._manager.is_capture_complete(subscription.log_key) def remove_subscription(self, subscription: "LogSubscription") -> None: check.inst_param( subscription, "subscription", (ComputeLogSubscription, CapturedLogSubscription) ) log_key = self._log_key(subscription) watch_key = self._watch_key(log_key) if subscription in self._subscriptions[watch_key]: self._subscriptions[watch_key].remove(subscription) subscription.complete() def _log_key(self, subscription: "LogSubscription") -> Sequence[str]: check.inst_param( subscription, "subscription", (ComputeLogSubscription, CapturedLogSubscription) ) if isinstance(subscription, ComputeLogSubscription): return self._manager.build_log_key_for_run(subscription.run_id, subscription.key) return subscription.log_key def _watch_key(self, log_key: Sequence[str]) -> str: return json.dumps(log_key) def remove_all_subscriptions(self, log_key: Sequence[str]) -> None: watch_key = self._watch_key(log_key) for subscription in self._subscriptions.pop(watch_key, []): subscription.complete() def watch(self, subscription: "LogSubscription") -> None: log_key = self._log_key(subscription) watch_key = self._watch_key(log_key) if watch_key in self._watchers: return update_paths = [ self._manager.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT]), self._manager.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR]), self._manager.get_captured_local_path( log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT], partial=True ), self._manager.get_captured_local_path( log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR], partial=True ), ] complete_paths = [self._manager.complete_artifact_path(log_key)] directory = os.path.dirname( self._manager.get_captured_local_path(log_key, ComputeIOType.STDERR), ) if not self._observer: self._observer = PollingObserver(self._manager.polling_timeout) self._observer.start() ensure_dir(directory) self._watchers[watch_key] = self._observer.schedule( LocalComputeLogFilesystemEventHandler(self, log_key, update_paths, complete_paths), str(directory), ) def notify_subscriptions(self, log_key: Sequence[str]) -> None: watch_key = self._watch_key(log_key) for subscription in self._subscriptions[watch_key]: subscription.fetch() def unwatch(self, log_key: Sequence[str], handler) -> None: watch_key = self._watch_key(log_key) if watch_key in self._watchers: self._observer.remove_handler_for_watch(handler, self._watchers[watch_key]) # type: ignore del self._watchers[watch_key] def dispose(self) -> None: if self._observer: self._observer.stop() self._observer.join(15) class LocalComputeLogFilesystemEventHandler(PatternMatchingEventHandler): def __init__(self, manager, log_key, update_paths, complete_paths): self.manager = manager self.log_key = log_key self.update_paths = update_paths self.complete_paths = complete_paths patterns = update_paths + complete_paths super(LocalComputeLogFilesystemEventHandler, self).__init__(patterns=patterns) def on_created(self, event): if event.src_path in self.complete_paths: self.manager.remove_all_subscriptions(self.log_key) self.manager.unwatch(self.log_key, self) def on_modified(self, event): if event.src_path in self.update_paths: self.manager.notify_subscriptions(self.log_key)