Integrating Airflow with Databricks: Creating Custom Operators

Introduction

Apache Airflow provides robust capabilities for designing and managing workflows. However, there are times when external integrations require a more tailored approach than what's available out-of-the-box. This article focuses on the practical implementation of custom Airflow operators, using Databricks integration as a case study. We’ll create a custom operator, and make it deferrable for better resource utilization.

Setup

To follow this example, you will need:

  1. Airflow: pip install apache-airflow
  2. Databricks Python SDK: pip install databricks-sdk
  3. A Databricks account

Writing the Hook

The best practice for interacting with an external service using Airflow is the Hook abstraction. Hooks provide a unified interface for acquiring connections, and integrate with the built-in connection management. We’ll create a hook for connecting to Databricks via the Databricks Python SDK:

from airflow.hooks.base import BaseHook
from databricks.sdk import WorkspaceClient


class DatabricksHook(BaseHook):
    def __init__(self, dbx_conn_id="databricks_default"):
        super(DatabricksHook, self).__init__()
        self.dbx_conn_id = dbx_conn_id
        self.client = None

    def get_conn(self):
        """
        Returns a Databricks API client.
        """
        if self.client is None:
            conn = self.get_connection(self.dbx_conn_id)
            self.client = WorkspaceClient(host=conn.host, token=conn.password)
        return self.client

Writing the Operator

Operators are the tasks we run in our workflows. Let’s use our hook to access the Databricks API, submit a blocking request to run a job, and determine the state of our task based on the returned code. If our task has failed, we can indicate an operator failure to the framework by raising an AirflowException:

from airflow.models import BaseOperator
from airflow.utils.context import Context
from airflow.exceptions import AirflowException


class DatabricksRunNowOperator(BaseOperator):
    def __init__(self, dbx_conn_id: str, job_id: int, *args, **kwargs) -> None:
        super(DatabricksRunNowOperator, self).__init__(*args, **kwargs)
        self.dbx_conn_id = dbx_conn_id
        self.job_id = job_id

    def execute(self, context: Context) -> None:
        hook = DatabricksHook(self.dbx_conn_id)
        response = hook.get_conn().jobs.run_now_and_wait(self.job_id)
        if response.state.result_state != "SUCCESS":
            raise AirflowException(f"Databricks job failed: {response.as_dict()}")
        self.log.info(f"Databricks job succeeded: {response.as_dict()}")

Making the Operator Deferrable

Our analytics workload on Databricks could potentially take a long time to execute. In our previous implementation, we’d block the Airflow worker while our task waited for run_now_and_wait to return. To alleviate this, we can use the Airflow concept of a “deferrable” operator. These operators can pause their work and free up worker resources until an external event occurs, at which point they resume.

The code below introduces two main components:

  1. DatabricksJobTrigger: This is the "event watcher". It keeps checking the Databricks job status. If the job finishes (or errors out), it signals that it's time to move forward.
  2. DeferrableDatabricksRunNowOperator: This is the main task. It starts a Databricks job, then pauses itself using the trigger. Once the trigger says the job is done, the operator logs the job outcome and finishes up.

This approach can save resources in Airflow since the operator isn't constantly running. Instead, it's waiting for a signal to continue:

from airflow.triggers.base import BaseTrigger, TriggerEvent
from typing import Any
import asyncio
from databricks.sdk.service.jobs import RunLifeCycleState


class DatabricksJobTrigger(BaseTrigger):
    def __init__(self, dbx_conn_id, run_id):
        self.dbx_conn_id = dbx_conn_id
        self.run_id = run_id

    def serialize(self) -> tuple[str, dict[str, Any]]:
        return (
            "databricks_example.DatabricksJobTrigger",
            {
                "dbx_conn_id": self.dbx_conn_id,
                "run_id": self.run_id,
            },
        )

    async def run(self):
        hook = DatabricksHook(self.dbx_conn_id)
        while True:
            # Improvement - make get_run async.
            response = hook.get_conn().jobs.get_run(self.run_id)
            self.log.info(f"Got {response.as_dict()}")
            status = response.state.life_cycle_state
            self.log.info(status)
            if status in [
                RunLifeCycleState.TERMINATED,
                RunLifeCycleState.SKIPPED,
                RunLifeCycleState.INTERNAL_ERROR,
            ]:
                yield TriggerEvent(response.as_dict())
                return
            await asyncio.sleep(15)


class DeferrableDatabricksRunNowOperator(BaseOperator):
    def __init__(self, dbx_conn_id: str, job_id: int, *args, **kwargs) -> None:
        super(DeferrableDatabricksRunNowOperator, self).__init__(*args, **kwargs)
        self.dbx_conn_id = dbx_conn_id
        self.job_id = job_id

    def execute(self, context: Context) -> None:
        hook = DatabricksHook(self.dbx_conn_id)
        response = hook.get_conn().jobs.run_now(self.job_id)
        run_id = response.run_id
        self.log.info(f"Submitted Databricks run with ID: {run_id}")

        self.defer(
            trigger=DatabricksJobTrigger(self.dbx_conn_id, run_id),
            method_name="execute_complete",
        )

		def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
        if event["state"]["result_state"] != "SUCCESS":
            raise AirflowException(f"Databricks job failed: {event}")
        self.log.info(f"Databricks job succeeded: {event}")

While running this operator, you may notice logs like:

Triggerer's async thread was blocked for 0.23 seconds

This is due to the get_run function not being async, but being invoked from our async trigger run method. In a production setting, it would be best practice to have the entirety of this function be asynchronous. For the sake of this article, we will use the Databricks SDK as is. If you’re curious, you can see an asynchronous Databricks API client in the official Airflow integration.

Error Handling

The above operators work in the happy path. But, what happens when a user issues a termination from the Airflow UI? The BaseOperator provides an overridable method on_kill for these cases. When integrating with Databricks, we’ll need to pass this kill signal through to cancel our job.

def on_kill(self) -> None:
        self.log.info(f"Killing Databricks job {self.job_id}")
        if self.run_id is not None:
            DatabricksHook(self.dbx_conn_id).get_conn().jobs.cancel_run(self.run_id)

Example DAG

First, set up the connection credentials in your environment:

export AIRFLOW_CONN_DATABRICKS_DEFAULT='{
        "conn_type": "my-databricks",
        "host": "<your_databricks_host>",
        "password": "<your_databricks_token>"
}'

Note that in production, a secrets manager should be used for the credentials. Now, we can define a DAG for testing our operators:

from airflow.decorators import dag
import pendulum


@dag(
    schedule=None,
    start_date=pendulum.datetime(2023, 1, 1),
    catchup=False,
)
def databricks_example():
    DatabricksRunNowOperator(
        task_id="example", dbx_conn_id="databricks_default", job_id=MY_JOB_ID
    ) >> DeferrableDatabricksRunNowOperator(
        task_id="example_deferrable", dbx_conn_id="databricks_default", job_id=MY_JOB_ID
    )


databricks_example()

Bada-bing, bada-boom. Now you’re cooking with gas!

Conclusion

There are a variety of things left unhandled in our basic implementation:

  • Fully asynchronous trigger function.
  • Code reuse across deferrable and non-deferrable operators.
  • Configurable timeouts.
  • Full handling of Databricks job state info.

If you’re curious about these things, I recommend taking a look at the Airflow repo’s Databricks integration.