Introduction
Apache Airflow provides robust capabilities for designing and managing workflows. However, there are times when external integrations require a more tailored approach than what's available out-of-the-box. This article focuses on the practical implementation of custom Airflow operators, using Databricks integration as a case study. We’ll create a custom operator, and make it deferrable for better resource utilization.
Setup
To follow this example, you will need:
- Airflow:
pip install apache-airflow
- Databricks Python SDK:
pip install databricks-sdk
- A Databricks account
Writing the Hook
The best practice for interacting with an external service using Airflow is the Hook abstraction. Hooks provide a unified interface for acquiring connections, and integrate with the built-in connection management. We’ll create a hook for connecting to Databricks via the Databricks Python SDK:
from airflow.hooks.base import BaseHook
from databricks.sdk import WorkspaceClient
class DatabricksHook(BaseHook):
def __init__(self, dbx_conn_id="databricks_default"):
super(DatabricksHook, self).__init__()
self.dbx_conn_id = dbx_conn_id
self.client = None
def get_conn(self):
"""
Returns a Databricks API client.
"""
if self.client is None:
conn = self.get_connection(self.dbx_conn_id)
self.client = WorkspaceClient(host=conn.host, token=conn.password)
return self.client
Writing the Operator
Operators are the tasks we run in our workflows. Let’s use our hook to access the Databricks API, submit a blocking request to run a job, and determine the state of our task based on the returned code. If our task has failed, we can indicate an operator failure to the framework by raising an AirflowException
:
from airflow.models import BaseOperator
from airflow.utils.context import Context
from airflow.exceptions import AirflowException
class DatabricksRunNowOperator(BaseOperator):
def __init__(self, dbx_conn_id: str, job_id: int, *args, **kwargs) -> None:
super(DatabricksRunNowOperator, self).__init__(*args, **kwargs)
self.dbx_conn_id = dbx_conn_id
self.job_id = job_id
def execute(self, context: Context) -> None:
hook = DatabricksHook(self.dbx_conn_id)
response = hook.get_conn().jobs.run_now_and_wait(self.job_id)
if response.state.result_state != "SUCCESS":
raise AirflowException(f"Databricks job failed: {response.as_dict()}")
self.log.info(f"Databricks job succeeded: {response.as_dict()}")
Making the Operator Deferrable
Our analytics workload on Databricks could potentially take a long time to execute. In our previous implementation, we’d block the Airflow worker while our task waited for run_now_and_wait
to return. To alleviate this, we can use the Airflow concept of a “deferrable” operator. These operators can pause their work and free up worker resources until an external event occurs, at which point they resume.
The code below introduces two main components:
DatabricksJobTrigger
: This is the "event watcher". It keeps checking the Databricks job status. If the job finishes (or errors out), it signals that it's time to move forward.DeferrableDatabricksRunNowOperator
: This is the main task. It starts a Databricks job, then pauses itself using the trigger. Once the trigger says the job is done, the operator logs the job outcome and finishes up.
This approach can save resources in Airflow since the operator isn't constantly running. Instead, it's waiting for a signal to continue:
from airflow.triggers.base import BaseTrigger, TriggerEvent
from typing import Any
import asyncio
from databricks.sdk.service.jobs import RunLifeCycleState
class DatabricksJobTrigger(BaseTrigger):
def __init__(self, dbx_conn_id, run_id):
self.dbx_conn_id = dbx_conn_id
self.run_id = run_id
def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"databricks_example.DatabricksJobTrigger",
{
"dbx_conn_id": self.dbx_conn_id,
"run_id": self.run_id,
},
)
async def run(self):
hook = DatabricksHook(self.dbx_conn_id)
while True:
# Improvement - make get_run async.
response = hook.get_conn().jobs.get_run(self.run_id)
self.log.info(f"Got {response.as_dict()}")
status = response.state.life_cycle_state
self.log.info(status)
if status in [
RunLifeCycleState.TERMINATED,
RunLifeCycleState.SKIPPED,
RunLifeCycleState.INTERNAL_ERROR,
]:
yield TriggerEvent(response.as_dict())
return
await asyncio.sleep(15)
class DeferrableDatabricksRunNowOperator(BaseOperator):
def __init__(self, dbx_conn_id: str, job_id: int, *args, **kwargs) -> None:
super(DeferrableDatabricksRunNowOperator, self).__init__(*args, **kwargs)
self.dbx_conn_id = dbx_conn_id
self.job_id = job_id
def execute(self, context: Context) -> None:
hook = DatabricksHook(self.dbx_conn_id)
response = hook.get_conn().jobs.run_now(self.job_id)
run_id = response.run_id
self.log.info(f"Submitted Databricks run with ID: {run_id}")
self.defer(
trigger=DatabricksJobTrigger(self.dbx_conn_id, run_id),
method_name="execute_complete",
)
def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
if event["state"]["result_state"] != "SUCCESS":
raise AirflowException(f"Databricks job failed: {event}")
self.log.info(f"Databricks job succeeded: {event}")
While running this operator, you may notice logs like:
Triggerer's async thread was blocked for 0.23 seconds
This is due to the get_run
function not being async
, but being invoked from our async
trigger run
method. In a production setting, it would be best practice to have the entirety of this function be asynchronous. For the sake of this article, we will use the Databricks SDK as is. If you’re curious, you can see an asynchronous Databricks API client in the official Airflow integration.
Error Handling
The above operators work in the happy path. But, what happens when a user issues a termination from the Airflow UI? The BaseOperator
provides an overridable method on_kill
for these cases. When integrating with Databricks, we’ll need to pass this kill signal through to cancel our job.
def on_kill(self) -> None:
self.log.info(f"Killing Databricks job {self.job_id}")
if self.run_id is not None:
DatabricksHook(self.dbx_conn_id).get_conn().jobs.cancel_run(self.run_id)
Example DAG
First, set up the connection credentials in your environment:
export AIRFLOW_CONN_DATABRICKS_DEFAULT='{
"conn_type": "my-databricks",
"host": "<your_databricks_host>",
"password": "<your_databricks_token>"
}'
Note that in production, a secrets manager should be used for the credentials. Now, we can define a DAG for testing our operators:
from airflow.decorators import dag
import pendulum
@dag(
schedule=None,
start_date=pendulum.datetime(2023, 1, 1),
catchup=False,
)
def databricks_example():
DatabricksRunNowOperator(
task_id="example", dbx_conn_id="databricks_default", job_id=MY_JOB_ID
) >> DeferrableDatabricksRunNowOperator(
task_id="example_deferrable", dbx_conn_id="databricks_default", job_id=MY_JOB_ID
)
databricks_example()
Bada-bing, bada-boom. Now you’re cooking with gas!
Conclusion
There are a variety of things left unhandled in our basic implementation:
- Fully asynchronous trigger function.
- Code reuse across deferrable and non-deferrable operators.
- Configurable timeouts.
- Full handling of Databricks job state info.
If you’re curious about these things, I recommend taking a look at the Airflow repo’s Databricks integration.