PySpark, from the ground up Lesson 28 / 60

The skew problem: when one key has 100x the rows

How data skew slows down jobs even when the total work is small, how to spot it in the Spark UI, and what symptoms look like in production.

A Spark stage is only as fast as its slowest task. That sentence is the whole lesson. Read it again, internalize it, and most of the rest of this is consequences.

In a healthy Spark stage, every task processes roughly the same amount of data. 200 partitions, each holding maybe 50 MB of rows, each finishing in 12 seconds. The whole stage finishes in maybe 14 seconds, because tasks run in parallel and the slowest one is barely slower than the median.

In a skewed stage, 199 tasks finish in 12 seconds and one task — the unlucky one — runs for 30 minutes. The whole stage takes 30 minutes. Adding more cores doesn’t help. Adding more memory doesn’t help. The bottleneck is one task, on one executor, processing one obscenely large partition. That’s data skew.

This lesson is about recognizing skew, finding it in the Spark UI, and understanding why it’s a problem in the first place. The fix — salting — gets its own lesson next.

How skew happens

Spark partitions data after a shuffle by hashing the join or group-by key:

partition_for(row) = hash(row.key) % num_partitions

If your keys are uniformly distributed, the rows spread roughly evenly across partitions. If they’re not — and in real data they almost never are — some partitions get way more rows than others.

Three common scenarios from production code:

Power users. An events table keyed by user_id. The top user has 100 million events. The median user has 12. After a groupBy("user_id") shuffle, every event from the top user lands in the same partition. That partition is 100 MB, the rest are 1 KB.

Geographic concentration. Transactions keyed by country. 60% of traffic is from the US, 30% from a long tail, 10% from “everything else.” After grouping by country, the US partition is 6x the size of all the others combined.

Null-or-empty keys. A column where most rows have a real value but 30% have null or "". All the nulls hash to the same bucket and pile up. This one is the most common and the most insidious because nobody plans for it — they just shrug and move on, and the job is mysteriously slow.

The pattern is identical in all three: the join key or group-by key is unbalanced, the shuffle preserves that imbalance, and one task ends up doing far more work than the rest.

Why “the whole stage waits”

Spark stages have a barrier. The next stage cannot start until every task in the current stage completes and writes its shuffle output to disk. If 199 of your 200 tasks finish in 12 seconds and one runs for 30 minutes, the next stage sits idle for 29 minutes and 48 seconds, waiting on that one task.

You can see this in the Spark UI’s stage timeline as a long, thin bar — almost everything finishes early, and there’s a single dragged-out task at the bottom. We’ll see what that looks like below.

It also means: total work in the cluster is not the right way to think about a skewed job. The cluster as a whole might only be 5% utilized for 29 of those 30 minutes. One core is pinned, the rest are bored. CPU dashboards lie about it.

A clearly skewed dataset

Let’s manufacture some skew so we can look at it. I’ll build a 1.1-million-row events DataFrame where one user (user_id = 1) accounts for 1 million rows and the other 100,000 rows are spread across 999 users.

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = (SparkSession.builder
         .appName("TheSkewProblem")
         .master("local[*]")
         .config("spark.sql.shuffle.partitions", "200")
         .getOrCreate())

# 1 million rows for user 1
big_user = spark.range(0, 1_000_000).select(
    F.lit(1).alias("user_id"),
    F.col("id").alias("event_id"),
)

# 100k rows distributed across user_ids 2..1000
other_users = spark.range(0, 100_000).select(
    ((F.col("id") % 999) + 2).alias("user_id"),
    F.col("id").alias("event_id"),
)

events = big_user.unionByName(other_users)
events.count()  # 1,100,000

Now run the diagnostic — group by user_id and count, sorted by count descending:

(events
 .groupBy("user_id")
 .count()
 .orderBy(F.desc("count"))
 .show(10))

# +-------+-------+
# |user_id|  count|
# +-------+-------+
# |      1|1000000|
# |      2|    101|
# |      3|    101|
# |      4|    101|
# |      5|    101|
# |      6|    101|
# |      7|    101|
# |      8|    101|
# |      9|    101|
# |     10|    101|
# +-------+-------+

User 1 has 1,000,000 rows. Every other user has about 100. That’s a 10,000x ratio. Anything keyed on user_id is going to send all of user 1’s rows to a single partition.

What it looks like in the Spark UI

Run any operation that shuffles on user_id:

result = events.groupBy("user_id").agg(F.count("*").alias("n"))
result.write.mode("overwrite").parquet("/tmp/skew-demo")

