import sys
import time
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
import boto3
import dagster._check as check
from dagster import DagsterInvariantViolationError, PipesClient
from dagster._annotations import experimental, public
from dagster._core.definitions.metadata import RawMetadataMapping
from dagster._core.definitions.resource_annotation import TreatAsResourceParam
from dagster._core.errors import DagsterExecutionInterruptedError
from dagster._core.execution.context.asset_execution_context import AssetExecutionContext
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.pipes.client import (
PipesClientCompletedInvocation,
PipesContextInjector,
PipesMessageReader,
)
from dagster._core.pipes.context import PipesSession
from dagster._core.pipes.utils import PipesEnvContextInjector, open_pipes_session
from dagster._utils.merger import deep_merge_dicts
from dagster_aws.pipes.message_readers import PipesCloudWatchLogReader, PipesCloudWatchMessageReader
if TYPE_CHECKING:
from mypy_boto3_emr_serverless.client import EMRServerlessClient
from mypy_boto3_emr_serverless.literals import JobRunStateType
from mypy_boto3_emr_serverless.type_defs import (
GetJobRunResponseTypeDef,
MonitoringConfigurationTypeDef,
StartJobRunRequestRequestTypeDef,
StartJobRunResponseTypeDef,
)
AWS_SERVICE_NAME = "EMR Serverless"
[docs]
@public
@experimental
class PipesEMRServerlessClient(PipesClient, TreatAsResourceParam):
"""A pipes client for running workloads on AWS EMR Serverless.
Args:
client (Optional[boto3.client]): The boto3 AWS EMR Serverless client used to interact with AWS EMR Serverless.
context_injector (Optional[PipesContextInjector]): A context injector to use to inject
context into AWS EMR Serverless workload. Defaults to :py:class:`PipesEnvContextInjector`.
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the AWS EMR Serverless workload. Defaults to :py:class:`PipesCloudWatchMessageReader`.
forward_termination (bool): Whether to cancel the AWS EMR Serverless workload if the Dagster process receives a termination signal.
poll_interval (float): The interval in seconds to poll the AWS EMR Serverless workload for status updates. Defaults to 5 seconds.
"""
AWS_SERVICE_NAME = AWS_SERVICE_NAME
def __init__(
self,
client: Optional["EMRServerlessClient"] = None,
context_injector: Optional[PipesContextInjector] = None,
message_reader: Optional[PipesMessageReader] = None,
forward_termination: bool = True,
poll_interval: float = 5.0,
):
self._client = client or boto3.client("emr-serverless")
self._context_injector = context_injector or PipesEnvContextInjector()
self._message_reader = message_reader or PipesCloudWatchMessageReader()
self.forward_termination = check.bool_param(forward_termination, "forward_termination")
self.poll_interval = poll_interval
@property
def client(self) -> "EMRServerlessClient":
return self._client
@property
def context_injector(self) -> PipesContextInjector:
return self._context_injector
@property
def message_reader(self) -> PipesMessageReader:
return self._message_reader
@classmethod
def _is_dagster_maintained(cls) -> bool:
return True
[docs]
@public
def run(
self,
*,
context: Union[OpExecutionContext, AssetExecutionContext],
start_job_run_params: "StartJobRunRequestRequestTypeDef",
extras: Optional[Dict[str, Any]] = None,
) -> PipesClientCompletedInvocation:
"""Run a workload on AWS EMR Serverless, enriched with the pipes protocol.
Args:
context (Union[OpExecutionContext, AssetExecutionContext]): The context of the currently executing Dagster op or asset.
params (dict): Parameters for the ``start_job_run`` boto3 AWS EMR Serverless client call.
See `Boto3 EMR Serverless API Documentation <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-serverless/client/start_job_run.html>`_
extras (Optional[Dict[str, Any]]): Additional information to pass to the Pipes session in the external process.
Returns:
PipesClientCompletedInvocation: Wrapper containing results reported by the external
process.
"""
with open_pipes_session(
context=context,
message_reader=self.message_reader,
context_injector=self.context_injector,
extras=extras,
) as session:
start_job_run_params = self._enrich_start_params(context, session, start_job_run_params)
start_response = self._start(context, start_job_run_params)
try:
completion_response = self._wait_for_completion(context, start_response)
context.log.info(f"[pipes] {self.AWS_SERVICE_NAME} workload is complete!")
self._read_messages(context, session, completion_response)
return PipesClientCompletedInvocation(
session, metadata=self._extract_dagster_metadata(completion_response)
)
except DagsterExecutionInterruptedError:
if self.forward_termination:
context.log.warning(
f"[pipes] Dagster process interrupted! Will terminate external {self.AWS_SERVICE_NAME} workload."
)
self._terminate(context, start_response)
raise
def _enrich_start_params(
self,
context: Union[OpExecutionContext, AssetExecutionContext],
session: PipesSession,
params: "StartJobRunRequestRequestTypeDef",
) -> "StartJobRunRequestRequestTypeDef":
# inject Dagster tags
tags = params.get("tags", {})
params["tags"] = {**tags, **session.default_remote_invocation_info}
# inject env variables via --conf spark.executorEnv.env.<key>=<value>
dagster_env_vars = {}
dagster_env_vars.update(session.get_bootstrap_env_vars())
if "jobDriver" not in params:
params["jobDriver"] = {}
if "sparkSubmit" not in params["jobDriver"]:
params["jobDriver"]["sparkSubmit"] = {} # pyright: ignore[reportGeneralTypeIssues]
params["jobDriver"]["sparkSubmit"]["sparkSubmitParameters"] = params.get(
"jobDriver", {}
).get("sparkSubmit", {}).get("sparkSubmitParameters", "") + "".join(
[
f" --conf spark.emr-serverless.driverEnv.{key}={value}"
for key, value in dagster_env_vars.items()
]
)
return cast("StartJobRunRequestRequestTypeDef", params)
def _start(
self,
context: Union[OpExecutionContext, AssetExecutionContext],
params: "StartJobRunRequestRequestTypeDef",
) -> "StartJobRunResponseTypeDef":
response = self.client.start_job_run(**params)
application_id = response["applicationId"]
job_run_id = response["jobRunId"]
# this URL is only valid for an hour
# so we don't include it in the output metadata
dashboard_url = self.client.get_dashboard_for_job_run(
applicationId=application_id, jobRunId=job_run_id
)
context.log.info(
f"[pipes] {self.AWS_SERVICE_NAME} job started with job_run_id {job_run_id}. Dashboard URL: {dashboard_url}"
)
return response
def _wait_for_completion(
self,
context: Union[OpExecutionContext, AssetExecutionContext],
start_response: "StartJobRunResponseTypeDef",
) -> "GetJobRunResponseTypeDef": # pyright: ignore[reportReturnType]
job_run_id = start_response["jobRunId"]
while response := self.client.get_job_run(
applicationId=start_response["applicationId"],
jobRunId=job_run_id,
):
state: "JobRunStateType" = response["jobRun"]["state"]
if state in ["FAILED", "CANCELLED", "CANCELLING"]:
context.log.error(
f"[pipes] {self.AWS_SERVICE_NAME} job {job_run_id} terminated with state: {state}. Details:\n{response['jobRun'].get('stateDetails')}"
)
raise RuntimeError(
f"{self.AWS_SERVICE_NAME} job failed"
) # TODO: introduce something like DagsterPipesRemoteExecutionError
elif state == "SUCCESS":
context.log.info(
f"[pipes] {self.AWS_SERVICE_NAME} job {job_run_id} completed with state: {state}"
)
return response
elif state in ["PENDING", "SUBMITTED", "SCHEDULED", "RUNNING"]:
time.sleep(self.poll_interval)
continue
else:
raise DagsterInvariantViolationError(
f"Unexpected state for AWS EMR Serverless job {job_run_id}: {state}"
)
def _read_messages(
self,
context: Union[OpExecutionContext, AssetExecutionContext],
session: PipesSession,
response: "GetJobRunResponseTypeDef",
):
application_id = response["jobRun"]["applicationId"]
job_id = response["jobRun"]["jobRunId"]
application = self.client.get_application(applicationId=application_id)["application"]
# merge base monitoring configuration from application
# with potential overrides from the job run
application_monitoring_configuration = application.get("monitoringConfiguration", {})
job_monitoring_configuration = (
response["jobRun"].get("configurationOverrides", {}).get("monitoringConfiguration", {})
)
monitoring_configuration = cast(
"MonitoringConfigurationTypeDef",
deep_merge_dicts(application_monitoring_configuration, job_monitoring_configuration),
)
application_type = application["type"]
if application_type == "Spark":
worker_type = "SPARK_DRIVER"
elif application_type == "Hive":
worker_type = "HIVE_DRIVER"
else:
raise NotImplementedError(f"Application type {application_type} is not supported")
if not isinstance(self.message_reader, PipesCloudWatchMessageReader):
context.log.warning(
f"[pipes] {self.message_reader} is not supported for {self.AWS_SERVICE_NAME}. Dagster won't be able to receive logs and messages from the job."
)
return
# https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/logging.html#jobs-log-storage-cw
# we can get cloudwatch logs from the known log group
if (
monitoring_configuration.get("cloudWatchLoggingConfiguration", {}).get("enabled")
is not True
):
context.log.warning(
f"[pipes] Recieved {self.message_reader}, but CloudWatch logging is not enabled for {self.AWS_SERVICE_NAME} job. Dagster won't be able to receive logs and messages from the job."
)
return
if log_types := monitoring_configuration.get("cloudWatchLoggingConfiguration", {}).get(
"logTypes"
):
# get the configured output streams
# but limit them with "stdout" and "stderr"
output_streams = list(
map(
lambda x: x.lower(),
set(log_types.get(worker_type, ["STDOUT", "STDERR"])) & {"stdout", "stderr"},
)
)
else:
output_streams = ["stdout", "stderr"]
log_group = monitoring_configuration.get("logGroupName") or "/aws/emr-serverless"
attempt = response["jobRun"].get("attempt")
if attempt is not None and attempt > 1:
log_stream = (
f"/applications/{application_id}/jobs/{job_id}/attempts/{attempt}/{worker_type}"
)
else:
log_stream = f"/applications/{application_id}/jobs/{job_id}/{worker_type}"
if log_stream_prefix := monitoring_configuration.get(
"cloudWatchLoggingConfiguration", {}
).get("logStreamNamePrefix"):
log_stream = f"{log_stream_prefix}{log_stream}"
output_files = {
"stdout": sys.stdout,
"stderr": sys.stderr,
}
# update MessageReader params so it can start receiving messages
if isinstance(self.message_reader, PipesCloudWatchMessageReader):
session.report_launched(
{
"extras": {
"log_group": log_group,
"log_stream": f"{log_stream}/stdout",
}
}
)
# now add LogReaders for stdout and stderr logs
for output_stream in output_streams:
output_file = output_files[output_stream]
context.log.debug(
f"[pipes] Adding PipesCloudWatchLogReader for group {log_group} stream {log_stream}/{output_stream}"
)
self.message_reader.add_log_reader(
PipesCloudWatchLogReader(
client=self.message_reader.client,
log_group=log_group,
log_stream=f"{log_stream}/{output_stream}",
target_stream=output_file,
start_time=int(session.created_at.timestamp() * 1000),
debug_info=output_stream,
),
)
def _extract_dagster_metadata(self, response: "GetJobRunResponseTypeDef") -> RawMetadataMapping:
metadata: RawMetadataMapping = {}
job_run = response["jobRun"]
metadata["AWS EMR Serverless Application ID"] = job_run["applicationId"]
metadata["AWS EMR Serverless Job Run ID"] = job_run["jobRunId"]
# TODO: it would be great to add a url to EMR Studio page for this run
# such urls look like: https://es-638xhdetxum2td9nc3a45evmn.emrstudio-prod.eu-north-1.amazonaws.com/#/serverless-applications/00fm4oe0607u5a1d
# but we need to get the Studio ID from the application_id
# which is not possible with the current AWS API
return metadata
def _terminate(
self,
context: Union[OpExecutionContext, AssetExecutionContext],
start_response: "StartJobRunResponseTypeDef",
):
job_run_id = start_response["jobRunId"]
application_id = start_response["applicationId"]
context.log.info(f"[pipes] Terminating {self.AWS_SERVICE_NAME} job run {job_run_id}")
self.client.cancel_job_run(applicationId=application_id, jobRunId=job_run_id)