Ask AI

Source code for dagster_dask.executor

from typing import Any, Mapping, Optional, Sequence

import dask
import dask.distributed
from dagster import (
    Executor,
    Field,
    Permissive,
    Selector,
    StringSource,
    _check as check,
    _seven,
    multiple_process_executor_requirements,
)
from dagster._core.definitions.executor_definition import executor
from dagster._core.definitions.reconstruct import ReconstructableJob
from dagster._core.errors import raise_execution_interrupts
from dagster._core.events import DagsterEvent
from dagster._core.execution.api import create_execution_plan, execute_plan
from dagster._core.execution.context.system import PlanOrchestrationContext
from dagster._core.execution.plan.plan import ExecutionPlan
from dagster._core.execution.plan.state import KnownExecutionState
from dagster._core.execution.retries import RetryMode
from dagster._core.instance import DagsterInstance
from dagster._core.instance.ref import InstanceRef
from dagster._core.storage.dagster_run import DagsterRun
from dagster._utils import iterate_with_context

# Dask resource requirements are specified under this key
DASK_RESOURCE_REQUIREMENTS_KEY = "dagster-dask/resource_requirements"


[docs]@executor( name="dask", requirements=multiple_process_executor_requirements(), config_schema={ "cluster": Field( Selector( { "existing": Field( {"address": StringSource}, description="Connect to an existing scheduler.", ), "local": Field( Permissive(), is_required=False, description="Local cluster configuration." ), "yarn": Field( Permissive(), is_required=False, description="YARN cluster configuration." ), "ssh": Field( Permissive(), is_required=False, description="SSH cluster configuration." ), "pbs": Field( Permissive(), is_required=False, description="PBS cluster configuration." ), "moab": Field( Permissive(), is_required=False, description="Moab cluster configuration." ), "sge": Field( Permissive(), is_required=False, description="SGE cluster configuration." ), "lsf": Field( Permissive(), is_required=False, description="LSF cluster configuration." ), "slurm": Field( Permissive(), is_required=False, description="SLURM cluster configuration." ), "oar": Field( Permissive(), is_required=False, description="OAR cluster configuration." ), "kube": Field( Permissive(), is_required=False, description="Kubernetes cluster configuration.", ), } ) ) }, ) def dask_executor(init_context): """Dask-based executor. The 'cluster' can be one of the following: ('existing', 'local', 'yarn', 'ssh', 'pbs', 'moab', 'sge', 'lsf', 'slurm', 'oar', 'kube'). If the Dask executor is used without providing executor-specific config, a local Dask cluster will be created (as when calling :py:class:`dask.distributed.Client() <dask:distributed.Client>` with :py:class:`dask.distributed.LocalCluster() <dask:distributed.LocalCluster>`). The Dask executor optionally takes the following config: .. code-block:: none cluster: { local?: # takes distributed.LocalCluster parameters { timeout?: 5, # Timeout duration for initial connection to the scheduler n_workers?: 4 # Number of workers to start threads_per_worker?: 1 # Number of threads per each worker } } To use the `dask_executor`, set it as the `executor_def` when defining a job: .. code-block:: python from dagster import job from dagster_dask import dask_executor @job(executor_def=dask_executor) def dask_enabled_job(): pass """ ((cluster_type, cluster_configuration),) = init_context.executor_config["cluster"].items() return DaskExecutor(cluster_type, cluster_configuration)
def query_on_dask_worker( dependencies: Any, recon_job: ReconstructableJob, dagster_run: DagsterRun, run_config: Optional[Mapping[str, object]], step_keys: Optional[Sequence[str]], instance_ref: InstanceRef, known_state: Optional[KnownExecutionState], ) -> Sequence[DagsterEvent]: """Note that we need to pass "dependencies" to ensure Dask sequences futures during task scheduling, even though we do not use this argument within the function. """ with DagsterInstance.from_ref(instance_ref) as instance: subset_job = recon_job.get_subset(op_selection=dagster_run.resolved_op_selection) execution_plan = create_execution_plan( subset_job, run_config=run_config, step_keys_to_execute=step_keys, known_state=known_state, ) return execute_plan( execution_plan, subset_job, instance, dagster_run, run_config=run_config ) def get_dask_resource_requirements(tags: Mapping[str, str]): check.mapping_param(tags, "tags", key_type=str, value_type=str) req_str = tags.get(DASK_RESOURCE_REQUIREMENTS_KEY) if req_str is not None: return _seven.json.loads(req_str) return {} class DaskExecutor(Executor): def __init__(self, cluster_type, cluster_configuration): self.cluster_type = check.opt_str_param(cluster_type, "cluster_type", default="local") self.cluster_configuration = check.opt_dict_param( cluster_configuration, "cluster_configuration" ) @property def retries(self): return RetryMode.DISABLED def execute(self, plan_context: PlanOrchestrationContext, execution_plan: ExecutionPlan): check.inst_param(plan_context, "plan_context", PlanOrchestrationContext) check.inst_param(execution_plan, "execution_plan", ExecutionPlan) check.param_invariant( isinstance(plan_context.executor, DaskExecutor), "plan_context", f"Expected executor to be DaskExecutor got {plan_context.executor}", ) check.invariant( plan_context.instance.is_persistent, "Dask execution requires a persistent DagsterInstance", ) step_levels = execution_plan.get_steps_to_execute_by_level() job_name = plan_context.job_name instance = plan_context.instance cluster_type = self.cluster_type if cluster_type == "existing": # address passed directly to Client() below to connect to existing Scheduler cluster = self.cluster_configuration["address"] elif cluster_type == "local": from dask.distributed import LocalCluster cluster = LocalCluster(**self.build_dict(job_name)) elif cluster_type == "yarn": from dask_yarn import YarnCluster cluster = YarnCluster(**self.build_dict(job_name)) elif cluster_type == "ssh": from dask.distributed import SSHCluster cluster = SSHCluster(**self.build_dict(job_name)) elif cluster_type == "pbs": from dask_jobqueue import PBSCluster cluster = PBSCluster(**self.build_dict(job_name)) elif cluster_type == "moab": from dask_jobqueue import MoabCluster cluster = MoabCluster(**self.build_dict(job_name)) elif cluster_type == "sge": from dask_jobqueue import SGECluster cluster = SGECluster(**self.build_dict(job_name)) elif cluster_type == "lsf": from dask_jobqueue import LSFCluster cluster = LSFCluster(**self.build_dict(job_name)) elif cluster_type == "slurm": from dask_jobqueue import SLURMCluster cluster = SLURMCluster(**self.build_dict(job_name)) elif cluster_type == "oar": from dask_jobqueue import OARCluster cluster = OARCluster(**self.build_dict(job_name)) elif cluster_type == "kube": from dask_kubernetes import KubeCluster cluster = KubeCluster(**self.build_dict(job_name)) else: raise ValueError( "Must be providing one of the following ('existing', 'local', 'yarn', 'ssh'," f" 'pbs', 'moab', 'sge', 'lsf', 'slurm', 'oar', 'kube') not {cluster_type}" ) with dask.distributed.Client(cluster) as client: execution_futures = [] execution_futures_dict = {} for step_level in step_levels: for step in step_level: # We ensure correctness in sequencing by letting Dask schedule futures and # awaiting dependencies within each step. dependencies = [] for step_input in step.step_inputs: for key in step_input.dependency_keys: dependencies.append(execution_futures_dict[key]) run_config = plan_context.run_config dask_task_name = "%s.%s" % (job_name, step.key) recon_job = plan_context.reconstructable_job future = client.submit( query_on_dask_worker, dependencies, recon_job, plan_context.dagster_run, run_config, [step.key], instance.get_ref(), execution_plan.known_state, key=dask_task_name, resources=get_dask_resource_requirements(step.tags), ) execution_futures.append(future) execution_futures_dict[step.key] = future # This tells Dask to awaits the step executions and retrieve their results to the # master futures = dask.distributed.as_completed(execution_futures, with_results=True) # Allow interrupts while waiting for the results from Dask for future, result in iterate_with_context(raise_execution_interrupts, futures): for step_event in result: yield check.inst(step_event, DagsterEvent) def build_dict(self, job_name): """Returns a dict we can use for kwargs passed to dask client instantiation. Intended to be used like: with dask.distributed.Client(**cfg.build_dict()) as client: << use client here >> """ if self.cluster_type in ["yarn", "pbs", "moab", "sge", "lsf", "slurm", "oar", "kube"]: dask_cfg = {"name": job_name} else: dask_cfg = {} if self.cluster_configuration: for k, v in self.cluster_configuration.items(): dask_cfg[k] = v # if address is set, don't add LocalCluster args # context: https://github.com/dask/distributed/issues/3313 if (self.cluster_type == "local") and ("address" not in dask_cfg): # We set threads_per_worker because Dagster is not thread-safe. Even though # environments=True by default, there is a clever piece of machinery # (dask.distributed.deploy.local.nprocesses_nthreads) that automagically makes execution # multithreaded by default when the number of available cores is greater than 4. # See: https://github.com/dagster-io/dagster/issues/2181 # We may want to try to figure out a way to enforce this on remote Dask clusters against # which users run Dagster workloads. dask_cfg["threads_per_worker"] = 1 return dask_cfg