import os
from contextlib import contextmanager
from typing import Any, Iterator, Mapping, Optional, Sequence
import boto3
import dagster._seven as seven
from botocore.errorfactory import ClientError
from dagster import (
Field,
Permissive,
StringSource,
_check as check,
)
from dagster._config.config_type import Noneable
from dagster._core.storage.cloud_storage_compute_log_manager import (
CloudStorageComputeLogManager,
PollingComputeLogSubscriptionManager,
)
from dagster._core.storage.compute_log_manager import CapturedLogContext, ComputeIOType
from dagster._core.storage.local_compute_log_manager import (
IO_TYPE_EXTENSION,
LocalComputeLogManager,
)
from dagster._serdes import ConfigurableClass, ConfigurableClassData
from dagster._utils import ensure_dir, ensure_file
from typing_extensions import Self
POLLING_INTERVAL = 5
[docs]
class S3ComputeLogManager(CloudStorageComputeLogManager, ConfigurableClass):
"""Logs compute function stdout and stderr to S3.
Users should not instantiate this class directly. Instead, use a YAML block in ``dagster.yaml``
such as the following:
.. code-block:: YAML
compute_logs:
module: dagster_aws.s3.compute_log_manager
class: S3ComputeLogManager
config:
bucket: "mycorp-dagster-compute-logs"
local_dir: "/tmp/cool"
prefix: "dagster-test-"
use_ssl: true
verify: true
verify_cert_path: "/path/to/cert/bundle.pem"
endpoint_url: "http://alternate-s3-host.io"
skip_empty_files: true
upload_interval: 30
upload_extra_args:
ServerSideEncryption: "AES256"
show_url_only: false
region: "us-west-1"
Args:
bucket (str): The name of the s3 bucket to which to log.
local_dir (Optional[str]): Path to the local directory in which to stage logs. Default:
``dagster._seven.get_system_temp_directory()``.
prefix (Optional[str]): Prefix for the log file keys.
use_ssl (Optional[bool]): Whether or not to use SSL. Default True.
verify (Optional[bool]): Whether or not to verify SSL certificates. Default True.
verify_cert_path (Optional[str]): A filename of the CA cert bundle to use. Only used if
`verify` set to False.
endpoint_url (Optional[str]): Override for the S3 endpoint url.
skip_empty_files: (Optional[bool]): Skip upload of empty log files.
upload_interval: (Optional[int]): Interval in seconds to upload partial log files to S3. By default, will only upload when the capture is complete.
upload_extra_args: (Optional[dict]): Extra args for S3 file upload
show_url_only: (Optional[bool]): Only show the URL of the log file in the UI, instead of fetching and displaying the full content. Default False.
region: (Optional[str]): The region of the S3 bucket. If not specified, will use the default region of the AWS session.
inst_data (Optional[ConfigurableClassData]): Serializable representation of the compute
log manager when newed up from config.
"""
def __init__(
self,
bucket,
local_dir=None,
inst_data: Optional[ConfigurableClassData] = None,
prefix="dagster",
use_ssl=True,
verify=True,
verify_cert_path=None,
endpoint_url=None,
skip_empty_files=False,
upload_interval=None,
upload_extra_args=None,
show_url_only=False,
region=None,
):
_verify = False if not verify else verify_cert_path
self._s3_session = boto3.resource(
"s3", use_ssl=use_ssl, verify=_verify, endpoint_url=endpoint_url
).meta.client
self._s3_bucket = check.str_param(bucket, "bucket")
self._s3_prefix = self._clean_prefix(check.str_param(prefix, "prefix"))
# proxy calls to local compute log manager (for subscriptions, etc)
if not local_dir:
local_dir = seven.get_system_temp_directory()
self._local_manager = LocalComputeLogManager(local_dir)
self._subscription_manager = PollingComputeLogSubscriptionManager(self)
self._inst_data = check.opt_inst_param(inst_data, "inst_data", ConfigurableClassData)
self._skip_empty_files = check.bool_param(skip_empty_files, "skip_empty_files")
self._upload_interval = check.opt_int_param(upload_interval, "upload_interval")
check.opt_dict_param(upload_extra_args, "upload_extra_args")
self._upload_extra_args = upload_extra_args
self._show_url_only = show_url_only
if region is None:
# if unspecified, use the current session name
self._region = self._s3_session.meta.region_name
else:
self._region = region
@property
def inst_data(self):
return self._inst_data
@classmethod
def config_type(cls):
return {
"bucket": StringSource,
"local_dir": Field(StringSource, is_required=False),
"prefix": Field(StringSource, is_required=False, default_value="dagster"),
"use_ssl": Field(bool, is_required=False, default_value=True),
"verify": Field(bool, is_required=False, default_value=True),
"verify_cert_path": Field(StringSource, is_required=False),
"endpoint_url": Field(StringSource, is_required=False),
"skip_empty_files": Field(bool, is_required=False, default_value=False),
"upload_interval": Field(Noneable(int), is_required=False, default_value=None),
"upload_extra_args": Field(
Permissive(), is_required=False, description="Extra args for S3 file upload"
),
"show_url_only": Field(bool, is_required=False, default_value=False),
"region": Field(StringSource, is_required=False),
}
@classmethod
def from_config_value(
cls, inst_data: ConfigurableClassData, config_value: Mapping[str, Any]
) -> Self:
return cls(inst_data=inst_data, **config_value)
@property
def local_manager(self) -> LocalComputeLogManager:
return self._local_manager
@property
def upload_interval(self) -> Optional[int]:
return self._upload_interval if self._upload_interval else None
def _clean_prefix(self, prefix):
parts = prefix.split("/")
return "/".join([part for part in parts if part])
def _resolve_path_for_namespace(self, namespace):
return [self._s3_prefix, "storage", *namespace]
def _s3_key(self, log_key, io_type, partial=False):
check.inst_param(io_type, "io_type", ComputeIOType)
extension = IO_TYPE_EXTENSION[io_type]
[*namespace, filebase] = log_key
filename = f"{filebase}.{extension}"
if partial:
filename = f"{filename}.partial"
paths = [*self._resolve_path_for_namespace(namespace), filename]
return "/".join(paths) # s3 path delimiter
@contextmanager
def capture_logs(self, log_key: Sequence[str]) -> Iterator[CapturedLogContext]:
with super().capture_logs(log_key) as local_context:
if not self._show_url_only:
yield local_context
else:
out_key = self._s3_key(log_key, ComputeIOType.STDOUT)
err_key = self._s3_key(log_key, ComputeIOType.STDERR)
s3_base = f"https://s3.console.aws.amazon.com/s3/object/{self._s3_bucket}?region={self._region}"
yield CapturedLogContext(
local_context.log_key,
external_stdout_url=f"{s3_base}&prefix={out_key}",
external_stderr_url=f"{s3_base}&prefix={err_key}",
)
def delete_logs(
self, log_key: Optional[Sequence[str]] = None, prefix: Optional[Sequence[str]] = None
):
self.local_manager.delete_logs(log_key=log_key, prefix=prefix)
s3_keys_to_remove = None
if log_key:
s3_keys_to_remove = [
self._s3_key(log_key, ComputeIOType.STDOUT),
self._s3_key(log_key, ComputeIOType.STDERR),
self._s3_key(log_key, ComputeIOType.STDOUT, partial=True),
self._s3_key(log_key, ComputeIOType.STDERR, partial=True),
]
elif prefix:
# add the trailing '' to make sure that ['a'] does not match ['apple']
s3_prefix = "/".join([self._s3_prefix, "storage", *prefix, ""])
matching = self._s3_session.list_objects(Bucket=self._s3_bucket, Prefix=s3_prefix)
s3_keys_to_remove = [obj["Key"] for obj in matching.get("Contents", [])]
else:
check.failed("Must pass in either `log_key` or `prefix` argument to delete_logs")
if s3_keys_to_remove:
to_delete = [{"Key": key} for key in s3_keys_to_remove]
self._s3_session.delete_objects(Bucket=self._s3_bucket, Delete={"Objects": to_delete})
def download_url_for_type(self, log_key: Sequence[str], io_type: ComputeIOType):
if not self.is_capture_complete(log_key):
return None
s3_key = self._s3_key(log_key, io_type)
return self._s3_session.generate_presigned_url(
ClientMethod="get_object", Params={"Bucket": self._s3_bucket, "Key": s3_key}
)
def display_path_for_type(self, log_key: Sequence[str], io_type: ComputeIOType):
if not self.is_capture_complete(log_key):
return None
s3_key = self._s3_key(log_key, io_type)
return f"s3://{self._s3_bucket}/{s3_key}"
def cloud_storage_has_logs(
self, log_key: Sequence[str], io_type: ComputeIOType, partial: bool = False
) -> bool:
s3_key = self._s3_key(log_key, io_type, partial=partial)
try: # https://stackoverflow.com/a/38376288/14656695
self._s3_session.head_object(Bucket=self._s3_bucket, Key=s3_key)
except ClientError:
return False
return True
def upload_to_cloud_storage(
self, log_key: Sequence[str], io_type: ComputeIOType, partial=False
):
path = self.local_manager.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type])
ensure_file(path)
if (self._skip_empty_files or partial) and os.stat(path).st_size == 0:
return
s3_key = self._s3_key(log_key, io_type, partial=partial)
with open(path, "rb") as data:
extra_args = {
"ContentType": "text/plain",
**(self._upload_extra_args if self._upload_extra_args else {}),
}
self._s3_session.upload_fileobj(data, self._s3_bucket, s3_key, ExtraArgs=extra_args)
def download_from_cloud_storage(
self, log_key: Sequence[str], io_type: ComputeIOType, partial=False
):
path = self._local_manager.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[io_type], partial=partial
)
ensure_dir(os.path.dirname(path))
s3_key = self._s3_key(log_key, io_type, partial=partial)
with open(path, "wb") as fileobj:
self._s3_session.download_fileobj(self._s3_bucket, s3_key, fileobj)
def get_log_keys_for_log_key_prefix(
self, log_key_prefix: Sequence[str], io_type: ComputeIOType
) -> Sequence[Sequence[str]]:
directory = self._resolve_path_for_namespace(log_key_prefix)
objects = self._s3_session.list_objects_v2(
Bucket=self._s3_bucket, Prefix="/".join(directory)
)
results = []
list_key_prefix = list(log_key_prefix)
if objects["KeyCount"] == 0:
return []
for obj in objects["Contents"]:
full_key = obj["Key"]
filename, obj_io_type = full_key.split("/")[-1].split(".")
if obj_io_type != IO_TYPE_EXTENSION[io_type]:
continue
results.append(list_key_prefix + [filename])
return results
def on_subscribe(self, subscription):
self._subscription_manager.add_subscription(subscription)
def on_unsubscribe(self, subscription):
self._subscription_manager.remove_subscription(subscription)
def dispose(self):
self._subscription_manager.dispose()
self._local_manager.dispose()