Source code for dagster_databricks.databricks

import base64
import time

import dagster
import requests.exceptions
from dagster import check
from databricks_api import DatabricksAPI

from .types import (

# wait at most 24 hours by default for run execution

[docs]class DatabricksError(Exception): pass
class DatabricksClient: """A thin wrapper over the Databricks REST API.""" def __init__(self, host, token, workspace_id=None): = host self.workspace_id = workspace_id self.client = DatabricksAPI(host=host, token=token) def submit_run(self, *args, **kwargs): """Submit a run directly to the 'Runs Submit' API.""" return*args, **kwargs)["run_id"] # pylint: disable=no-member def read_file(self, dbfs_path, block_size=1024 ** 2): """Read a file from DBFS to a **byte string**.""" if dbfs_path.startswith("dbfs://"): dbfs_path = dbfs_path[7:] data = b"" bytes_read = 0 jdoc =, length=block_size) # pylint: disable=no-member data += base64.b64decode(jdoc["data"]) while jdoc["bytes_read"] == block_size: bytes_read += jdoc["bytes_read"] jdoc = # pylint: disable=no-member path=dbfs_path, offset=bytes_read, length=block_size ) data += base64.b64decode(jdoc["data"]) return data def put_file(self, file_obj, dbfs_path, overwrite=False, block_size=1024 ** 2): """Upload an arbitrary large file to DBFS. This doesn't use the DBFS `Put` API because that endpoint is limited to 1MB. """ if dbfs_path.startswith("dbfs://"): dbfs_path = dbfs_path[7:] create_response = self.client.dbfs.create( # pylint: disable=no-member path=dbfs_path, overwrite=overwrite ) handle = create_response["handle"] block = while block: data = base64.b64encode(block).decode("utf-8") self.client.dbfs.add_block(data=data, handle=handle) # pylint: disable=no-member block = self.client.dbfs.close(handle=handle) # pylint: disable=no-member def get_run_state(self, databricks_run_id): """Get the state of a run by Databricks run ID (_not_ dagster run ID). Return a `DatabricksRunState` object. Note that the `result_state` attribute may be `None` if the run hasn't yet terminated. """ run = # pylint: disable=no-member state = run["state"] result_state = state.get("result_state") if result_state: result_state = DatabricksRunResultState(result_state) return DatabricksRunState( life_cycle_state=DatabricksRunLifeCycleState(state["life_cycle_state"]), result_state=result_state, state_message=state["state_message"], ) class DatabricksRunState: """Represents the state of a Databricks job run.""" def __init__(self, life_cycle_state, result_state, state_message): self.life_cycle_state = life_cycle_state self.result_state = result_state self.state_message = state_message def has_terminated(self): """Has the job terminated?""" return self.life_cycle_state in DATABRICKS_RUN_TERMINATED_STATES def is_successful(self): """Was the job successful?""" return self.result_state == DatabricksRunResultState.Success def __repr__(self): return str(self.__dict__) class DatabricksJobRunner: """Submits jobs created using Dagster config to Databricks, and monitors their progress.""" def __init__( self, host, token, poll_interval_sec=10, max_wait_time_sec=_DEFAULT_RUN_MAX_WAIT_TIME_SEC ): """Args: host (str): Databricks host, e.g. token (str): Databricks token """ = check.str_param(host, "host") self.token = check.str_param(token, "token") self.poll_interval_sec = check.numeric_param(poll_interval_sec, "poll_interval_sec") self.max_wait_time_sec = check.int_param(max_wait_time_sec, "max_wait_time_sec") self._client = DatabricksClient(, token=self.token) @property def client(self): """Return the underlying `DatabricksClient` object.""" return self._client def submit_run(self, run_config, task): """Submit a new run using the 'Runs submit' API.""" existing_cluster_id = run_config["cluster"].get("existing") new_cluster = run_config["cluster"].get("new") # The Databricks API needs different keys to be present in API calls depending # on new/existing cluster, so we need to process the new_cluster # config first. if new_cluster: new_cluster = new_cluster.copy() nodes = new_cluster.pop("nodes") if "instance_pool_id" in nodes: new_cluster["instance_pool_id"] = nodes["instance_pool_id"] else: node_types = nodes["node_types"] new_cluster["node_type_id"] = node_types["node_type_id"] if "driver_node_type_id" in node_types: new_cluster["driver_node_type_id"] = node_types["driver_node_type_id"] cluster_size = new_cluster.pop("size") if "num_workers" in cluster_size: new_cluster["num_workers"] = cluster_size["num_workers"] else: new_cluster["autoscale"] = cluster_size["autoscale"] tags = new_cluster.get("custom_tags", []) tags.append({"key": "__dagster_version", "value": dagster.__version__}) new_cluster["custom_tags"] = tags check.invariant( existing_cluster_id is not None or new_cluster is not None, "Invalid value for run_config.cluster", ) # We'll always need some libraries, namely dagster/dagster_databricks/dagster_pyspark, # since they're imported by our scripts. # Add them if they're not already added by users in config. libraries = list(run_config.get("libraries", [])) python_libraries = {x["pypi"]["package"].split("==")[0] for x in libraries if "pypi" in x} for library in ["dagster", "dagster_databricks", "dagster_pyspark"]: if library not in python_libraries: libraries.append( {"pypi": {"package": "{}=={}".format(library, dagster.__version__)}} ) # Only one task should be able to be chosen really; make sure of that here. check.invariant( sum( task.get(key) is not None for key in [ "notebook_task", "spark_python_task", "spark_jar_task", "spark_submit_task", ] ) == 1, "Multiple tasks specified in Databricks run", ) config = dict( run_name=run_config.get("run_name"), new_cluster=new_cluster, existing_cluster_id=existing_cluster_id, libraries=libraries, **task, ) return self.client.submit_run(**config) def retrieve_logs_for_run_id(self, log, databricks_run_id): """Retrieve the stdout and stderr logs for a run.""" api_client = self.client.client run = # pylint: disable=no-member cluster = api_client.cluster.get_cluster( # pylint: disable=no-member run["cluster_instance"]["cluster_id"] ) log_config = cluster.get("cluster_log_conf") if log_config is None: log.warn( "Logs not configured for cluster {cluster} used for run {run}".format( cluster=cluster["cluster_id"], run=databricks_run_id ) ) return None if "s3" in log_config: logs_prefix = log_config["s3"]["destination"] log.warn("Retrieving S3 logs not yet implemented") return None elif "dbfs" in log_config: logs_prefix = log_config["dbfs"]["destination"] stdout = self.wait_for_dbfs_logs(log, logs_prefix, cluster["cluster_id"], "stdout") stderr = self.wait_for_dbfs_logs(log, logs_prefix, cluster["cluster_id"], "stderr") return stdout, stderr def wait_for_dbfs_logs( self, log, prefix, cluster_id, filename, waiter_delay=10, waiter_max_attempts=10 ): """Attempt up to `waiter_max_attempts` attempts to get logs from DBFS.""" path = "/".join([prefix, cluster_id, "driver", filename])"Retrieving logs from {}".format(path)) num_attempts = 0 while num_attempts <= waiter_max_attempts: try: logs = self.client.read_file(path) return logs.decode("utf-8") except requests.exceptions.HTTPError: num_attempts += 1 time.sleep(waiter_delay) log.warn("Could not retrieve cluster logs!") def wait_for_run_to_complete(self, log, databricks_run_id): return wait_for_run_to_complete( self.client, log, databricks_run_id, self.poll_interval_sec, self.max_wait_time_sec ) def wait_for_run_to_complete(client, log, databricks_run_id, poll_interval_sec, max_wait_time_sec): """Wait for a Databricks run to complete.""" check.int_param(databricks_run_id, "databricks_run_id")"Waiting for Databricks run %s to complete..." % databricks_run_id) start = time.time() while True: log.debug("Waiting %.1f seconds..." % poll_interval_sec) time.sleep(poll_interval_sec) run_state = client.get_run_state(databricks_run_id) if run_state.has_terminated(): if run_state.is_successful():"Run %s completed successfully" % databricks_run_id) return else: error_message = "Run %s failed with result state: %s. Message: %s" % ( databricks_run_id, run_state.result_state, run_state.state_message, ) log.error(error_message) raise DatabricksError(error_message) else:"Run %s in state %s" % (databricks_run_id, run_state)) if time.time() - start > max_wait_time_sec: raise DatabricksError( "Job run {} took more than {}s to complete; failing".format( databricks_run_id, max_wait_time_sec ) )