Ask AI

Source code for dagster._core.launcher.default_run_launcher

import time
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast

from typing_extensions import Self

import dagster._seven as seven
from dagster import _check as check
from dagster._config.config_schema import UserConfigSchema
from dagster._core.errors import (
    DagsterInvariantViolationError,
    DagsterLaunchFailedError,
    DagsterUserCodeProcessError,
)
from dagster._core.launcher.base import LaunchRunContext, RunLauncher
from dagster._core.storage.dagster_run import DagsterRun
from dagster._core.storage.tags import GRPC_INFO_TAG
from dagster._serdes import ConfigurableClass, deserialize_value
from dagster._serdes.config_class import ConfigurableClassData
from dagster._utils.merger import merge_dicts

if TYPE_CHECKING:
    from dagster._core.instance import DagsterInstance
    from dagster._grpc.client import DagsterGrpcClient


# note: this class is a top level export, so we defer many imports til use for performance
[docs] class DefaultRunLauncher(RunLauncher, ConfigurableClass): """Launches runs against running GRPC servers.""" def __init__( self, inst_data: Optional[ConfigurableClassData] = None, ): self._inst_data = inst_data self._run_ids = set() super().__init__() @property def inst_data(self) -> Optional[ConfigurableClassData]: return self._inst_data @classmethod def config_type(cls) -> UserConfigSchema: return {} @classmethod def from_config_value( cls, inst_data: ConfigurableClassData, config_value: Mapping[str, Any] ) -> Self: return cls(inst_data=inst_data) @staticmethod def launch_run_from_grpc_client( instance: "DagsterInstance", run: DagsterRun, grpc_client: "DagsterGrpcClient" ): # defer for perf from dagster._grpc.types import ExecuteExternalJobArgs, StartRunResult instance.add_run_tags( run.run_id, { GRPC_INFO_TAG: seven.json.dumps( merge_dicts( {"host": grpc_client.host}, ( {"port": grpc_client.port} if grpc_client.port else {"socket": grpc_client.socket} ), ({"use_ssl": True} if grpc_client.use_ssl else {}), ) ) }, ) res = deserialize_value( grpc_client.start_run( ExecuteExternalJobArgs( job_origin=run.remote_job_origin, # type: ignore # (possible none) run_id=run.run_id, instance_ref=instance.get_ref(), ) ), StartRunResult, ) if not res.success: raise ( DagsterLaunchFailedError( res.message, serializable_error_info=res.serializable_error_info ) ) def launch_run(self, context: LaunchRunContext) -> None: # defer for perf from dagster._core.remote_representation.code_location import GrpcServerCodeLocation run = context.dagster_run check.inst_param(run, "run", DagsterRun) if not context.workspace: raise DagsterInvariantViolationError( "DefaultRunLauncher requires a workspace to be included in its LaunchRunContext" ) remote_job_origin = check.not_none(run.remote_job_origin) code_location = context.workspace.get_code_location( remote_job_origin.repository_origin.code_location_origin.location_name ) check.inst( code_location, GrpcServerCodeLocation, "DefaultRunLauncher: Can't launch runs for pipeline not loaded from a GRPC server", ) DefaultRunLauncher.launch_run_from_grpc_client( self._instance, run, cast(GrpcServerCodeLocation, code_location).client ) self._run_ids.add(run.run_id) def _get_grpc_client_for_termination(self, run_id): # defer for perf from dagster._grpc.client import DagsterGrpcClient if not self.has_instance: return None run = self._instance.get_run_by_id(run_id) if not run or run.is_finished: return None tags = run.tags if GRPC_INFO_TAG not in tags: return None grpc_info = seven.json.loads(tags.get(GRPC_INFO_TAG)) return DagsterGrpcClient( port=grpc_info.get("port"), socket=grpc_info.get("socket"), host=grpc_info.get("host"), use_ssl=bool(grpc_info.get("use_ssl", False)), ) def terminate(self, run_id): # defer for perf from dagster._grpc.types import CancelExecutionRequest, CancelExecutionResult check.str_param(run_id, "run_id") if not self.has_instance: return False run = self._instance.get_run_by_id(run_id) if not run or run.is_finished: return False self._instance.report_run_canceling(run) client = self._get_grpc_client_for_termination(run_id) if not client: self._instance.report_engine_event( message="Unable to get grpc client to send termination request to.", dagster_run=run, cls=self.__class__, ) return False res = deserialize_value( client.cancel_execution(CancelExecutionRequest(run_id=run_id)), CancelExecutionResult ) if res.serializable_error_info: raise DagsterUserCodeProcessError.from_error_info(res.serializable_error_info) return res.success def join(self, timeout=30): # If this hasn't been initialized at all, we can just do a noop if not self.has_instance: return total_time = 0 interval = 0.01 while True: active_run_ids = [ run_id for run_id in self._run_ids if ( self._instance.get_run_by_id(run_id) and not self._instance.get_run_by_id(run_id).is_finished ) ] if len(active_run_ids) == 0: return if total_time >= timeout: raise Exception(f"Timed out waiting for these runs to finish: {active_run_ids!r}") total_time += interval time.sleep(interval) interval = interval * 2