Ask AI

Source code for dagster_aws.ecs.executor

import json
import os
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, cast

import boto3
from dagster import (
    DagsterInvariantViolationError,
    DagsterRun,
    Field,
    IntSource,
    Permissive,
    _check as check,
    executor,
)
from dagster._annotations import experimental
from dagster._core.definitions.executor_definition import multiple_process_executor_requirements
from dagster._core.definitions.metadata import MetadataValue
from dagster._core.events import DagsterEvent, EngineEventData
from dagster._core.execution.retries import RetryMode, get_retries_config
from dagster._core.execution.tags import get_tag_concurrency_limits_config
from dagster._core.executor.base import Executor
from dagster._core.executor.init import InitExecutorContext
from dagster._core.executor.step_delegating import (
    CheckStepHealthResult,
    StepDelegatingExecutor,
    StepHandler,
    StepHandlerContext,
)
from dagster._utils.backoff import backoff
from dagster._utils.merger import deep_merge_dicts

from dagster_aws.ecs.container_context import EcsContainerContext
from dagster_aws.ecs.launcher import STOPPED_STATUSES, EcsRunLauncher
from dagster_aws.ecs.tasks import (
    get_current_ecs_task,
    get_current_ecs_task_metadata,
    get_task_kwargs_from_current_task,
)
from dagster_aws.ecs.utils import RetryableEcsException, run_ecs_task

DEFAULT_STEP_TASK_RETRIES = "5"


_ECS_EXECUTOR_CONFIG_SCHEMA = {
    "run_task_kwargs": Field(
        Permissive({}),
        is_required=False,
        description=(
            "Additional arguments to which can be set to the boto3 run_task call. Will override values inherited from the ECS run launcher."
            " https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task"
            " for the available parameters."
        ),
    ),
    "cpu": Field(IntSource, is_required=False),
    "memory": Field(IntSource, is_required=False),
    "ephemeral_storage": Field(IntSource, is_required=False),
    "task_overrides": Field(
        Permissive({}),
        is_required=False,
    ),
    "retries": get_retries_config(),
    "max_concurrent": Field(
        IntSource,
        is_required=False,
        description=(
            "Limit on the number of tasks that will run concurrently within the scope "
            "of a Dagster run. Note that this limit is per run, not global."
        ),
    ),
    "tag_concurrency_limits": get_tag_concurrency_limits_config(),
}


