← Back to Blog
SparkFebruary 8, 202610 min read

PySpark for Data Engineers: Production Patterns Beyond the Tutorial

Spark tutorials show you how to run a word count. Production Spark work involves partitioning strategies, skew handling, and knowing when to stop using Spark entirely.

PySpark is one of the most powerful tools in the data engineering toolkit and one of the easiest to misuse. The API is expressive and the tutorials make it look simple. Production Spark work reveals the gap: skewed partitions that bring jobs to a crawl, UDFs that serialize everything to Python and eliminate all the performance gains, joins that cause out-of-memory errors on executors, and cluster configurations that nobody understands well enough to tune.

This guide covers the patterns and mental models that separate Spark code that works in dev from Spark code that works reliably in production at scale.

DataFrame API vs. RDDs

Use the DataFrame API. Full stop. RDDs (Resilient Distributed Datasets) are the lower-level abstraction that the DataFrame API is built on. They exist and are sometimes necessary for very specific custom operations, but for data engineering work the DataFrame API is almost always the right choice.

The DataFrame API benefits from the Catalyst query optimizer, which rewrites your query into an efficient execution plan. RDDs bypass this entirely. A DataFrame join can be optimized by Catalyst to use a broadcast join when one side is small; the equivalent RDD join cannot be.

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType

spark = SparkSession.builder     .appName("orders-pipeline")     .config("spark.sql.adaptive.enabled", "true")     .config("spark.sql.adaptive.coalescePartitions.enabled", "true")     .getOrCreate()

# Prefer explicit schema over inferSchema — faster and safer
schema = StructType([
    StructField("order_id", StringType(), False),
    StructField("customer_id", StringType(), True),
    StructField("amount", DoubleType(), True),
    StructField("order_date", TimestampType(), True)
])

df = spark.read.schema(schema).parquet("s3://bucket/orders/")

Partitioning: The Most Important Performance Lever

Spark parallelism is determined by the number of partitions. Too few partitions and you underutilize the cluster. Too many and you spend more time on scheduling overhead than computation. The default target is 128MB per partition.

# Check partition count and size
df.rdd.getNumPartitions()  # current partition count

# Repartition for joins and shuffles (increases partitions)
df_repartitioned = df.repartition(200, "customer_id")

# Coalesce to reduce partitions (no shuffle, only decreases)
df_coalesced = df.coalesce(10)

# Write partitioned by date for efficient reads
df.write     .partitionBy("year", "month")     .mode("overwrite")     .parquet("s3://bucket/processed-orders/")

# Read with partition pruning (Spark only reads matching partitions)
df_march = spark.read.parquet("s3://bucket/processed-orders/")     .filter((F.col("year") == 2026) & (F.col("month") == 3))

Adaptive Query Execution (AQE), enabled with spark.sql.adaptive.enabled=true, automatically coalesces small partitions after shuffles and handles some skew cases. Enable it in all production jobs.

Join Strategies: Avoiding the OOM Shuffle

Joins in Spark require shuffling data across the network to co-locate matching keys. Shuffle joins on large tables are expensive. Broadcast joins, where a small table is copied to every executor, eliminate the shuffle entirely.

from pyspark.sql.functions import broadcast

# Automatic broadcast: Spark broadcasts tables under spark.sql.autoBroadcastJoinThreshold (default 10MB)
# Manual broadcast hint for tables slightly above threshold
result = large_orders.join(
    broadcast(small_dim_table),
    on="customer_id",
    how="left"
)

# Sort-merge join for large-large joins (default when neither side is broadcastable)
# Ensure both sides are sorted on the join key before the join
orders_sorted = orders.repartition("customer_id").sortWithinPartitions("customer_id")
customers_sorted = customers.repartition("customer_id").sortWithinPartitions("customer_id")

result = orders_sorted.join(customers_sorted, on="customer_id", how="left")

# Skew join: one key has disproportionately many rows
# Salting technique: add a random suffix to the skewed key
import random
from pyspark.sql.functions import concat, lit, floor, rand

# Salt the large table
orders_salted = orders.withColumn(
    "salted_key",
    concat(F.col("skewed_customer_id"), lit("_"), (floor(rand() * 10)).cast("string"))
)

