Ask AI

Source code for dagster_azure.blob.compute_log_manager

import os
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from typing import Any, Mapping, Optional, Sequence

import dagster._seven as seven
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobSasPermissions, BlobServiceClient, UserDelegationKey
from dagster import (
    Field,
    Noneable,
    Permissive,
    StringSource,
    _check as check,
)
from dagster._core.storage.cloud_storage_compute_log_manager import (
    CloudStorageComputeLogManager,
    PollingComputeLogSubscriptionManager,
)
from dagster._core.storage.compute_log_manager import ComputeIOType
from dagster._core.storage.local_compute_log_manager import (
    IO_TYPE_EXTENSION,
    LocalComputeLogManager,
)
from dagster._serdes import ConfigurableClass, ConfigurableClassData
from dagster._utils import ensure_dir, ensure_file
from typing_extensions import Self

from dagster_azure.blob.utils import create_blob_client, generate_blob_sas


[docs] class AzureBlobComputeLogManager(CloudStorageComputeLogManager, ConfigurableClass): """Logs op compute function stdout and stderr to Azure Blob Storage. This is also compatible with Azure Data Lake Storage. Users should not instantiate this class directly. Instead, use a YAML block in ``dagster.yaml`` such as the following: .. code-block:: YAML compute_logs: module: dagster_azure.blob.compute_log_manager class: AzureBlobComputeLogManager config: storage_account: my-storage-account container: my-container secret_key: sas-token-or-secret-key default_azure_credential: exclude_environment_credential: true prefix: "dagster-test-" local_dir: "/tmp/cool" upload_interval: 30 Args: storage_account (str): The storage account name to which to log. container (str): The container (or ADLS2 filesystem) to which to log. secret_key (Optional[str]): Secret key for the storage account. SAS tokens are not supported because we need a secret key to generate a SAS token for a download URL. default_azure_credential (Optional[dict]): Use and configure DefaultAzureCredential. Cannot be used with sas token or secret key config. local_dir (Optional[str]): Path to the local directory in which to stage logs. Default: ``dagster._seven.get_system_temp_directory()``. prefix (Optional[str]): Prefix for the log file keys. upload_interval: (Optional[int]): Interval in seconds to upload partial log files blob storage. By default, will only upload when the capture is complete. inst_data (Optional[ConfigurableClassData]): Serializable representation of the compute log manager when newed up from config. """ def __init__( self, storage_account, container, secret_key=None, local_dir=None, inst_data: Optional[ConfigurableClassData] = None, prefix="dagster", upload_interval=None, default_azure_credential=None, ): self._storage_account = check.str_param(storage_account, "storage_account") self._container = check.str_param(container, "container") self._blob_prefix = self._clean_prefix(check.str_param(prefix, "prefix")) self._default_azure_credential = check.opt_dict_param( default_azure_credential, "default_azure_credential" ) check.opt_str_param(secret_key, "secret_key") if secret_key is not None: self._blob_client = create_blob_client(storage_account, secret_key) else: credential = DefaultAzureCredential(**self._default_azure_credential) self._blob_client = create_blob_client(storage_account, credential) self._container_client = self._blob_client.get_container_client(container) self._download_urls = {} # proxy calls to local compute log manager (for subscriptions, etc) if not local_dir: local_dir = seven.get_system_temp_directory() self._local_manager = LocalComputeLogManager(local_dir) self._subscription_manager = PollingComputeLogSubscriptionManager(self) self._upload_interval = check.opt_int_param(upload_interval, "upload_interval") self._inst_data = check.opt_inst_param(inst_data, "inst_data", ConfigurableClassData) @contextmanager def _watch_logs(self, dagster_run, step_key=None): # proxy watching to the local compute log manager, interacting with the filesystem with self.local_manager._watch_logs(dagster_run, step_key): # noqa: SLF001 yield @property def inst_data(self): return self._inst_data @classmethod def config_type(cls): return { "storage_account": StringSource, "container": StringSource, "secret_key": Field(Noneable(StringSource), is_required=False, default_value=None), "default_azure_credential": Field( Noneable(Permissive(description="keyword arguments for DefaultAzureCredential")), is_required=False, default_value=None, ), "local_dir": Field(Noneable(StringSource), is_required=False, default_value=None), "prefix": Field(StringSource, is_required=False, default_value="dagster"), "upload_interval": Field(Noneable(int), is_required=False, default_value=None), } @classmethod def from_config_value( cls, inst_data: ConfigurableClassData, config_value: Mapping[str, Any] ) -> Self: return cls(inst_data=inst_data, **config_value) @property def local_manager(self) -> LocalComputeLogManager: return self._local_manager @property def upload_interval(self) -> Optional[int]: return self._upload_interval if self._upload_interval else None def _clean_prefix(self, prefix): parts = prefix.split("/") return "/".join([part for part in parts if part]) def _resolve_path_for_namespace(self, namespace): return [self._blob_prefix, "storage", *namespace] def _blob_key(self, log_key, io_type, partial=False): check.inst_param(io_type, "io_type", ComputeIOType) extension = IO_TYPE_EXTENSION[io_type] [*namespace, filebase] = log_key filename = f"{filebase}.{extension}" if partial: filename = f"{filename}.partial" paths = [*self._resolve_path_for_namespace(namespace), filename] return "/".join(paths) # blob path delimiter def delete_logs( self, log_key: Optional[Sequence[str]] = None, prefix: Optional[Sequence[str]] = None ): self.local_manager.delete_logs(log_key=log_key, prefix=prefix) if log_key: prefix_path = "/".join([self._blob_prefix, "storage", *log_key]) elif prefix: # add the trailing '/' to make sure that ['a'] does not match ['apple'] prefix_path = "/".join([self._blob_prefix, "storage", *prefix, ""]) else: prefix_path = None blob_list = { b.name for b in list(self._container_client.list_blobs(name_starts_with=prefix_path)) } to_remove = None if log_key: # filter to the known set of keys known_keys = [ self._blob_key(log_key, ComputeIOType.STDOUT), self._blob_key(log_key, ComputeIOType.STDERR), self._blob_key(log_key, ComputeIOType.STDOUT, partial=True), self._blob_key(log_key, ComputeIOType.STDERR, partial=True), ] to_remove = [key for key in known_keys if key in blob_list] elif prefix: to_remove = list(blob_list) else: check.failed("Must pass in either `log_key` or `prefix` argument to delete_logs") if to_remove: self._container_client.delete_blobs(*to_remove) def download_url_for_type(self, log_key: Sequence[str], io_type: ComputeIOType): if not self.is_capture_complete(log_key): return None blob_key = self._blob_key(log_key, io_type) if blob_key in self._download_urls: return self._download_urls[blob_key] blob = self._container_client.get_blob_client(blob_key) user_delegation_key = None account_key = None if hasattr(self._blob_client.credential, "account_key"): account_key = self._blob_client.credential.account_key else: user_delegation_key = self._request_user_delegation_key(self._blob_client) sas = generate_blob_sas( self._storage_account, self._container, blob_key, account_key=account_key, user_delegation_key=user_delegation_key, expiry=datetime.now() + timedelta(hours=6), permission=BlobSasPermissions(read=True), ) url = blob.url + "?" + sas self._download_urls[blob_key] = url return url def _request_user_delegation_key( self, blob_service_client: BlobServiceClient, ) -> UserDelegationKey: """Creates user delegation key when a service principal is used or other authentication other than account key. """ # Get a user delegation key that's valid for 1 day delegation_key_start_time = datetime.now(timezone.utc) delegation_key_expiry_time = delegation_key_start_time + timedelta(days=1) user_delegation_key = blob_service_client.get_user_delegation_key( key_start_time=delegation_key_start_time, key_expiry_time=delegation_key_expiry_time, ) return user_delegation_key def display_path_for_type(self, log_key: Sequence[str], io_type: ComputeIOType): if not self.is_capture_complete(log_key): return self.local_manager.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type]) blob_key = self._blob_key(log_key, io_type) return f"https://{self._storage_account}.blob.core.windows.net/{self._container}/{blob_key}" def cloud_storage_has_logs( self, log_key: Sequence[str], io_type: ComputeIOType, partial: bool = False ) -> bool: blob_key = self._blob_key(log_key, io_type, partial=partial) blob_objects = self._container_client.list_blobs(blob_key) exact_matches = [blob for blob in blob_objects if blob.name == blob_key] return len(exact_matches) > 0 def upload_to_cloud_storage( self, log_key: Sequence[str], io_type: ComputeIOType, partial=False ): path = self.local_manager.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type]) ensure_file(path) blob_key = self._blob_key(log_key, io_type, partial=partial) with open(path, "rb") as data: blob = self._container_client.get_blob_client(blob_key) blob.upload_blob(data, **{"overwrite": partial}) # type: ignore def download_from_cloud_storage( self, log_key: Sequence[str], io_type: ComputeIOType, partial=False ): path = self.local_manager.get_captured_local_path( log_key, IO_TYPE_EXTENSION[io_type], partial=partial ) ensure_dir(os.path.dirname(path)) blob_key = self._blob_key(log_key, io_type, partial=partial) with open(path, "wb") as fileobj: blob = self._container_client.get_blob_client(blob_key) blob.download_blob().readinto(fileobj) def get_log_keys_for_log_key_prefix( self, log_key_prefix: Sequence[str], io_type: ComputeIOType ) -> Sequence[Sequence[str]]: directory = self._resolve_path_for_namespace(log_key_prefix) blobs = self._container_client.list_blobs(name_starts_with="/".join(directory)) results = [] list_key_prefix = list(log_key_prefix) for blob in blobs: full_key = blob.name filename, blob_io_type = full_key.split("/")[-1].split(".") if blob_io_type != IO_TYPE_EXTENSION[io_type]: continue results.append(list_key_prefix + [filename]) return results def on_subscribe(self, subscription): self._subscription_manager.add_subscription(subscription) def on_unsubscribe(self, subscription): self._subscription_manager.remove_subscription(subscription)