[docs] @executor( name="ecs", config_schema=_ECS_EXECUTOR_CONFIG_SCHEMA, requirements=multiple_process_executor_requirements(), ) @experimental def ecs_executor(init_context: InitExecutorContext) -> Executor: """Executor which launches steps as ECS tasks. To use the `ecs_executor`, set it as the `executor_def` when defining a job: .. literalinclude:: ../../../../../../python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_example_executor_mode_def.py :start-after: start_marker :end-before: end_marker :language: python Then you can configure the executor with run config as follows: .. code-block:: YAML execution: config: cpu: 1024 memory: 2048 ephemeral_storage: 10 task_overrides: containerOverrides: - name: run environment: - name: MY_ENV_VAR value: "my_value" `max_concurrent` limits the number of ECS tasks that will execute concurrently for one run. By default there is no limit- it will maximally parallel as allowed by the DAG. Note that this is not a global limit. Configuration set on the ECS tasks created by the `ECSRunLauncher` will also be set on the tasks created by the `ecs_executor`. Configuration set using `tags` on a `@job` will only apply to the `run` level. For configuration to apply at each `step` it must be set using `tags` for each `@op`. """ run_launcher = init_context.instance.run_launcher check.invariant( isinstance(run_launcher, EcsRunLauncher), "Using the ecs_executor currently requires that the run be launched in an ECS task via the EcsRunLauncher.", ) exc_cfg = init_context.executor_config return StepDelegatingExecutor( EcsStepHandler( run_launcher=run_launcher, # type: ignore run_task_kwargs=exc_cfg.get("run_task_kwargs"), # type: ignore cpu=exc_cfg.get("cpu"), # type: ignore memory=exc_cfg.get("memory"), # type: ignore ephemeral_storage=exc_cfg.get("ephemeral_storage"), # type: ignore task_overrides=exc_cfg.get("task_overrides"), # type:ignore ), retries=RetryMode.from_config(exc_cfg["retries"]), # type: ignore max_concurrent=check.opt_int_elem(exc_cfg, "max_concurrent"), tag_concurrency_limits=check.opt_list_elem(exc_cfg, "tag_concurrency_limits"), should_verify_step=True, )
@experimental class EcsStepHandler(StepHandler): @property def name(self): return "EcsStepHandler" def __init__( self, run_launcher: EcsRunLauncher, run_task_kwargs: Optional[Mapping[str, Any]], cpu: Optional[int], memory: Optional[int], ephemeral_storage: Optional[int], task_overrides: Optional[Mapping[str, Any]], ): super().__init__() run_task_kwargs = run_task_kwargs or {} self.ecs = boto3.client("ecs") self.ec2 = boto3.resource("ec2") # confusingly, run_task expects cpu and memory value as strings self._cpu = str(cpu) if cpu else None self._memory = str(memory) if memory else None self._ephemeral_storage = ephemeral_storage self._task_overrides = check.opt_mapping_param(task_overrides, "task_overrides") current_task_metadata = get_current_ecs_task_metadata() current_task = get_current_ecs_task( self.ecs, current_task_metadata.task_arn, current_task_metadata.cluster ) if run_launcher.use_current_ecs_task_config: current_task_kwargs = get_task_kwargs_from_current_task( self.ec2, current_task_metadata.cluster, current_task, ) else: current_task_kwargs = {} run_launcher_kwargs = {**current_task_kwargs, **run_launcher.run_task_kwargs} self._cluster_arn = current_task["clusterArn"] self._task_definition_arn = current_task["taskDefinitionArn"] self._run_task_kwargs = { "taskDefinition": current_task["taskDefinitionArn"], **run_launcher_kwargs, **run_task_kwargs, } # TODO: change launch_step to return task ARN # this will be a breaking change so we need to wait for a minor release # to do this self._launched_tasks = {} def _get_run_task_kwargs( self, run: DagsterRun, args: Sequence[str], step_key: str, step_tags: Mapping[str, str], step_handler_context: StepHandlerContext, container_context: EcsContainerContext, ): run_launcher = check.inst( step_handler_context.instance.run_launcher, EcsRunLauncher, "ECS executor can only be enabled with the ECS run launcher", ) run_task_kwargs = self._run_task_kwargs kwargs_from_tags = step_tags.get("ecs/run_task_kwargs") if kwargs_from_tags: run_task_kwargs = {**run_task_kwargs, **json.loads(kwargs_from_tags)} # convert tags to a dictionary for easy value overriding tags = { **{tag["key"]: tag["value"] for tag in run_task_kwargs.get("tags", [])}, **{ tag["key"]: tag["value"] for tag in run_launcher.build_ecs_tags_for_run_task(run, container_context) }, **step_handler_context.dagster_run.dagster_execution_info, "dagster/step-key": step_key, "dagster/step-id": self._get_step_id(step_handler_context), } run_task_kwargs["tags"] = [ { "key": key, "value": value, } for key, value in tags.items() ] task_overrides = self._get_task_overrides(step_tags) or {} task_overrides["containerOverrides"] = task_overrides.get("containerOverrides", []) # container name has to match since we are assuming we are using the same task executor_container_name = run_launcher.get_container_name(container_context) executor_env_vars = [ {"name": env["name"], "value": env["value"]} for env in step_handler_context.execute_step_args.get_command_env() ] # inject Executor command and env vars into the container overrides # if they are defined # otherwise create a new container overrides for the executor container for container_overrides in task_overrides["containerOverrides"]: # try to update existing container overrides for the executor container if container_overrides["name"] == executor_container_name: if "command" in container_overrides and container_overrides["command"] != args: raise DagsterInvariantViolationError( f"The 'command' field for {executor_container_name} container is not allowed in the 'containerOverrides' field of the task overrides." ) # update environment variables & command container_overrides["command"] = args container_overrides["environment"] = ( container_overrides.get("environment", []) + executor_env_vars ) break # if no existing container overrides for the executor container, add new container overrides else: task_overrides["containerOverrides"].append( { "name": executor_container_name, "command": args, "environment": executor_env_vars, } ) run_task_kwargs["overrides"] = deep_merge_dicts( run_task_kwargs.get("overrides", {}), task_overrides ) return run_task_kwargs def _get_task_overrides(self, step_tags: Mapping[str, str]) -> Dict[str, Any]: overrides = {**self._task_overrides} cpu = step_tags.get("ecs/cpu", self._cpu) memory = step_tags.get("ecs/memory", self._memory) if cpu: overrides["cpu"] = cpu if memory: overrides["memory"] = memory ephemeral_storage = step_tags.get("ecs/ephemeral_storage", self._ephemeral_storage) if ephemeral_storage: overrides["ephemeralStorage"] = {"sizeInGiB": int(ephemeral_storage)} if tag_overrides := step_tags.get("ecs/task_overrides"): overrides = deep_merge_dicts(overrides, json.loads(tag_overrides)) return overrides def _get_step_id(self, step_handler_context: StepHandlerContext): """Step ID is used to identify the ECS task in the ECS cluster. It is unique to specific step being executed and takes into account op-level retries. It's used as a workaround to avoid having to return task ARN from launch_step. """ step_key = self._get_step_key(step_handler_context) if step_handler_context.execute_step_args.known_state: retry_count = step_handler_context.execute_step_args.known_state.get_retry_state().get_attempt_count( step_key ) else: retry_count = 0 return "%s-%d" % (step_key, retry_count) def _get_step_key(self, step_handler_context: StepHandlerContext) -> str: step_keys_to_execute = cast( List[str], step_handler_context.execute_step_args.step_keys_to_execute ) assert len(step_keys_to_execute) == 1, "Launching multiple steps is not currently supported" return step_keys_to_execute[0] def _get_container_context( self, step_handler_context: StepHandlerContext ) -> EcsContainerContext: return EcsContainerContext.create_for_run( step_handler_context.dagster_run, cast(EcsRunLauncher, step_handler_context.instance.run_launcher), ) def _run_task(self, **run_task_kwargs): return run_ecs_task(self.ecs, run_task_kwargs) def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: step_key = self._get_step_key(step_handler_context) step_tags = step_handler_context.step_tags[step_key] container_context = self._get_container_context(step_handler_context) run = step_handler_context.dagster_run args = step_handler_context.execute_step_args.get_command_args( skip_serialized_namedtuple=True ) run_task_kwargs = self._get_run_task_kwargs( run, args, step_key, step_tags, step_handler_context=step_handler_context, container_context=container_context, ) task = backoff( self._run_task, retry_on=(RetryableEcsException,), kwargs=run_task_kwargs, max_retries=int( os.getenv("STEP_TASK_RETRIES", DEFAULT_STEP_TASK_RETRIES), ), ) yield DagsterEvent.step_worker_starting( step_handler_context.get_step_context(step_key), message=f'Executing step "{step_key}" in ECS task.', metadata={ "Task ARN": MetadataValue.text(task["taskArn"]), }, ) step_id = self._get_step_id(step_handler_context) self._launched_tasks[step_id] = task["taskArn"] def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult: step_key = self._get_step_key(step_handler_context) step_id = self._get_step_id(step_handler_context) try: task_arn = self._launched_tasks[step_id] except KeyError: return CheckStepHealthResult.unhealthy( reason=f"Task ARN for step {step_key} could not be found in executor's task map. This is likely a bug." ) cluster_arn = self._cluster_arn tasks = self.ecs.describe_tasks(tasks=[task_arn], cluster=cluster_arn).get("tasks") if not tasks: return CheckStepHealthResult.unhealthy( reason=f"Task {task_arn} for step {step_key} could not be found." ) t = tasks[0] if t.get("lastStatus") in STOPPED_STATUSES: failed_containers = [] for c in t.get("containers"): if c.get("exitCode") != 0: failed_containers.append(c) if len(failed_containers) > 0: cluster_failure_info = ( f"Task {t.get('taskArn')} failed.\n" f"Stop code: {t.get('stopCode')}.\n" f"Stop reason: {t.get('stoppedReason')}.\n" ) for c in failed_containers: exit_code = c.get("exitCode") exit_code_msg = f" - exit code {exit_code}" if exit_code is not None else "" cluster_failure_info += f"Container '{c.get('name')}' failed{exit_code_msg}.\n" return CheckStepHealthResult.unhealthy(reason=cluster_failure_info) return CheckStepHealthResult.healthy() def terminate_step( self, step_handler_context: StepHandlerContext, ) -> None: step_id = self._get_step_id(step_handler_context) step_key = self._get_step_key(step_handler_context) try: task_arn = self._launched_tasks[step_id] except KeyError: raise DagsterInvariantViolationError( f"Task ARN for step {step_key} could not be found in executor's task map. This is likely a bug." ) cluster_arn = self._cluster_arn DagsterEvent.engine_event( step_handler_context.get_step_context(step_key), message=f"Stopping task {task_arn} for step", event_specific_data=EngineEventData(), ) self.ecs.stop_task(task=task_arn, cluster=cluster_arn)