Ask AI

Source code for dagster_airlift.core.airflow_instance

import datetime
import time
from abc import ABC
from typing import Any, Dict, List, Optional, Sequence

import requests
from dagster import _check as check
from dagster._annotations import public
from dagster._core.definitions.utils import check_valid_name
from dagster._core.errors import DagsterError
from dagster._record import record
from dagster._time import get_current_datetime

from dagster_airlift.core.serialization.serialized_data import DagInfo, TaskInfo

TERMINAL_STATES = {"success", "failed", "skipped", "up_for_retry", "up_for_reschedule"}
# This limits the number of task ids that we attempt to query from airflow's task instance rest API at a given time.
# Airflow's batch task instance retrieval rest API doesn't have a limit parameter, but we query a single run at a time, meaning we should be getting
# a single task instance per task id.
# Airflow task instance batch API: https://airflow.apache.org/docs/apache-airflow/stable/stable-rest-api-ref.html#operation/get_task_instances_batch
DEFAULT_BATCH_TASK_RETRIEVAL_LIMIT = 100
# This corresponds directly to the page_limit parameter on airflow's batch dag runs rest API.
# Airflow dag run batch API: https://airflow.apache.org/docs/apache-airflow/stable/stable-rest-api-ref.html#operation/get_dag_runs_batch
DEFAULT_BATCH_DAG_RUNS_LIMIT = 100
SLEEP_SECONDS = 1


