import base64
import gzip
import os
import random
import string
import sys
from contextlib import contextmanager
from datetime import datetime
from threading import Event, Thread
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterator,
List,
Optional,
Sequence,
Tuple,
cast,
)
import boto3
import dagster._check as check
from botocore.exceptions import ClientError
from dagster import DagsterInvariantViolationError
from dagster._annotations import experimental
from dagster._core.pipes.client import PipesLaunchedData, PipesMessageReader, PipesParams
from dagster._core.pipes.context import PipesMessageHandler
from dagster._core.pipes.utils import (
PipesBlobStoreMessageReader,
PipesChunkedLogReader,
PipesLogReader,
PipesThreadedMessageReader,
extract_message_or_forward_to_stdout,
forward_only_logs_to_file,
)
from dagster_pipes import PipesDefaultMessageWriter
if TYPE_CHECKING:
from mypy_boto3_logs import CloudWatchLogsClient
from mypy_boto3_logs.type_defs import OutputLogEventTypeDef
from mypy_boto3_s3 import S3Client
def _can_read_from_s3(client: "S3Client", bucket: Optional[str], key: Optional[str]):
if not bucket or not key:
return False
else:
try:
client.head_object(Bucket=bucket, Key=key)
return True
except ClientError:
return False
def default_log_decode_fn(contents: bytes) -> str:
return contents.decode("utf-8")
def gzip_log_decode_fn(contents: bytes) -> str:
return gzip.decompress(contents).decode("utf-8")
class PipesS3LogReader(PipesChunkedLogReader):
def __init__(
self,
*,
bucket: str,
key: str,
client: Optional["S3Client"] = None,
interval: float = 10,
target_stream: Optional[IO[str]] = None,
# TODO: maybe move this parameter to a different scope
decode_fn: Optional[Callable[[bytes], str]] = None,
debug_info: Optional[str] = None,
):
self.bucket = bucket
self.key = key
self.client: "S3Client" = client or boto3.client("s3")
self.decode_fn = decode_fn or default_log_decode_fn
self.log_position = 0
super().__init__(
interval=interval, target_stream=target_stream or sys.stdout, debug_info=debug_info
)
@property
def name(self) -> str:
return f"PipesS3LogReader(s3://{os.path.join(self.bucket, self.key)})"
def target_is_readable(self, params: PipesParams) -> bool:
return _can_read_from_s3(
client=self.client,
bucket=self.bucket,
key=self.key,
)
def download_log_chunk(self, params: PipesParams) -> Optional[str]:
text = self.decode_fn(
self.client.get_object(Bucket=self.bucket, Key=self.key)["Body"].read()
)
current_position = self.log_position
self.log_position += len(text)
return text[current_position:]
[docs]
class PipesS3MessageReader(PipesBlobStoreMessageReader):
"""Message reader that reads messages by periodically reading message chunks from a specified S3
bucket.
If `log_readers` is passed, this reader will also start the passed readers
when the first message is received from the external process.
Args:
interval (float): interval in seconds between attempts to download a chunk
bucket (str): The S3 bucket to read from.
client (WorkspaceClient): A boto3 client.
log_readers (Optional[Sequence[PipesLogReader]]): A set of log readers for logs on S3.
"""
def __init__(
self,
*,
interval: float = 10,
bucket: str,
client: boto3.client, # pyright: ignore (reportGeneralTypeIssues)
log_readers: Optional[Sequence[PipesLogReader]] = None,
):
super().__init__(
interval=interval,
log_readers=log_readers,
)
self.bucket = check.str_param(bucket, "bucket")
self.client = client
@contextmanager
def get_params(self) -> Iterator[PipesParams]:
key_prefix = "".join(random.choices(string.ascii_letters, k=30))
yield {"bucket": self.bucket, "key_prefix": key_prefix}
def messages_are_readable(self, params: PipesParams) -> bool:
key_prefix = params.get("key_prefix")
if key_prefix is not None:
try:
self.client.head_object(Bucket=self.bucket, Key=f"{key_prefix}/1.json")
return True
except ClientError:
return False
else:
return False
def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]:
key = f"{params['key_prefix']}/{index}.json"
try:
obj = self.client.get_object(Bucket=self.bucket, Key=key)
return obj["Body"].read().decode("utf-8")
except ClientError:
return None
def no_messages_debug_text(self) -> str:
return (
f"Attempted to read messages from S3 bucket {self.bucket}. Expected"
" PipesS3MessageWriter to be explicitly passed to open_dagster_pipes in the external"
" process."
)
class PipesLambdaLogsMessageReader(PipesMessageReader):
"""Message reader that consumes buffered pipes messages that were flushed on exit from the
final 4k of logs that are returned from issuing a sync lambda invocation. This means messages
emitted during the computation will only be processed once the lambda completes.
Limitations: If the volume of pipes messages exceeds 4k, messages will be lost and it is
recommended to switch to PipesS3MessageWriter & PipesS3MessageReader.
"""
@contextmanager
def read_messages(
self,
handler: PipesMessageHandler,
) -> Iterator[PipesParams]:
self._handler = handler
try:
# use buffered stdio to shift the pipes messages to the tail of logs
yield {PipesDefaultMessageWriter.BUFFERED_STDIO_KEY: PipesDefaultMessageWriter.STDERR}
finally:
self._handler = None
def consume_lambda_logs(self, response) -> None:
handler = check.not_none(
self._handler, "Can only consume logs within context manager scope."
)
log_result = base64.b64decode(response["LogResult"]).decode("utf-8")
for log_line in log_result.splitlines():
extract_message_or_forward_to_stdout(handler, log_line)
def no_messages_debug_text(self) -> str:
return (
"Attempted to read messages by extracting them from the tail of lambda logs directly."
)
def tail_cloudwatch_events(
client: "CloudWatchLogsClient",
log_group: str,
log_stream: str,
start_time: Optional[int] = None,
) -> Generator[List["OutputLogEventTypeDef"], None, None]:
"""Yields events from a CloudWatch log stream."""
params: Dict[str, Any] = {
"logGroupName": log_group,
"logStreamName": log_stream,
}
if start_time is not None:
params["startTime"] = start_time
response = client.get_log_events(**params)
while True:
events = response.get("events")
if events:
yield events
params["nextToken"] = response["nextForwardToken"]
response = client.get_log_events(**params)
@experimental
class PipesCloudWatchLogReader(PipesLogReader):
def __init__(
self,
client=None,
log_group: Optional[str] = None,
log_stream: Optional[str] = None,
target_stream: Optional[IO[str]] = None,
start_time: Optional[int] = None,
debug_info: Optional[str] = None,
):
self.client = client or boto3.client("logs")
self.log_group = log_group
self.log_stream = log_stream
self.target_stream = target_stream or sys.stdout
self.thread = None
self.start_time = start_time
self._debug_info = debug_info
@property
def debug_info(self) -> Optional[str]:
return self._debug_info
def target_is_readable(self, params: PipesParams) -> bool:
log_group = params.get("log_group") or self.log_group
log_stream = params.get("log_stream") or self.log_stream
if log_group is not None and log_stream is not None:
# check if the stream actually exists
try:
self.client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=log_stream,
)
return True
except self.client.exceptions.ResourceNotFoundException:
return False
else:
return False
def start(self, params: PipesParams, is_session_closed: Event) -> None:
if not self.target_is_readable(params):
raise DagsterInvariantViolationError(
"log_group and log_stream must be set either in the constructor or in Pipes params."
)
self.thread = Thread(
target=self._start, kwargs={"params": params, "is_session_closed": is_session_closed}
)
self.thread.start()
def _start(self, params: PipesParams, is_session_closed: Event) -> None:
log_group = cast(str, params.get("log_group") or self.log_group)
log_stream = cast(str, params.get("log_stream") or self.log_stream)
start_time = cast(int, self.start_time or params.get("start_time"))
for events in tail_cloudwatch_events(
self.client, log_group, log_stream, start_time=start_time
):
for event in events:
for line in event.get("message", "").splitlines():
if line:
forward_only_logs_to_file(line, self.target_stream)
if is_session_closed.is_set():
return
def stop(self) -> None:
pass
def is_running(self) -> bool:
return self.thread is not None and self.thread.is_alive()
[docs]
@experimental
class PipesCloudWatchMessageReader(PipesThreadedMessageReader):
"""Message reader that consumes AWS CloudWatch logs to read pipes messages."""
def __init__(
self,
client=None,
log_group: Optional[str] = None,
log_stream: Optional[str] = None,
log_readers: Optional[Sequence[PipesLogReader]] = None,
):
"""Args:
client (boto3.client): boto3 CloudWatch client.
"""
self.client: "CloudWatchLogsClient" = client or boto3.client("logs")
self.log_group = log_group
self.log_stream = log_stream
self.start_time = datetime.now()
super().__init__(log_readers=log_readers)
def on_launched(self, launched_payload: PipesLaunchedData) -> None:
if "log_group" in launched_payload["extras"]:
self.log_group = launched_payload["extras"]["log_group"]
if "log_stream" in launched_payload["extras"]:
self.log_stream = launched_payload["extras"]["log_stream"]
self.launched_payload = launched_payload
@contextmanager
def get_params(self) -> Iterator[PipesParams]:
yield {PipesDefaultMessageWriter.STDIO_KEY: PipesDefaultMessageWriter.STDOUT}
def messages_are_readable(self, params: PipesParams) -> bool:
if self.log_group is not None and self.log_stream is not None:
# check if the stream actually exists
try:
self.client.describe_log_streams(
logGroupName=self.log_group,
logStreamNamePrefix=self.log_stream,
)
return True
except self.client.exceptions.ResourceNotFoundException:
return False
else:
return False
def download_messages(
self, cursor: Optional[str], params: PipesParams
) -> Optional[Tuple[str, str]]:
params = {
"logGroupName": self.log_group,
"logStreamName": self.log_stream,
"startTime": int(self.start_time.timestamp() * 1000),
}
if cursor is not None:
params["nextToken"] = cursor
response = self.client.get_log_events(**params)
events = response.get("events")
if not events:
return None
else:
cursor = cast(str, response["nextForwardToken"])
return cursor, "\n".join(
cast(str, event.get("message")) for event in events if event.get("message")
)
def no_messages_debug_text(self) -> str:
return "Attempted to read messages by extracting them from CloudWatch logs."