Open the Spark UI (http://localhost:4040 when running locally), click into the stage that does the shuffle, and look at the Tasks table at the bottom. The columns to watch:

  • Duration — wall clock time per task
  • Shuffle Read Size / Records — how much data each task pulled from upstream

In a healthy stage these columns will be tightly clustered around a median. In a skewed stage you’ll see something like:

Min:    Median:   Max:
50 ms   80 ms     45 s    <- duration
2 KB    4 KB      40 MB   <- shuffle read size

When max duration is more than 5x the median, you have skew. When it’s 100x the median, you have severe skew. Spark UI even highlights this with a “summary” row showing min/25th/median/75th/max — eyeball those and the answer is right there.

The other tell is the stage timeline view. Healthy stages look like a tightly packed brick of horizontal bars all finishing around the same time. Skewed stages have one long bar sticking out, sometimes drawn out 10x past the rest. Once you’ve seen one skewed timeline, you recognize them at a glance.

In production: the long-tail symptom

You won’t always have the Spark UI handy — sometimes you’re debugging from logs only. The signature of skew there is the long-tail symptom:

[INFO] Stage 14: 195/200 tasks finished in 28s
[INFO] Stage 14: 198/200 tasks finished in 32s
[INFO] Stage 14: 199/200 tasks finished in 35s
[INFO] Stage 14: 199/200 tasks finished in 2m 14s
[INFO] Stage 14: 199/200 tasks finished in 5m 31s
[INFO] Stage 14: 199/200 tasks finished in 12m 09s
[INFO] Stage 14: 200/200 tasks finished in 24m 58s

195 tasks finish in 28 seconds. The 200th finishes 25 minutes later. That gap is one task processing one too-fat partition. The job’s overall wall clock is dominated by that one task, even though most of the work was done in the first half-minute.

Why nothing else helps

When developers see this for the first time the instinct is to throw resources at it. None of the obvious moves work:

  • More executors? No. The bottleneck is a single task that runs on a single executor. Extra executors sit idle.
  • More cores per executor? No. A single task uses a single core. Multi-core only helps if there are multiple tasks to run.
  • More memory? Sometimes — if the skewed task was spilling to disk, more memory speeds it up. But the task is still running alone, and you’re still gated on it.
  • More shuffle partitions? Only if the skew is mild. Increasing spark.sql.shuffle.partitions from 200 to 2000 spreads the lighter keys across more partitions, but every row with the heavy key still hashes to the same single partition.
  • Repartition? A plain repartition(2000) is uniform and won’t help — it’ll just shuffle the same imbalanced result into a different uniform partition count, and the heavy key will still land in one place.

The thing that does work is changing the shape of the key itself, so that the heavy key gets split across multiple partitions. That’s salting. That’s lesson 29.

Solutions, briefly

A complete map of skew remediations, ranked roughly by ease:

  1. Broadcast join, when one side is small. Lesson 27 covered this. If the table that has the heavy key is being joined against a small lookup, broadcast the lookup and the join becomes local — no shuffle, no skew.

  2. Filter out the obvious offender. If 30% of your rows have user_id = null and you don’t care about nulls in the join, filter them out before joining. Free win.

  3. Salting. Add a random suffix to the heavy keys so they hash into multiple partitions, do the join, then collapse back. Works for the both-sides-are-big case. Full coverage in lesson 29.

  4. AQE skew handling. Spark 3.x ships with Adaptive Query Execution, which can detect skew at runtime and split heavy partitions automatically. Enabled with spark.sql.adaptive.enabled = true and spark.sql.adaptive.skewJoin.enabled = true. It’s not magic — it only helps for sort-merge joins, only kicks in past a configurable threshold, and only works on the join itself (not arbitrary group-bys). But on Spark 3.4+ it solves a lot of skew without code changes. Lesson 59 goes deep on AQE.

  5. Pre-aggregate before joining. If the heavy key is heavy because it has many duplicates that you’re going to aggregate anyway, do the aggregation first. A .groupBy("user_id").agg(...) ahead of the join shrinks user 1’s row count from a million to one.

The order of operations in a real debugging session is usually: (1) confirm it’s skew with the diagnostic query, (2) check if AQE is on and has a threshold appropriate for your data, (3) if one side is a small lookup, broadcast it, (4) if not, filter null/garbage keys, (5) if not, salt.

What’s next

You now know what skew is, why it’s a stage-level problem rather than a cluster-level problem, and how to spot it from the UI or the logs. Lesson 29 walks through salting end-to-end with code, including the gotcha where badly-implemented salting makes things worse. After that, lesson 30 closes out the joins-and-shuffles module by tying everything together: how to read the physical plan of a join and predict the runtime before you press go.

The skew diagnostic query — df.groupBy(key).count().orderBy(F.desc("count")).show(20) — is worth committing to memory. Run it on any DataFrame where you’re about to do something keyed. If the top key is more than ~10x the median, plan for skew before you debug the slow job.

A tighter diagnostic for production

The simple groupBy().count() works on small datasets. On real production data, doing a full group-by just to check for skew is itself an expensive shuffle. A faster approach is to sample first:

sample = events.sample(fraction=0.01, seed=42)

(sample
 .groupBy("user_id")
 .count()
 .orderBy(F.desc("count"))
 .show(20))

A 1% sample is usually enough to spot heavy keys, and it skips the full shuffle. If user 1 dominates in the sample, they dominate in the full set.

Another useful one-liner is the skew ratio — top key versus median:

counts = events.groupBy("user_id").count()
top    = counts.agg(F.max("count")).first()[0]
median = counts.approxQuantile("count", [0.5], 0.01)[0]
print(f"top={top:,}  median={median:,}  ratio={top/median:.1f}x")

A ratio above 100 is severe skew. Above 10 is worth thinking about. Below 5 is fine.

Both diagnostics are cheap enough to drop into a pipeline as a sanity check before any expensive shuffle. Future-you will thank present-you the first time the alternative is debugging a 4-hour stage from logs.


References: Apache Spark documentation on shuffle behavior and Adaptive Query Execution; Databricks engineering blog posts on identifying and remediating data skew. Retrieved 2026-05-01.

Search