# Explode the small table to match all salt values
customers_exploded = customers.crossJoin(
    spark.range(10).toDF("salt")
).withColumn(
    "salted_key",
    concat(F.col("customer_id"), lit("_"), F.col("salt").cast("string"))
)

Avoiding UDF Performance Traps

Python UDFs (User Defined Functions) serialize data between the JVM and the Python interpreter for each row. For large datasets, this serialization overhead can make a UDF 10-100x slower than an equivalent native Spark function.

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# SLOW: Python UDF — serializes row by row to Python
@udf(StringType())
def get_revenue_tier_python(amount):
    if amount is None:
        return "unknown"
    if amount >= 10000:
        return "enterprise"
    if amount >= 1000:
        return "growth"
    return "starter"

df_slow = df.withColumn("tier", get_revenue_tier_python(F.col("amount")))

# FAST: Native Spark functions — stays in JVM
df_fast = df.withColumn(
    "tier",
    F.when(F.col("amount").isNull(), "unknown")
     .when(F.col("amount") >= 10000, "enterprise")
     .when(F.col("amount") >= 1000, "growth")
     .otherwise("starter")
)

# If you must use a Python UDF, use pandas UDF (vectorized)
# Operates on pandas Series instead of row by row — much faster
from pyspark.sql.functions import pandas_udf
import pandas as pd

@pandas_udf(StringType())
def get_revenue_tier_vectorized(amounts: pd.Series) -> pd.Series:
    conditions = [amounts >= 10000, amounts >= 1000, amounts.isna()]
    choices = ["enterprise", "growth", "unknown"]
    return pd.Series(
        pd.np.select(conditions, choices, default="starter")
    )

Reading and Writing Efficiently

# Always use Parquet or Delta for intermediate data (not CSV)
# Parquet: columnar, compressed, schema enforcement
# Delta: Parquet + ACID transactions + schema evolution

# Read with predicate pushdown — Spark pushes filters to storage layer
df = spark.read.parquet("s3://bucket/orders/")     .filter(F.col("order_date") >= "2026-01-01")

# Column pruning — only read needed columns
df = spark.read.parquet("s3://bucket/orders/")     .select("order_id", "customer_id", "amount")

# Write with compression (default is snappy for Parquet)
df.write     .option("compression", "snappy")     .mode("overwrite")     .partitionBy("year", "month", "day")     .parquet("s3://bucket/output/")

# For small output, coalesce before writing to avoid many tiny files
df.coalesce(1).write.mode("overwrite").parquet("s3://bucket/small-output/")

Debugging Spark Jobs

The Spark UI (port 4040 when running locally, accessible via the history server in a cluster) is the primary debugging tool. Key things to look for:

Stage duration and task distribution. If one task in a stage takes 10x longer than others, you have partition skew. The tasks bar in the Spark UI will show a long tail on a few tasks.

Shuffle read/write. High shuffle volume indicates expensive joins or groupBy operations that could be optimized by repartitioning before the operation.

Spill to disk. If executors are spilling to disk during a shuffle, they are running out of memory for the shuffle buffer. Increase executor memory or reduce the number of partitions in the shuffle.

# Explain query plan — check for broadcast joins, sort-merge joins, partitioning
df.join(dim_table, on="customer_id").explain(extended=True)

# Cache expensive intermediate results that are reused
df_enriched = df.join(broadcast(dim_customers), on="customer_id").cache()
df_enriched.count()  # Trigger caching

# Check partitioning after operations
print(df_enriched.rdd.getNumPartitions())

# Unpersist when done to free memory
df_enriched.unpersist()

When Not to Use Spark

Spark is optimized for large-scale distributed processing. It has significant overhead for small datasets: cluster startup time, JVM initialization, and shuffle operations that are trivial for a single-node system. For datasets under a few gigabytes, DuckDB, Pandas, or Polars will be faster and simpler.

Spark also requires a cluster, which means operational overhead: sizing executors, managing driver memory, handling cluster failures. For teams without dedicated infrastructure support, this overhead can outweigh the benefits.

Use Spark when: your dataset exceeds what a single machine can process comfortably (typically hundreds of gigabytes to terabytes), you need distributed streaming with Spark Structured Streaming, or you are on a platform (Databricks, EMR, Dataproc) where the cluster is managed for you. For everything else, simpler tools are often the right call.

Found this useful? Share it: