PySpark, from the ground up Lesson 30 / 60

PySpark joins that don't blow up the cluster

Why joins are the number-one source of Spark pain, what the shuffle actually does, and the broadcast and salting tricks that turn a 40-minute job into a 4-minute one.

Module 5 has, by this point, walked you through the full anatomy of a Spark join. Lesson 25 explained what a shuffle is. Lesson 26 covered the join strategies Spark picks between. Lesson 27 was broadcast joins. Lesson 28 was skew. Lesson 29 was salting. This lesson is the capstone — the from-the-trenches playbook for when a join job lands on your plate at 9 AM and the SLA is at noon. No new concepts. Just the order in which you check things, the order in which you fix things, and the traps that experienced data engineers keep falling into anyway.

If you only remember one thing from Module 5, make it the contents of this post.

The 30-second triage

Job is slow. You have thirty seconds before someone Slacks you about it. Open the Spark UI for the running (or last-failed) job. Click into the slowest stage — the one with the longest duration. You’re looking for three things, in this order.

1. Skew. Click the stage. Sort the task list by Duration descending. Look at the top task vs the median.

Task   Duration   Shuffle Read
0      32.1 min   18.4 GB     <-- the smoking gun
1      24 sec     220 MB
2      21 sec     215 MB
...

If the slowest task is more than ~3x the median, you have skew. The biggest task is processing one hot key while everyone else is done. Lessons 28-29 have the full diagnostic; the fix is salting or AQE skew handling.

2. Broadcast threshold. Open the SQL tab, find the query, click into its plan. Search the text for BroadcastHashJoin and SortMergeJoin. If you see SortMergeJoin and you know one of the two sides is small at runtime — say a few hundred MB after filtering — Spark missed a broadcast opportunity. The size estimator looked at the source file size, not the post-filter size, and went conservative.

3. Shuffle volume. Back on the stage page, look at “Shuffle Read” and “Shuffle Write” totals. Hundreds of GB shuffled means both sides got hashed and moved across the network. If your final result is 100 MB but the shuffle was 400 GB, you’re moving columns and rows that get thrown away later. Filter and project upstream of the join.

That’s the triage. Skew, broadcast, shuffle volume. Three checks, thirty seconds, and you’ve narrowed the fix to one of three families before you’ve even started typing.

The 5-step fix order

Once you know roughly what’s wrong, fix in this order. The earlier fixes are cheaper and safer than the later ones; don’t reach for salting if you haven’t filtered yet.

Step 1: filter both sides early

The cheapest optimization in Spark is “do less work.” If your big fact table has five years of orders and the report only cares about 2026, filter to 2026 before the join, not after. Predicate pushdown will sometimes do this for you on Parquet sources, but it absolutely won’t if the filter depends on a derived column or a join column from the other side.

# Slow: shuffle 5 years of orders
orders.join(customers, "customer_id").filter(col("year") == 2026)

# Fast: shuffle 1 year of orders
(orders.filter(col("year") == 2026)
        .join(customers, "customer_id"))

Same answer. One-fifth the shuffle. This is free performance and you should write it this way out of habit, not as an optimization.

Step 2: project narrow

Joins shuffle every column you carry into them. A SELECT * before a join includes the 4 KB JSON blob and the 12 columns nobody asked for. Each one travels across the network on every shuffled row.

orders_narrow   = orders.select("order_id", "customer_id", "total", "country")
customers_narrow = customers.select("customer_id", "tier", "signup_date")

joined = orders_narrow.join(customers_narrow, "customer_id")

If your shuffle volume on the Stage page is wildly larger than the size of the data you actually use, this is your fix. Project narrow on both sides before joining. You can always re-join the wide table later for the few columns you need on output.

Step 3: broadcast if you can

If either side is small at runtime — and “small” in 2026 means 100-200 MB comfortably, up to maybe 1 GB if your driver has the headroom — wrap it in broadcast():

from pyspark.sql.functions import broadcast

joined = orders.join(broadcast(dim_country), "country_id")

The size estimate Spark uses for autoBroadcastJoinThreshold reads file size, not post-filter size. After a heavy filter, the small side might be tiny but Spark won’t know. The hint overrides the estimate. Pattern: a star-schema fact joined to several small dim tables — broadcast every dim, the fact never shuffles.

The trap is broadcasting something that isn’t actually small. The driver collects the whole “small” side into memory before shipping it to executors. Broadcast a 4 GB table off a 2 GB driver and the job dies with OutOfMemoryError or Total size of serialized results... is bigger than spark.driver.maxResultSize. Lesson 27 had the gory details. Always check the actual runtime size, not just the variable name that has “small” in it.

Step 4: salt if you can’t broadcast

You’ve filtered, projected, and the small side genuinely isn’t small. Both sides are big. Now look at skew. If one key dominates, salting redistributes it. Lesson 29 covered the mechanics. The compressed version:

from pyspark.sql import functions as F

SALT_BUCKETS = 32

big_salted = big.withColumn(
    "salt",
    (F.rand() * SALT_BUCKETS).cast("int"),
)

small_exploded = small.withColumn(
    "salt",
    F.explode(F.array([F.lit(i) for i in range(SALT_BUCKETS)])),
)

joined = big_salted.join(small_exploded, on=["join_key", "salt"]).drop("salt")

The salt range is a tuning knob. Too few buckets and you’ve barely spread the load; too many and the small side balloons proportionally. 16 to 64 covers most real cases. Re-check the stage in the UI after — task durations should now be much closer to each other.

Step 5: AQE if you have Spark 3.x