[docs] class AirflowAuthBackend(ABC): """An abstract class that represents an authentication backend for an Airflow instance. Requires two methods to be implemented by subclasses: - get_session: Returns a requests.Session object that can be used to make requests to the Airflow instance, and handles authentication. - get_webserver_url: Returns the base URL of the Airflow webserver. The `dagster-airlift` package provides the following default implementations: - :py:class:`dagster-airlift.core.AirflowBasicAuthBackend`: An authentication backend that uses Airflow's basic auth to authenticate with the Airflow instance. - :py:class:`dagster-airlift.mwaa.MwaaSessionAuthBackend`: An authentication backend that uses AWS MWAA's web login token to authenticate with the Airflow instance (requires `dagster-airlift[mwaa]`). """ def get_session(self) -> requests.Session: raise NotImplementedError("This method must be implemented by subclasses.") def get_webserver_url(self) -> str: raise NotImplementedError("This method must be implemented by subclasses.")
[docs] class AirflowInstance: """A class that represents a running Airflow Instance and provides methods for interacting with its REST API. Args: auth_backend (AirflowAuthBackend): The authentication backend to use when making requests to the Airflow instance. name (str): The name of the Airflow instance. This will be prefixed to any assets automatically created using this instance. batch_task_instance_limit (int): The number of task instances to query at a time when fetching task instances. Defaults to 100. batch_dag_runs_limit (int): The number of dag runs to query at a time when fetching dag runs. Defaults to 100. """ def __init__( self, auth_backend: AirflowAuthBackend, name: str, batch_task_instance_limit: int = DEFAULT_BATCH_TASK_RETRIEVAL_LIMIT, batch_dag_runs_limit: int = DEFAULT_BATCH_DAG_RUNS_LIMIT, ) -> None: self.auth_backend = auth_backend self.name = check_valid_name(name) self.batch_task_instance_limit = batch_task_instance_limit self.batch_dag_runs_limit = batch_dag_runs_limit @property def normalized_name(self) -> str: return self.name.replace(" ", "_").replace("-", "_") def get_api_url(self) -> str: return f"{self.auth_backend.get_webserver_url()}/api/v1" def list_dags(self) -> List["DagInfo"]: response = self.auth_backend.get_session().get( f"{self.get_api_url()}/dags", params={"limit": 1000} ) if response.status_code == 200: dags = response.json() webserver_url = self.auth_backend.get_webserver_url() return [ DagInfo( webserver_url=webserver_url, dag_id=dag["dag_id"], metadata=dag, ) for dag in dags["dags"] ] else: raise DagsterError( f"Failed to fetch DAGs. Status code: {response.status_code}, Message: {response.text}" ) def list_variables(self) -> List[Dict[str, Any]]: response = self.auth_backend.get_session().get(f"{self.get_api_url()}/variables") if response.status_code == 200: return response.json()["variables"] else: raise DagsterError( "Failed to fetch variables. Status code: {response.status_code}, Message: {response.text}" ) def get_task_instance_batch( self, dag_id: str, task_ids: Sequence[str], run_id: str, states: Sequence[str] ) -> List["TaskInstance"]: """Get all task instances for a given dag_id, task_ids, and run_id.""" task_instances = [] task_id_chunks = [ task_ids[i : i + self.batch_task_instance_limit] for i in range(0, len(task_ids), self.batch_task_instance_limit) ] for task_id_chunk in task_id_chunks: response = self.auth_backend.get_session().post( f"{self.get_api_url()}/dags/~/dagRuns/~/taskInstances/list", json={ "dag_ids": [dag_id], "task_ids": task_id_chunk, "dag_run_ids": [run_id], }, ) if response.status_code == 200: for task_instance_json in response.json()["task_instances"]: task_id = task_instance_json["task_id"] task_instance = TaskInstance( webserver_url=self.auth_backend.get_webserver_url(), dag_id=dag_id, task_id=task_id, run_id=run_id, metadata=task_instance_json, ) if task_instance.state in states: task_instances.append(task_instance) else: raise DagsterError( f"Failed to fetch task instances for {dag_id}/{task_id_chunk}/{run_id}. Status code: {response.status_code}, Message: {response.text}" ) return task_instances def get_task_instance(self, dag_id: str, task_id: str, run_id: str) -> "TaskInstance": response = self.auth_backend.get_session().get( f"{self.get_api_url()}/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}" ) if response.status_code == 200: return TaskInstance( webserver_url=self.auth_backend.get_webserver_url(), dag_id=dag_id, task_id=task_id, run_id=run_id, metadata=response.json(), ) else: raise DagsterError( f"Failed to fetch task instance for {dag_id}/{task_id}/{run_id}. Status code: {response.status_code}, Message: {response.text}" ) def get_task_infos(self, *, dag_id: str) -> List["TaskInfo"]: response = self.auth_backend.get_session().get(f"{self.get_api_url()}/dags/{dag_id}/tasks") if response.status_code != 200: raise DagsterError( f"Failed to fetch task infos for {dag_id}. Status code: {response.status_code}, Message: {response.text}" ) dag_json = response.json() webserver_url = self.auth_backend.get_webserver_url() return [ TaskInfo( webserver_url=webserver_url, dag_id=dag_id, metadata=task_data, task_id=task_data["task_id"], ) for task_data in dag_json["tasks"] ] def get_task_info(self, *, dag_id: str, task_id: str) -> "TaskInfo": response = self.auth_backend.get_session().get( f"{self.get_api_url()}/dags/{dag_id}/tasks/{task_id}" ) if response.status_code == 200: return TaskInfo( webserver_url=self.auth_backend.get_webserver_url(), dag_id=dag_id, task_id=task_id, metadata=response.json(), ) else: raise DagsterError( f"Failed to fetch task info for {dag_id}/{task_id}. Status code: {response.status_code}, Message: {response.text}" ) def get_dag_source_code(self, file_token: str) -> str: response = self.auth_backend.get_session().get( f"{self.get_api_url()}/dagSources/{file_token}" ) if response.status_code == 200: return response.text else: raise DagsterError( f"Failed to fetch source code. Status code: {response.status_code}, Message: {response.text}" ) def get_dag_runs( self, dag_id: str, start_date: datetime.datetime, end_date: datetime.datetime ) -> List["DagRun"]: response = self.auth_backend.get_session().get( f"{self.get_api_url()}/dags/{dag_id}/dagRuns", params={ "updated_at_gte": start_date.isoformat(), "updated_at_lte": end_date.isoformat(), "state": ["success"], }, ) if response.status_code == 200: webserver_url = self.auth_backend.get_webserver_url() return [ DagRun( webserver_url=webserver_url, dag_id=dag_id, run_id=dag_run["dag_run_id"], metadata=dag_run, ) for dag_run in response.json()["dag_runs"] ] else: raise DagsterError( f"Failed to fetch dag runs for {dag_id}. Status code: {response.status_code}, Message: {response.text}" ) def get_dag_runs_batch( self, dag_ids: Sequence[str], end_date_gte: datetime.datetime, end_date_lte: datetime.datetime, offset: int = 0, ) -> List["DagRun"]: """Return a batch of dag runs for a list of dag_ids. Ordered by end_date.""" response = self.auth_backend.get_session().post( f"{self.get_api_url()}/dags/~/dagRuns/list", json={ "dag_ids": dag_ids, "end_date_gte": end_date_gte.isoformat(), "end_date_lte": end_date_lte.isoformat(), "order_by": "end_date", "states": ["success"], "page_offset": offset, "page_limit": self.batch_dag_runs_limit, }, ) if response.status_code == 200: webserver_url = self.auth_backend.get_webserver_url() return [ DagRun( webserver_url=webserver_url, dag_id=dag_run["dag_id"], run_id=dag_run["dag_run_id"], metadata=dag_run, ) for dag_run in response.json()["dag_runs"] ] else: raise DagsterError( f"Failed to fetch dag runs for {dag_ids}. Status code: {response.status_code}, Message: {response.text}" )
[docs] @public def trigger_dag(self, dag_id: str, logical_date: Optional[datetime.datetime] = None) -> str: """Trigger a dag run for the given dag_id. Does not wait for the run to finish. To wait for the completed run to finish, use :py:meth:`wait_for_run_completion`. Args: dag_id (str): The dag id to trigger. logical_date (Optional[datetime.datetime]): The Airflow logical_date to use for the dag run. If not provided, the current time will be used. Previously known as execution_date in Airflow; find more information in the Airflow docs: https://airflow.apache.org/docs/apache-airflow/stable/faq.html#what-does-execution-date-mean Returns: str: The dag run id. """ params = {} if not logical_date else {"logical_date": logical_date.isoformat()} response = self.auth_backend.get_session().post( f"{self.get_api_url()}/dags/{dag_id}/dagRuns", json=params, ) if response.status_code != 200: raise DagsterError( f"Failed to launch run for {dag_id}. Status code: {response.status_code}, Message: {response.text}" ) return response.json()["dag_run_id"]
def get_dag_run(self, dag_id: str, run_id: str) -> "DagRun": response = self.auth_backend.get_session().get( f"{self.get_api_url()}/dags/{dag_id}/dagRuns/{run_id}" ) if response.status_code != 200: raise DagsterError( f"Failed to fetch dag run for {dag_id}/{run_id}. Status code: {response.status_code}, Message: {response.text}" ) return DagRun( webserver_url=self.auth_backend.get_webserver_url(), dag_id=dag_id, run_id=run_id, metadata=response.json(), ) def unpause_dag(self, dag_id: str) -> None: response = self.auth_backend.get_session().patch( f"{self.get_api_url()}/dags", json={"is_paused": False}, params={"dag_id_pattern": dag_id}, ) if response.status_code != 200: raise DagsterError( f"Failed to unpause dag {dag_id}. Status code: {response.status_code}, Message: {response.text}" )
[docs] @public def wait_for_run_completion(self, dag_id: str, run_id: str, timeout: int = 30) -> None: """Given a run ID of an airflow dag, wait for that run to reach a completed state. Args: dag_id (str): The dag id. run_id (str): The run id. timeout (int): The number of seconds to wait before timing out. Returns: None """ start_time = get_current_datetime() while get_current_datetime() - start_time < datetime.timedelta(seconds=timeout): dag_run = self.get_dag_run(dag_id, run_id) if dag_run.finished: return time.sleep( SLEEP_SECONDS ) # Sleep for a second before checking again. This way we don't flood the rest API with requests. raise DagsterError(f"Timed out waiting for airflow run {run_id} to finish.")
[docs] @public def get_run_state(self, dag_id: str, run_id: str) -> str: """Given a run ID of an airflow dag, return the state of that run. Args: dag_id (str): The dag id. run_id (str): The run id. Returns: str: The state of the run. Will be one of the states defined by Airflow. """ return self.get_dag_run(dag_id, run_id).state
def delete_run(self, dag_id: str, run_id: str) -> None: response = self.auth_backend.get_session().delete( f"{self.get_api_url()}/dags/{dag_id}/dagRuns/{run_id}" ) if response.status_code != 204: raise DagsterError( f"Failed to delete run for {dag_id}/{run_id}. Status code: {response.status_code}, Message: {response.text}" ) return None
@record class TaskInstance: webserver_url: str dag_id: str task_id: str run_id: str metadata: Dict[str, Any] @property def state(self) -> str: return self.metadata["state"] @property def note(self) -> str: return self.metadata.get("note") or "" @property def details_url(self) -> str: return f"{self.webserver_url}/dags/{self.dag_id}/grid?dag_run_id={self.run_id}&task_id={self.task_id}" @property def log_url(self) -> str: return f"{self.details_url}&tab=logs" @property def logical_date(self) -> datetime.datetime: """Returns the airflow-coined "logical date" from the task instance metadata. The logical date refers to the starting time of the "data interval" that the overall dag run is processing. In airflow < 2.2, this was set as the execution_date parameter in the task instance metadata. """ # In airflow < 2.2, execution_date is set instead of logical_date. logical_date_str = check.not_none( self.metadata.get("logical_date") or self.metadata.get("execution_date"), "Expected one of execution_date or logical_date to be returned from the airflow rest API when querying for task information.", ) return datetime.datetime.fromisoformat(logical_date_str) @property def start_date(self) -> datetime.datetime: return datetime.datetime.fromisoformat(self.metadata["start_date"]) @property def end_date(self) -> datetime.datetime: return datetime.datetime.fromisoformat(self.metadata["end_date"]) @record class DagRun: webserver_url: str dag_id: str run_id: str metadata: Dict[str, Any] @property def note(self) -> str: return self.metadata.get("note") or "" @property def url(self) -> str: return f"{self.webserver_url}/dags/{self.dag_id}/grid?dag_run_id={self.run_id}&tab=details" @property def success(self) -> bool: return self.metadata["state"] == "success" @property def finished(self) -> bool: return self.state in TERMINAL_STATES @property def state(self) -> str: return self.metadata["state"] @property def run_type(self) -> str: return self.metadata["run_type"] @property def config(self) -> Dict[str, Any]: return self.metadata["conf"] @property def logical_date(self) -> datetime.datetime: """Returns the airflow-coined "logical date" from the dag run metadata. The logical date refers to the starting time of the "data interval" that the dag run is processing. In airflow < 2.2, this was set as the execution_date parameter in the dag run metadata. """ # In airflow < 2.2, execution_date is set instead of logical_date. logical_date_str = check.not_none( self.metadata.get("logical_date") or self.metadata.get("execution_date"), "Expected one of execution_date or logical_date to be returned from the airflow rest API when querying for dag information.", ) return datetime.datetime.fromisoformat(logical_date_str) @property def start_date(self) -> datetime.datetime: return datetime.datetime.fromisoformat(self.metadata["start_date"]) @property def end_date(self) -> datetime.datetime: return datetime.datetime.fromisoformat(self.metadata["end_date"])