Using dfguard in Airflow¶
In Airflow, each task is a Python callable. Tasks receive file paths and configuration, create their own SparkSession, do the work, and write results back to storage. DataFrames never cross task boundaries; they are too large for XCom.
This means @dfg.enforce on the task function itself gives you little:
the arguments are strings and integers, not DataFrames. The useful pattern is:
@dfg.enforceon the transformation helpers the task calls internallySchema.assert_valid(df)right after loading from storage, to catch schema drift from upstream before processing starts
The working example lives at examples/airflow/ in the dfguard repository.
File structure¶
airflow/
├── requirements.txt
├── pipeline/
│ ├── __init__.py # arms enforcement globally
│ ├── schemas.py # SparkSchema definitions
│ └── transforms.py # transformation helpers
└── dags/
└── orders_dag.py # DAG and task callables
pipeline/__init__.py¶
Arm once here. Every function in the package with a schema annotation is enforced automatically – no decorator needed on each one.
# pipeline/__init__.py
import dfguard.pyspark as dfg
dfg.arm()
# dfg.arm(subset=False) # strict: no extra columns anywhere in the package
# To disable enforcement globally (e.g. in tests or non-prod environments):
# dfg.disarm()
schemas.py¶
Define schema contracts once, shared by all tasks.
# pipeline/schemas.py
import dfguard.pyspark as dfg
from pyspark.sql import types as T
class RawOrderSchema(dfg.SparkSchema):
order_id: T.LongType()
customer_id: T.LongType()
amount: T.DoubleType()
quantity: T.IntegerType()
status: T.StringType()
class EnrichedOrderSchema(RawOrderSchema):
revenue: T.DoubleType()
is_high_value: T.BooleanType()
class SummarySchema(dfg.SparkSchema):
customer_id: T.LongType()
total_revenue: T.DoubleType()
order_count: T.LongType()
transforms.py¶
Pure transformation functions. dfg.arm() in __init__.py covers them
all. @dfg.enforce(subset=False) is added where an exact schema is required
– no extra columns allowed, useful before writing to a fixed-schema sink.
# pipeline/transforms.py
import dfguard.pyspark as dfg
from pyspark.sql import DataFrame, functions as F
from pipeline.schemas import EnrichedOrderSchema, RawOrderSchema, SummarySchema
# Covered by dfg.arm() -- no decorator needed
def enrich(raw: RawOrderSchema) -> DataFrame:
return (
raw
.withColumn("revenue", F.col("amount") * F.col("quantity"))
.withColumn("is_high_value", F.col("revenue") > 500.0)
)
# subset=False: the summary written to storage must match SummarySchema exactly
@dfg.enforce(subset=False)
def summarise(enriched: EnrichedOrderSchema) -> DataFrame:
return (
enriched
.groupBy("customer_id")
.agg(
F.sum("revenue").alias("total_revenue"),
F.count("*").alias("order_count"),
)
)
orders_dag.py¶
Each task validates the DataFrame immediately after loading, then calls the appropriate transform. Tasks are completely independent; each creates and stops its own SparkSession.
# dags/orders_dag.py
import sys, os
from datetime import datetime
from airflow import DAG
from airflow.operators.python import PythonOperator
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
INPUT_PATH = "/data/orders/raw_orders.parquet"
ENRICH_PATH = "/data/orders/enriched_orders.parquet"
SUMMARY_PATH = "/data/orders/customer_summary.parquet"
def task_enrich(**context):
from pyspark.sql import SparkSession
from pipeline.schemas import RawOrderSchema
from pipeline.transforms import enrich
spark = SparkSession.builder.appName("orders-enrich").getOrCreate()
raw = spark.read.parquet(INPUT_PATH)
RawOrderSchema.assert_valid(raw) # fail fast if upstream schema changed
enriched = enrich(raw) # @dfg.enforce guards the function
enriched.write.mode("overwrite").parquet(ENRICH_PATH)
spark.stop()
def task_summarise(**context):
from pyspark.sql import SparkSession
from pipeline.schemas import EnrichedOrderSchema
from pipeline.transforms import summarise
spark = SparkSession.builder.appName("orders-summarise").getOrCreate()
enriched = spark.read.parquet(ENRICH_PATH)
EnrichedOrderSchema.assert_valid(enriched)
summary = summarise(enriched)
summary.write.mode("overwrite").parquet(SUMMARY_PATH)
spark.stop()
with DAG(
dag_id="orders_pipeline",
start_date=datetime(2024, 1, 1),
schedule="@daily",
catchup=False,
) as dag:
enrich_task = PythonOperator(task_id="enrich_orders", python_callable=task_enrich)
summarise_task = PythonOperator(task_id="summarise_orders", python_callable=task_summarise)
enrich_task >> summarise_task
Running locally¶
pip install -r examples/airflow/requirements.txt
export AIRFLOW_HOME=~/airflow
airflow db migrate
airflow dags trigger orders_pipeline
What a schema error looks like¶
If the enrich task is skipped and task_summarise runs directly on raw
data, it fails at the function call:
TypeError: Schema mismatch in summarise() argument 'enriched':
expected: order_id:bigint, customer_id:bigint, amount:double,
quantity:int, status:string, revenue:double, is_high_value:boolean
received: order_id:bigint, customer_id:bigint, amount:double,
quantity:int, status:string
The error is raised at the function call. The message tells you which function, which argument, what was expected, and what was actually passed.
If upstream data changes shape and assert_valid catches it:
SchemaValidationError: Schema validation failed:
✗ Missing column 'revenue' (expected double, nullable=False)
assert_valid reports all missing and mismatched fields in one message,
not just the first one it finds.
Tip
Use @dfg.enforce on every transformation helper. Use
Schema.assert_valid(df) at the start of every task that reads from
storage. Together they give you two layers: one at the storage boundary
(schema drift from upstream), one at the function boundary (wrong data
passed to wrong transform).