If you’re on Spark 3.0+ (which means anything in active production in 2026), Adaptive Query Execution can fix several of these problems at runtime, after the first stages complete and Spark has real size information instead of estimates. Turn it on globally:

spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")

AQE will:

  • Convert a SortMergeJoin to a BroadcastHashJoin at runtime if the actual size after early stages turns out small enough.
  • Split skewed partitions into smaller pieces and replicate the matching side, automating the easy salting cases.
  • Coalesce small post-shuffle partitions to reduce task overhead.

It’s not magic. It can’t help with skew on a derived key (a concat(country, status) join key) or with broadcasts that need a hint Spark would never have inferred. But it handles enough of the easy cases that turning it off is a choice you’d want a reason for. In 2026, AQE is on by default in vanilla Spark 3.2+, but you should still set the flags explicitly in your job config so reviewing the code makes the choice obvious.

The “join key explosion” trap

This one is the Spark equivalent of int * int overflow. You join two DataFrames on a key that you assumed was unique on at least one side. It isn’t. Both sides have duplicates. The join result is the cartesian product of the matching rows.

# orders has 1M rows, 1 row per order
# customers has been deduped... or has it?
joined = orders.join(customers, "customer_id")

# joined.count()
# 50_000_000_000   <-- excuse me?

If the customers DataFrame had three rows for each customer_id because of an upstream bug (slowly-changing dimension not deduped, soft-deleted rows still in the source, history table mistaken for the current snapshot), every order multiplies by three. A 1M-row fact joined to a 3x-duplicated dim is now 3M rows. Two duplicates on the fact side and three on the dim side gets you 6x. You’ve quietly turned a 1M-row job into a 50B-row job and the executors are dying.

The diagnostic is to check uniqueness before joining, especially on dimension-shaped tables:

customers.groupBy("customer_id").count().filter("count > 1").show()

If anything comes back, you have duplicates and the join is going to misbehave. Either dedupe (customers.dropDuplicates(["customer_id"]), or a window function picking the latest row for SCD-2 tables) or use a left_semi / left_anti join when you only need filtering, not column attachment. The left_semi join is wonderful and underused — it returns rows from the left that have a match on the right, without exploding cardinality.

When to break a complex join into intermediate stages

You have a six-way join. Catalyst is generating a plan with a stack of nested shuffles. The job runs but takes forever, and Spark UI’s SQL tab shows a thousand-line plan you can’t decipher.

Catalyst is a sophisticated planner, but on big chained joins it sometimes makes choices that are hard to reason about. The pragmatic move: cut the join in half, write the intermediate to Parquet, and read it back for the next half.

# Stage 1: enrich orders with customer + country
stage1 = (orders
          .join(broadcast(customers), "customer_id")
          .join(broadcast(countries), "country_id"))
stage1.write.mode("overwrite").parquet("tmp/stage1/")

# Stage 2: read stage1 back, join the heavy stuff
enriched = spark.read.parquet("tmp/stage1/")
final = (enriched
         .join(products, "product_id")
         .join(stores, "store_id"))
final.write.parquet("out/")

Three things this buys you. First, statistics: when Spark reads tmp/stage1/ it knows the actual row count and column distribution, which fixes a lot of broadcast-threshold misses. Second, debuggability: you can inspect the intermediate, count rows, validate the join cardinality, all without rerunning the upstream work. Third, retry granularity: if stage 2 fails, you don’t recompute stage 1.

The cost is one extra round-trip through Parquet, which on cloud storage with reasonable file counts is in the seconds-to-low-minutes range. Often it’s a net win. Always a debugging win.

When joining is the wrong move

The unspoken last option. You’ve optimized the join and it’s still painful. You’re joining the same dim table to the same fact table in seventeen different downstream jobs. Each one pays the join cost.

Sometimes the answer is: don’t do the join at runtime. Denormalize at write time. Materialize a wide enriched table once, with the dim columns already attached, and let the downstream jobs read the flat table.

# Run this once, on a schedule
enriched = (orders
            .join(broadcast(customers), "customer_id")
            .join(broadcast(countries), "country_id")
            .join(broadcast(products),  "product_id"))
enriched.write.partitionBy("country_code", "year").parquet("warehouse/orders_enriched/")

# Downstream jobs read the wide table directly. No joins.
df = spark.read.parquet("warehouse/orders_enriched/")

Yes, this duplicates data. Yes, it complicates updates when a dimension changes. That’s the trade. In analytics workloads — where dim changes are infrequent, queries are frequent, and storage is cheap — the trade often goes in favor of the wide table. The lakehouse pattern (Delta Lake, Iceberg, Hudi) makes this even cleaner: you can update the dimension and have the changes propagate to the wide table on a cadence, instead of recomputing the join in every query.

This is the move you reach for when join performance is no longer the right thing to optimize. Sometimes the right optimization is to remove the join from the runtime path entirely.

The mental model, kept short

Joins are slow because they shuffle. Shuffles are slow because they move data over the network. Everything in this lesson is a way of moving less data: filter so there’s less to move, project so each row is smaller, broadcast so the big side doesn’t move, salt so no single executor is overloaded, AQE so Spark can fix the easy stuff for you, intermediate stages so a hard job becomes two manageable ones, and denormalization so the join doesn’t run at query time at all.

Module 6 starts in lesson 31 with partitioning, which is the other half of this story — how data is laid out before it hits a join, and how the layout decision either sets you up for cheap joins or guarantees expensive ones. The two modules are siblings. Read them as one.

Search