Ask AI

Source code for dagster_airlift.in_airflow.proxying_fn

import logging
from typing import Any, Callable, Dict, List, Optional, Set

from airflow import DAG
from airflow.models import BaseOperator

from dagster_airlift.in_airflow.dag_proxy_operator import (
    BaseProxyDAGToDagsterOperator,
    DefaultProxyDAGToDagsterOperator,
)
from dagster_airlift.in_airflow.proxied_state import AirflowProxiedState
from dagster_airlift.in_airflow.task_proxy_operator import (
    BaseProxyTaskToDagsterOperator,
    DefaultProxyTaskToDagsterOperator,
)


[docs] def proxying_to_dagster( *, global_vars: Dict[str, Any], proxied_state: AirflowProxiedState, logger: Optional[logging.Logger] = None, build_from_task_fn: Callable[ [BaseOperator], BaseProxyTaskToDagsterOperator ] = DefaultProxyTaskToDagsterOperator.build_from_task, build_from_dag_fn: Callable[ [DAG], BaseProxyDAGToDagsterOperator ] = DefaultProxyDAGToDagsterOperator.build_from_dag, ) -> None: """Proxies tasks and dags to Dagster based on provided proxied state. Expects a dictionary of in-scope global variables to be provided (typically retrieved with `globals()`), and a proxied state dictionary (typically retrieved with :py:func:`load_proxied_state_from_yaml`) for dags in that global state. This function will modify in-place the dictionary of global variables to replace proxied tasks with appropriate Dagster operators. In the case of task-level proxying, the proxied tasks will be replaced with new operators that are constructed by the provided `build_from_task_fn`. A default implementation of this function is provided in `DefaultProxyTaskToDagsterOperator`. In the case of dag-level proxying, the entire dag structure will be replaced with a single task that is constructed by the provided `build_from_dag_fn`. A default implementation of this function is provided in `DefaultProxyDAGToDagsterOperator`. Args: global_vars (Dict[str, Any]): The global variables in the current context. In most cases, retrieved with `globals()` (no import required). This is equivalent to what airflow already does to introspect the dags which exist in a given module context: https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/dags.html#loading-dags proxied_state (AirflowMigrationState): The proxied state for the dags. logger (Optional[logging.Logger]): The logger to use. Defaults to logging.getLogger("dagster_airlift"). Examples: Typical usage of this function is to be called at the end of a dag file, retrieving proxied_state from an accompanying `proxied_state` path. .. code-block:: python from pathlib import Path from airflow import DAG from airflow.operators.python import PythonOperator from dagster._time import get_current_datetime_midnight from dagster_airlift.in_airflow import proxying_to_dagster from dagster_airlift.in_airflow.proxied_state import load_proxied_state_from_yaml with DAG( dag_id="daily_interval_dag", ..., ) as minute_dag: PythonOperator(task_id="my_task", python_callable=...) # At the end of the dag file, so we can ensure dags are loaded into globals. proxying_to_dagster( proxied_state=load_proxied_state_from_yaml(Path(__file__).parent / "proxied_state"), global_vars=globals(), ) You can also provide custom implementations of the `build_from_task_fn` function to customize the behavior of task-level proxying. .. code-block:: python from dagster_airlift.in_airflow import proxying_to_dagster, BaseProxyTaskToDagsterOperator from airflow.models.operator import BaseOperator ... # Dag code here class CustomAuthTaskProxyOperator(BaseProxyTaskToDagsterOperator): def get_dagster_session(self, context: Context) -> requests.Session: # Add custom headers to the session return requests.Session(headers={"Authorization": "Bearer my_token"}) def get_dagster_url(self, context: Context) -> str: # Use a custom environment variable for the dagster url return os.environ["CUSTOM_DAGSTER_URL"] @classmethod def build_from_task(cls, task: BaseOperator) -> "CustomAuthTaskProxyOperator": # Custom logic to build the operator from the task (task_id should remain the same) if task.task_id == "my_task_needs_more_retries": return CustomAuthTaskProxyOperator(task_id=task_id, retries=3) else: return CustomAuthTaskProxyOperator(task_id=task_id) proxying_to_dagster( proxied_state=load_proxied_state_from_yaml(Path(__file__).parent / "proxied_state"), global_vars=globals(), build_from_task_fn=CustomAuthTaskProxyOperator.build_from_task, ) You can do the same for dag-level proxying by providing a custom implementation of the `build_from_dag_fn` function. .. code-block:: python from dagster_airlift.in_airflow import proxying_to_dagster, BaseProxyDAGToDagsterOperator from airflow.models.dag import DAG ... # Dag code here class CustomAuthDAGProxyOperator(BaseProxyDAGToDagsterOperator): def get_dagster_session(self, context: Context) -> requests.Session: # Add custom headers to the session return requests.Session(headers={"Authorization": "Bearer my_token"}) def get_dagster_url(self, context: Context) -> str: # Use a custom environment variable for the dagster url return os.environ["CUSTOM_DAGSTER_URL"] @classmethod def build_from_dag(cls, dag: DAG) -> "CustomAuthDAGProxyOperator": # Custom logic to build the operator from the dag (DAG id should remain the same) if dag.dag_id == "my_dag_needs_more_retries": return CustomAuthDAGProxyOperator(task_id="custom override", retries=3, dag=dag) else: return CustomAuthDAGProxyOperator(task_id="basic_override", dag=dag) proxying_to_dagster( proxied_state=load_proxied_state_from_yaml(Path(__file__).parent / "proxied_state"), global_vars=globals(), build_from_dag_fn=CustomAuthDAGProxyOperator.build_from_dag, ) """ caller_module = global_vars.get("__module__") suffix = f" in module `{caller_module}`" if caller_module else "" if not logger: logger = logging.getLogger("dagster_airlift") logger.debug(f"Searching for dags proxied to dagster{suffix}...") task_level_proxying_dags: List[DAG] = [] dag_level_proxying_dags: List[DAG] = [] all_dag_ids: Set[str] = set() # Do a pass to collect dags and ensure that proxied information is set correctly. for obj in global_vars.values(): if not isinstance(obj, DAG): continue dag: DAG = obj all_dag_ids.add(dag.dag_id) if not proxied_state.dag_has_proxied_state(dag.dag_id): logger.debug(f"Dag with id `{dag.dag_id}` has no proxied state. Skipping...") continue logger.debug(f"Dag with id `{dag.dag_id}` has proxied state.") proxied_state_for_dag = proxied_state.dags[dag.dag_id] if proxied_state_for_dag.proxied is not None: if proxied_state_for_dag.proxied is False: logger.debug(f"Dag with id `{dag.dag_id}` is not proxied. Skipping...") continue dag_level_proxying_dags.append(dag) else: for task_id in proxied_state_for_dag.tasks.keys(): if task_id not in dag.task_dict: raise Exception( f"Task with id `{task_id}` not found in dag `{dag.dag_id}`. Found tasks: {list(dag.task_dict.keys())}" ) if not isinstance(dag.task_dict[task_id], BaseOperator): raise Exception( f"Task with id `{task_id}` in dag `{dag.dag_id}` is not an instance of BaseOperator. This likely means a MappedOperator was attempted, which is not yet supported by airlift." ) task_level_proxying_dags.append(dag) if len(all_dag_ids) == 0: raise Exception( "No dags found in globals dictionary. Ensure that your dags are available from global context, and that the call to `proxying_to_dagster` is the last line in your dag file." ) for dag in dag_level_proxying_dags: logger.debug(f"Tagging dag {dag.dag_id} as proxied.") dag.tags = [*dag.tags, "Dag overriden to Dagster"] dag.task_dict = {} dag.task_group.children = {} override_task = build_from_dag_fn(dag) dag.task_dict[override_task.task_id] = override_task for dag in task_level_proxying_dags: logger.debug(f"Tagging dag {dag.dag_id} as proxied.") proxied_state_for_dag = proxied_state.dags[dag.dag_id] num_proxied_tasks = len( [ task_id for task_id, task_state in proxied_state_for_dag.tasks.items() if task_state.proxied ] ) task_possessive = "Task" if num_proxied_tasks == 1 else "Tasks" dag.tags = [ *dag.tags, f"{num_proxied_tasks} {task_possessive} Marked as Proxied to Dagster", ] proxied_tasks = set() for task_id, task_state in proxied_state_for_dag.tasks.items(): if not task_state.proxied: logger.debug( f"Task {task_id} in dag {dag.dag_id} has `proxied` set to False. Skipping..." ) continue # At this point, we should be assured that the task exists within the task_dict of the dag, and is a BaseOperator. original_op: BaseOperator = dag.task_dict[task_id] # type: ignore # we already confirmed this is BaseOperator del dag.task_dict[task_id] if original_op.task_group is not None: del original_op.task_group.children[task_id] logger.debug( f"Creating new operator for task {original_op.task_id} in dag {original_op.dag_id}" ) new_op = build_from_task_fn(original_op) original_op.dag.task_dict[original_op.task_id] = new_op new_op.upstream_task_ids = original_op.upstream_task_ids new_op.downstream_task_ids = original_op.downstream_task_ids new_op.dag = original_op.dag original_op.dag = None proxied_tasks.add(task_id) logger.debug(f"Proxied tasks {proxied_tasks} in dag {dag.dag_id}.") logging.debug(f"Proxied {len(task_level_proxying_dags)}.") logging.debug(f"Completed switching proxied tasks to dagster{suffix}.")