Ask AI

Source code for dagster_docker.docker_run_launcher

import json
from typing import Any, Mapping, Optional

import dagster._check as check
import docker
from dagster._core.launcher.base import (
    CheckRunHealthResult,
    LaunchRunContext,
    ResumeRunContext,
    RunLauncher,
    WorkerStatus,
)
from dagster._core.storage.dagster_run import DagsterRun
from dagster._core.storage.tags import DOCKER_IMAGE_TAG
from dagster._core.utils import parse_env_var
from dagster._grpc.types import ExecuteRunArgs, ResumeRunArgs
from dagster._serdes import ConfigurableClass
from dagster._serdes.config_class import ConfigurableClassData
from typing_extensions import Self

from dagster_docker.container_context import DockerContainerContext
from dagster_docker.utils import DOCKER_CONFIG_SCHEMA, validate_docker_config, validate_docker_image

DOCKER_CONTAINER_ID_TAG = "docker/container_id"


[docs] class DockerRunLauncher(RunLauncher, ConfigurableClass): """Launches runs in a Docker container.""" def __init__( self, inst_data: Optional[ConfigurableClassData] = None, image=None, registry=None, env_vars=None, network=None, networks=None, container_kwargs=None, ): self._inst_data = inst_data self.image = image self.registry = registry self.env_vars = env_vars validate_docker_config(network, networks, container_kwargs) if network: self.networks = [network] elif networks: self.networks = networks else: self.networks = [] self.container_kwargs = check.opt_dict_param( container_kwargs, "container_kwargs", key_type=str ) super().__init__() @property def inst_data(self): return self._inst_data @classmethod def config_type(cls): return DOCKER_CONFIG_SCHEMA @classmethod def from_config_value( cls, inst_data: ConfigurableClassData, config_value: Mapping[str, Any] ) -> Self: return cls(inst_data=inst_data, **config_value) def get_container_context(self, dagster_run: DagsterRun) -> DockerContainerContext: return DockerContainerContext.create_for_run(dagster_run, self) def _get_client(self, container_context: DockerContainerContext): client = docker.client.from_env() if container_context.registry: client.login( registry=container_context.registry["url"], username=container_context.registry["username"], password=container_context.registry["password"], ) return client def _get_docker_image(self, job_code_origin): docker_image = job_code_origin.repository_origin.container_image if not docker_image: docker_image = self.image if not docker_image: raise Exception("No docker image specified by the instance config or repository") validate_docker_image(docker_image) return docker_image def _launch_container_with_command(self, run, docker_image, command): container_context = self.get_container_context(run) docker_env = dict([parse_env_var(env_var) for env_var in container_context.env_vars]) docker_env["DAGSTER_RUN_JOB_NAME"] = run.job_name client = self._get_client(container_context) container_kwargs = {**container_context.container_kwargs} labels = container_kwargs.pop("labels", {}) if "stop_timeout" in container_kwargs: # This should work, but does not due to https://github.com/docker/docker-py/issues/3168 # Pull it out and apply it in the terminate() method instead del container_kwargs["stop_timeout"] if isinstance(labels, list): labels = {key: "" for key in labels} labels["dagster/run_id"] = run.run_id labels["dagster/job_name"] = run.job_name try: container = client.containers.create( image=docker_image, command=command, detach=True, environment=docker_env, network=container_context.networks[0] if len(container_context.networks) else None, labels=labels, **container_kwargs, ) except docker.errors.ImageNotFound: # pyright: ignore[reportAttributeAccessIssue] client.images.pull(docker_image) container = client.containers.create( image=docker_image, command=command, detach=True, environment=docker_env, network=container_context.networks[0] if len(container_context.networks) else None, labels=labels, **container_kwargs, ) if len(container_context.networks) > 1: for network_name in container_context.networks[1:]: network = client.networks.get(network_name) network.connect(container) self._instance.report_engine_event( message=f"Launching run in a new container {container.id} with image {docker_image}", dagster_run=run, cls=self.__class__, ) self._instance.add_run_tags( run.run_id, {DOCKER_CONTAINER_ID_TAG: container.id, DOCKER_IMAGE_TAG: docker_image}, # pyright: ignore[reportArgumentType] ) container.start() def launch_run(self, context: LaunchRunContext) -> None: run = context.dagster_run job_code_origin = check.not_none(context.job_code_origin) docker_image = self._get_docker_image(job_code_origin) command = ExecuteRunArgs( job_origin=job_code_origin, run_id=run.run_id, instance_ref=self._instance.get_ref(), ).get_command_args() self._launch_container_with_command(run, docker_image, command) @property def supports_resume_run(self): return True def resume_run(self, context: ResumeRunContext) -> None: run = context.dagster_run job_code_origin = check.not_none(context.job_code_origin) docker_image = self._get_docker_image(job_code_origin) command = ResumeRunArgs( job_origin=job_code_origin, run_id=run.run_id, instance_ref=self._instance.get_ref(), ).get_command_args() self._launch_container_with_command(run, docker_image, command) def _get_container(self, run): if not run: return None container_id = run.tags.get(DOCKER_CONTAINER_ID_TAG) if not container_id: return None container_context = self.get_container_context(run) try: return self._get_client(container_context).containers.get(container_id) except docker.errors.NotFound: # pyright: ignore[reportAttributeAccessIssue] return None def terminate(self, run_id): run = self._instance.get_run_by_id(run_id) if not run or run.is_finished: return False self._instance.report_run_canceling(run) container = self._get_container(run) container_context = self.get_container_context(run) stop_timeout = container_context.container_kwargs.get("stop_timeout") if not container: self._instance.report_engine_event( message="Unable to get docker container to send termination request to.", dagster_run=run, cls=self.__class__, ) return False container.stop(timeout=stop_timeout) return True @property def supports_check_run_worker_health(self): return True def check_run_worker_health(self, run: DagsterRun): container_id = run.tags.get(DOCKER_CONTAINER_ID_TAG) if not container_id: return CheckRunHealthResult(WorkerStatus.NOT_FOUND, msg="No container ID tag for run.") container = self._get_container(run) if container is None: return CheckRunHealthResult( WorkerStatus.NOT_FOUND, msg=f"Could not find container with ID {container_id}." ) if container.status == "running": return CheckRunHealthResult(WorkerStatus.RUNNING) container_state = container.attrs.get("State") failure_string = f"Container status is {container.status}." + ( f" Container state: {json.dumps(container_state)}" if container_state else "" ) return CheckRunHealthResult(WorkerStatus.FAILED, msg=failure_string)