Ask AI

Source code for dagster_databricks.databricks

import base64
import logging
import os
import time
from enum import Enum
from importlib.metadata import version
from typing import IO, Any, List, Mapping, Optional, Tuple, Union, cast

import dagster
import dagster._check as check
import dagster_pyspark
import databricks_api
import databricks_cli.sdk
import requests.exceptions
from dagster._annotations import deprecated, public
from databricks.sdk import WorkspaceClient
from databricks.sdk.core import (
    Config,
    DefaultCredentials,
    azure_service_principal,
    oauth_service_principal,
    pat_auth,
)
from databricks.sdk.service import compute, jobs
from typing_extensions import Final

import dagster_databricks

from .types import (
    DatabricksRunState,
)
from .version import __version__

# wait at most 24 hours by default for run execution
DEFAULT_RUN_MAX_WAIT_TIME_SEC: Final = 24 * 60 * 60


[docs]class DatabricksError(Exception): pass
class AuthTypeEnum(Enum): OAUTH_M2M = "oauth-m2m" PAT = "pat" AZURE_CLIENT_SECRET = "azure-client-secret" DEFAULT = "default" class WorkspaceClientFactory: def __init__( self, host: Optional[str], token: Optional[str], oauth_client_id: Optional[str], oauth_client_secret: Optional[str], azure_client_id: Optional[str], azure_client_secret: Optional[str], azure_tenant_id: Optional[str], ): """Initialize the Databricks Workspace client. Users may provide explicit credentials for a PAT, databricks service principal oauth credentials, or azure service principal credentials. If no credentials are provided, the underlying WorkspaceClient from `databricks.sdk` will attempt to read credentials from the environment or from the `~/.databrickscfg` file. For more information, see the Databricks SDK docs on various ways you can authenticate with the WorkspaceClient, through which most interactions with the Databricks API occur. <https://docs.databricks.com/en/dev-tools/auth.html#authentication-for-databricks-automation>`_. """ self._raise_if_multiple_auth_types( token=token, oauth_client_id=oauth_client_id, oauth_client_secret=oauth_client_secret, azure_client_id=azure_client_id, azure_client_secret=azure_client_secret, azure_tenant_id=azure_tenant_id, ) self._assert_valid_credentials_combos( oauth_client_id=oauth_client_id, oauth_client_secret=oauth_client_secret, azure_client_id=azure_client_id, azure_client_secret=azure_client_secret, azure_tenant_id=azure_tenant_id, ) auth_type = self._get_auth_type( token, oauth_client_id, oauth_client_secret, azure_client_id, azure_client_secret, azure_tenant_id, ) product_info = {"product": "dagster-databricks", "product_version": __version__} # Figure out what credentials provider to use based on any explicitly-provided credentials. If none were # provided, then fallback to the default credentials provider, which will attempt to read credentials from # the environment or from a `~/.databrickscfg` file, if it exists. if auth_type == AuthTypeEnum.OAUTH_M2M: host = self._resolve_host(host) c = Config( host=host, client_id=oauth_client_id, client_secret=oauth_client_secret, credentials_provider=oauth_service_principal, **product_info, ) elif auth_type == AuthTypeEnum.PAT: host = self._resolve_host(host) c = Config(host=host, token=token, credentials_provider=pat_auth, **product_info) elif auth_type == AuthTypeEnum.AZURE_CLIENT_SECRET: host = self._resolve_host(host) c = Config( host=host, azure_client_id=azure_client_id, azure_client_secret=azure_client_secret, azure_tenant_id=azure_tenant_id, credentials_provider=azure_service_principal, **product_info, ) elif auth_type == AuthTypeEnum.DEFAULT: # Can be used to automatically read credentials from environment or ~/.databrickscfg file. This is common # when launching Databricks jobs from a laptop development setting through Dagster if host is not None: # This allows for explicit override of the host, while letting other credentials be read from the # environment or ~/.databrickscfg file c = Config(host=host, credentials_provider=DefaultCredentials(), **product_info) else: # The initialization machinery in the Config object will look for the host and other auth info in the # environment, as long as no values are provided for those attributes (including None) c = Config(credentials_provider=DefaultCredentials(), **product_info) else: raise ValueError(f"Unexpected auth type {auth_type}") self.config = c def _raise_if_multiple_auth_types( self, token: Optional[str] = None, oauth_client_id: Optional[str] = None, oauth_client_secret: Optional[str] = None, azure_client_id: Optional[str] = None, azure_client_secret: Optional[str] = None, azure_tenant_id: Optional[str] = None, ): more_than_one_auth_type_provided = ( sum( [ True for _ in [ token, (oauth_client_id and oauth_client_secret), (azure_client_id and azure_client_secret and azure_tenant_id), ] if _ ] ) > 1 ) if more_than_one_auth_type_provided: raise ValueError( "Can only provide one of token, oauth credentials, or azure credentials" ) @staticmethod def _get_auth_type( token: Optional[str], oauth_client_id: Optional[str], oauth_client_secret: Optional[str], azure_client_id: Optional[str], azure_client_secret: Optional[str], azure_tenant_id: Optional[str], ) -> AuthTypeEnum: """Get the type of authentication used to initialize the WorkspaceClient.""" if oauth_client_id and oauth_client_secret: auth_type = AuthTypeEnum.OAUTH_M2M elif token: auth_type = AuthTypeEnum.PAT elif azure_client_id and azure_client_secret and azure_tenant_id: auth_type = AuthTypeEnum.AZURE_CLIENT_SECRET else: auth_type = AuthTypeEnum.DEFAULT return auth_type @staticmethod def _assert_valid_credentials_combos( oauth_client_id: Optional[str] = None, oauth_client_secret: Optional[str] = None, azure_client_id: Optional[str] = None, azure_client_secret: Optional[str] = None, azure_tenant_id: Optional[str] = None, ): """Ensure that all required credentials are provided for the given auth type.""" if ( oauth_client_id and not oauth_client_secret or oauth_client_secret and not oauth_client_id ): raise ValueError( "If using databricks service principal oauth credentials, both oauth_client_id and" " oauth_client_secret must be provided" ) if ( (azure_client_id and not azure_client_secret and not azure_tenant_id) or (azure_client_secret and not azure_client_id and not azure_tenant_id) or (azure_tenant_id and not azure_client_id and not azure_client_secret) ): raise ValueError( "If using azure service principal auth, azure_client_id, azure_client_secret, and" " azure_tenant_id must be provided" ) def get_workspace_client(self) -> WorkspaceClient: return WorkspaceClient(config=self.config) @staticmethod def _resolve_host(host: Optional[str]) -> str: host = host if host else os.getenv("DATABRICKS_HOST") if host is None: raise ValueError( "Must provide host explicitly or in DATABRICKS_HOST env var when providing" " credentials explicitly" ) return host
[docs]class DatabricksClient: """A thin wrapper over the Databricks REST API.""" def __init__( self, host: Optional[str] = None, token: Optional[str] = None, oauth_client_id: Optional[str] = None, oauth_client_secret: Optional[str] = None, azure_client_id: Optional[str] = None, azure_client_secret: Optional[str] = None, azure_tenant_id: Optional[str] = None, workspace_id: Optional[str] = None, ): self.host = host self.workspace_id = workspace_id workspace_client_factory = WorkspaceClientFactory( oauth_client_id=oauth_client_id, oauth_client_secret=oauth_client_secret, azure_client_id=azure_client_id, azure_client_secret=azure_client_secret, azure_tenant_id=azure_tenant_id, token=token, host=host, ) self._workspace_client = workspace_client_factory.get_workspace_client() # TODO: This is the old shim client that we were previously using. Arguably this is # confusing for users to use since this is an unofficial wrapper around the documented # Databricks REST API. We should consider removing this in the next minor release. if token: self._client = databricks_api.DatabricksAPI(host=host, token=token) self.__setup_user_agent(self._client.client) # TODO: This is the old `databricks_cli` client that was previously recommended by Databricks. # It is no longer supported and should be removed in favour of `databricks-sdk` in the next # minor release. self._api_client = databricks_cli.sdk.ApiClient(host=host, token=token) self.__setup_user_agent(self._api_client) else: self._client = None self._api_client = None def __setup_user_agent( self, client: Union[WorkspaceClient, databricks_api.DatabricksAPI, databricks_cli.sdk.ApiClient], ) -> None: """Overrides the user agent for the Databricks API client.""" client.default_headers["user-agent"] = f"dagster-databricks/{__version__}" @deprecated( breaking_version="0.21.0", additional_warn_text="Use `workspace_client` property instead." ) @public @property def client(self) -> databricks_api.DatabricksAPI: """Retrieve the legacy Databricks API client. Note: accessing this property will throw an exception if oauth credentials are used to initialize the DatabricksClient, because oauth credentials are not supported by the legacy Databricks API client. """ if self._client is None: raise ValueError( "Legacy Databricks API client from `databricks-api` was not initialized because" " oauth credentials were used instead of an access token. This legacy Databricks" " API client is not supported when using oauth credentials. Use the" " `workspace_client` property instead." ) return self._client @client.setter def client(self, value: Optional[databricks_api.DatabricksAPI]) -> None: self._client = value @deprecated( breaking_version="0.21.0", additional_warn_text="Use `workspace_client` property instead." ) @public @property def api_client(self) -> databricks_cli.sdk.ApiClient: """Retrieve a reference to the underlying Databricks API client. For more information, see the `Databricks Python API <https://docs.databricks.com/dev-tools/python-api.html>`_. Noe: accessing this property will throw an exception if oauth credentials are used to initialize the DatabricksClient, because oauth credentials are not supported by the legacy Databricks API client. **Examples:**. .. code-block:: python from dagster import op from databricks_cli.jobs.api import JobsApi from databricks_cli.runs.api import RunsApi from databricks.sdk import WorkspaceClient @op(required_resource_keys={"databricks_client"}) def op1(context): # Initialize the Databricks Jobs API jobs_client = JobsApi(context.resources.databricks_client.api_client) runs_client = RunsApi(context.resources.databricks_client.api_client) client = context.resources.databricks_client.api_client # Example 1: Run a Databricks job with some parameters. jobs_client.run_now(...) client.jobs.run_now(...) # Example 2: Trigger a one-time run of a Databricks workload. runs_client.submit_run(...) client.jobs.submit(...) # Example 3: Get an existing run. runs_client.get_run(...) client.jobs.get_run(...) # Example 4: Cancel a run. runs_client.cancel_run(...) client.jobs.cancel_run(...) Returns: ApiClient: The authenticated Databricks API client. """ if self._api_client is None: raise ValueError( "Legacy Databricks API client from `databricks-cli` was not initialized because" " oauth credentials were used instead of an access token. This legacy Databricks" " API client is not supported when using oauth credentials. Use the" " `workspace_client` property instead." ) return self._api_client @public @property def workspace_client(self) -> WorkspaceClient: """Retrieve a reference to the underlying Databricks Workspace client. For more information, see the `Databricks SDK for Python <https://docs.databricks.com/dev-tools/sdk-python.html>`_. **Examples:** .. code-block:: python from dagster import op from databricks.sdk import WorkspaceClient @op(required_resource_keys={"databricks_client"}) def op1(context): # Initialize the Databricks Jobs API client = context.resources.databricks_client.api_client # Example 1: Run a Databricks job with some parameters. client.jobs.run_now(...) # Example 2: Trigger a one-time run of a Databricks workload. client.jobs.submit(...) # Example 3: Get an existing run. client.jobs.get_run(...) # Example 4: Cancel a run. client.jobs.cancel_run(...) Returns: WorkspaceClient: The authenticated Databricks SDK Workspace Client. """ return self._workspace_client def read_file(self, dbfs_path: str, block_size: int = 1024**2) -> bytes: """Read a file from DBFS to a **byte string**.""" if dbfs_path.startswith("dbfs://"): dbfs_path = dbfs_path[7:] data = b"" bytes_read = 0 dbfs_service = self.workspace_client.dbfs jdoc = dbfs_service.read(path=dbfs_path, length=block_size) data += base64.b64decode(jdoc.data) while jdoc.bytes_read == block_size: bytes_read += jdoc.bytes_read jdoc = dbfs_service.read(path=dbfs_path, offset=bytes_read, length=block_size) data += base64.b64decode(jdoc.data) return data def put_file( self, file_obj: IO, dbfs_path: str, overwrite: bool = False, block_size: int = 1024**2 ) -> None: """Upload an arbitrary large file to DBFS. This doesn't use the DBFS `Put` API because that endpoint is limited to 1MB. """ if dbfs_path.startswith("dbfs://"): dbfs_path = dbfs_path[7:] dbfs_service = self.workspace_client.dbfs create_response = dbfs_service.create(path=dbfs_path, overwrite=overwrite) handle = create_response.handle block = file_obj.read(block_size) while block: data = base64.b64encode(block).decode("utf-8") dbfs_service.add_block(data=data, handle=handle) block = file_obj.read(block_size) dbfs_service.close(handle=handle) def get_run_state(self, databricks_run_id: int) -> "DatabricksRunState": """Get the state of a run by Databricks run ID. Return a `DatabricksRunState` object. Note that the `result_state` attribute may be `None` if the run hasn't yet terminated. """ run = self.workspace_client.jobs.get_run(databricks_run_id) return DatabricksRunState.from_databricks(run.state) def poll_run_state( self, logger: logging.Logger, start_poll_time: float, databricks_run_id: int, max_wait_time_sec: float, verbose_logs: bool = True, ) -> bool: run_state = self.get_run_state(databricks_run_id) if run_state.has_terminated(): if run_state.is_successful(): logger.info(f"Run `{databricks_run_id}` completed successfully.") return True if run_state.is_skipped(): logger.info(f"Run `{databricks_run_id}` was skipped.") return True else: error_message = ( f"Run `{databricks_run_id}` failed with result state:" f" `{run_state.result_state}`. Message: {run_state.state_message}." ) logger.error(error_message) raise DatabricksError(error_message) else: if verbose_logs: logger.debug(f"Run `{databricks_run_id}` in state {run_state}.") if time.time() - start_poll_time > max_wait_time_sec: raise DatabricksError( f"Run `{databricks_run_id}` took more than {max_wait_time_sec}s to complete." " Failing the run." ) return False def wait_for_run_to_complete( self, logger: logging.Logger, databricks_run_id: int, poll_interval_sec: float, max_wait_time_sec: int, verbose_logs: bool = True, ) -> None: logger.info(f"Waiting for Databricks run `{databricks_run_id}` to complete...") start_poll_time = time.time() while True: if self.poll_run_state( logger=logger, start_poll_time=start_poll_time, databricks_run_id=databricks_run_id, max_wait_time_sec=max_wait_time_sec, verbose_logs=verbose_logs, ): return time.sleep(poll_interval_sec)
class DatabricksJobRunner: """Submits jobs created using Dagster config to Databricks, and monitors their progress. Attributes: host (str): Databricks host, e.g. https://uksouth.azuredatabricks.net. token (str): Databricks authentication token. poll_interval_sec (float): How often to poll Databricks for run status. max_wait_time_sec (int): How long to wait for a run to complete before failing. """ def __init__( self, host: Optional[str] = None, token: Optional[str] = None, oauth_client_id: Optional[str] = None, oauth_client_secret: Optional[str] = None, azure_client_id: Optional[str] = None, azure_client_secret: Optional[str] = None, azure_tenant_id: Optional[str] = None, poll_interval_sec: float = 5, max_wait_time_sec: int = DEFAULT_RUN_MAX_WAIT_TIME_SEC, ): self.host = check.opt_str_param(host, "host") self.token = check.opt_str_param(token, "token") self.poll_interval_sec = check.numeric_param(poll_interval_sec, "poll_interval_sec") self.max_wait_time_sec = check.int_param(max_wait_time_sec, "max_wait_time_sec") oauth_client_id = check.opt_str_param(oauth_client_id, "oauth_client_id") oauth_client_secret = check.opt_str_param(oauth_client_secret, "oauth_client_secret") azure_client_id = check.opt_str_param(azure_client_id, "azure_client_id") azure_client_secret = check.opt_str_param(azure_client_secret, "azure_client_secret") azure_tenant_id = check.opt_str_param(azure_tenant_id, "azure_tenant_id") self._client: DatabricksClient = DatabricksClient( host=self.host, token=self.token, oauth_client_id=oauth_client_id, oauth_client_secret=oauth_client_secret, azure_client_id=azure_client_id, azure_client_secret=azure_client_secret, azure_tenant_id=azure_tenant_id, ) @property def client(self) -> DatabricksClient: """Return the underlying `DatabricksClient` object.""" return self._client def submit_run(self, run_config: Mapping[str, Any], task: Mapping[str, Any]) -> int: """Submit a new run using the 'Runs submit' API.""" existing_cluster_id = run_config["cluster"].get("existing") new_cluster = run_config["cluster"].get("new") # The Databricks API needs different keys to be present in API calls depending # on new/existing cluster, so we need to process the new_cluster # config first. if new_cluster: new_cluster = new_cluster.copy() nodes = new_cluster.pop("nodes") if "driver_instance_pool_id" in nodes: check.invariant( "instance_pool_id" in nodes, "Usage of `driver_instance_pool_id` requires that `instance_pool_id` be specified" " for worker nodes", ) if "instance_pool_id" in nodes: new_cluster["instance_pool_id"] = nodes["instance_pool_id"] if "driver_instance_pool_id" in nodes: new_cluster["driver_instance_pool_id"] = nodes["driver_instance_pool_id"] else: node_types = nodes["node_types"] new_cluster["node_type_id"] = node_types["node_type_id"] if "driver_node_type_id" in node_types: new_cluster["driver_node_type_id"] = node_types["driver_node_type_id"] cluster_size = new_cluster.pop("size") if "num_workers" in cluster_size: new_cluster["num_workers"] = cluster_size["num_workers"] else: new_cluster["autoscale"] = cluster_size["autoscale"] tags = new_cluster.get("custom_tags", {}) if isinstance(tags, list): tags = {x["key"]: x["value"] for x in tags} tags["__dagster_version"] = dagster.__version__ new_cluster["custom_tags"] = tags check.invariant( existing_cluster_id is not None or new_cluster is not None, "Invalid value for run_config.cluster", ) # We'll always need some libraries, namely dagster/dagster_databricks/dagster_pyspark, # since they're imported by our scripts. # Add them if they're not already added by users in config. libraries = list(run_config.get("libraries", [])) install_default_libraries = run_config.get("install_default_libraries", True) if install_default_libraries: python_libraries = { x["pypi"]["package"].split("==")[0].replace("_", "-") for x in libraries if "pypi" in x } for library_name, library in [ ("dagster", dagster), ("dagster-databricks", dagster_databricks), ("dagster-pyspark", dagster_pyspark), ]: if library_name not in python_libraries: libraries.append( {"pypi": {"package": f"{library_name}=={library.__version__}"}} ) if "databricks-sdk" not in python_libraries: libraries.append( {"pypi": {"package": f"databricks-sdk=={version('databricks-sdk')}"}} ) # Only one task should be able to be chosen really; make sure of that here. check.invariant( sum( task.get(key) is not None for key in [ "notebook_task", "spark_python_task", "spark_jar_task", "spark_submit_task", ] ) == 1, "Multiple tasks specified in Databricks run", ) notification_settings = self._get_notification_settings(run_config) job_health_settings = self._get_job_health_settings(run_config) return self.client.workspace_client.jobs.submit( run_name=run_config.get("run_name"), tasks=[ jobs.SubmitTask.from_dict( { "new_cluster": new_cluster, "existing_cluster_id": existing_cluster_id, "libraries": libraries, **task, "task_key": "dagster-task", }, ) ], idempotency_token=run_config.get("idempotency_token"), timeout_seconds=run_config.get("timeout_seconds"), health=job_health_settings, **notification_settings, ).bind()["run_id"] def _get_job_health_settings( self, run_config: Mapping[str, Any] ) -> Optional[List[jobs.JobsHealthRule]]: if "job_health_settings" in run_config: job_health_settings = [ jobs.JobsHealthRule.from_dict(h) for h in run_config["job_health_settings"] ] else: job_health_settings = None return job_health_settings def _get_notification_settings( self, run_config: Mapping[str, Any] ) -> Mapping[ str, Union[jobs.JobEmailNotifications, jobs.JobNotificationSettings, jobs.WebhookNotifications], ]: notification_configs = {} if "email_notifications" in run_config: email_notifications = jobs.JobEmailNotifications.from_dict( run_config["email_notifications"] ) notification_configs.update({"email_notifications": email_notifications}) if "notification_settings" in run_config: notification_settings = jobs.JobNotificationSettings.from_dict( run_config["notification_settings"] ) notification_configs.update({"notification_settings": notification_settings}) if "webhook_notifications" in run_config: webhook_notifications = jobs.WebhookNotifications.from_dict( run_config["webhook_notifications"] ) notification_configs.update({"webhook_notifications": webhook_notifications}) return notification_configs def retrieve_logs_for_run_id( self, log: logging.Logger, databricks_run_id: int ) -> Optional[Tuple[Optional[str], Optional[str]]]: """Retrieve the stdout and stderr logs for a run.""" run = self.client.workspace_client.jobs.get_run(databricks_run_id) # Run.cluster_instance can be None. In that case, fall back to cluster instance on first # task. Currently pyspark step launcher runs jobs with singleton tasks. cluster_instance = run.cluster_instance or run.tasks[0].cluster_instance cluster_id = check.inst( cluster_instance.cluster_id, str, "cluster_id should be string like `1234-123456-abcdefgh` got:" f" `{cluster_instance.cluster_id}`", ) cluster = self.client.workspace_client.clusters.get(cluster_id) log_config = cluster.cluster_log_conf if log_config is None: log.warn( f"Logs not configured for cluster {cluster_id} used for run {databricks_run_id}" ) return None if cast(Optional[compute.S3StorageInfo], log_config.s3) is not None: logs_prefix = log_config.s3.destination log.warn("Retrieving S3 logs not yet implemented") return None elif cast(Optional[compute.DbfsStorageInfo], log_config.dbfs) is not None: logs_prefix = log_config.dbfs.destination stdout = self.wait_for_dbfs_logs(log, logs_prefix, cluster_id, "stdout") stderr = self.wait_for_dbfs_logs(log, logs_prefix, cluster_id, "stderr") return stdout, stderr def wait_for_dbfs_logs( self, log: logging.Logger, prefix: str, cluster_id: str, filename: str, waiter_delay: int = 10, waiter_max_attempts: int = 10, ) -> Optional[str]: """Attempt up to `waiter_max_attempts` attempts to get logs from DBFS.""" path = "/".join([prefix, cluster_id, "driver", filename]) log.info(f"Retrieving logs from {path}") num_attempts = 0 while num_attempts <= waiter_max_attempts: try: logs = self.client.read_file(path) return logs.decode("utf-8") except requests.exceptions.HTTPError: num_attempts += 1 time.sleep(waiter_delay) log.warn("Could not retrieve cluster logs!")