PySpark, from the ground up Lesson 15 / 60

Adding columns: withColumn, lit, and the chaining trap

How to add or modify columns, why withColumn calls in a loop are a known performance pitfall, and when to use select instead.

select picks columns, filter keeps rows. The third everyday operation is deriving a new column from existing ones: a tax on top of a price, a flag based on a threshold, a normalized version of a string. PySpark’s primary tool for this is withColumn, and like most PySpark methods, it’s simple at the surface and has one well-known sharp edge underneath.

We’ll cover the surface first — withColumn, lit, conditional logic with when().otherwise(), type casts — and then the sharp edge: what happens when you call withColumn 50 times in a loop, why that’s a known performance pitfall, and what to use instead.

withColumn: add or replace a single column

withColumn(name, expr) returns a new DataFrame with the named column added or replaced. If the name already exists, the existing column is overwritten; if not, it’s appended.

from pyspark.sql.functions import col

df2 = orders.withColumn("amount_with_vat", col("amount") * 1.22)

That’s the whole API. One name, one Column expression, one new DataFrame. The original orders is unchanged — DataFrames are immutable, every transformation returns a new one.

The expression on the right side can be anything that evaluates to a Column: arithmetic, function calls, when().otherwise(), even literal values:

from pyspark.sql.functions import col, upper, length

df3 = (
    orders
    .withColumn("amount_with_vat",  col("amount") * 1.22)
    .withColumn("country_upper",    upper(col("country")))
    .withColumn("name_length",      length(col("customer_name")))
)

To replace an existing column, use the same name. Common case: type-fixing an inferred-string column.

df4 = orders.withColumn("amount", col("amount").cast("double"))

That doesn’t add a column called amount next to the old one — it overwrites the existing amount with the cast version. The schema position stays the same.

lit: turning Python values into Column expressions

When you mix a Python value with a Column in an expression, PySpark usually figures it out. col("amount") * 1.22 works because PySpark’s Column class overloads * and knows how to handle the numeric literal 1.22.

But sometimes that auto-promotion isn’t enough — usually when you’re calling a function that explicitly expects a Column, or when the Python value is itself ambiguous (a None, a date string). For those cases, you wrap the literal explicitly with lit:

from pyspark.sql.functions import col, lit, when

df5 = orders.withColumn(
    "discount",
    when(col("amount") > 100, lit(10.0)).otherwise(lit(0.0))
)

lit(10.0) produces a Column expression that evaluates to the constant 10.0 for every row. when().otherwise() needs Column arguments on both sides; without lit, you’d be passing raw Python floats to a function that expects Columns, and depending on the version it may or may not auto-promote.

Rule of thumb: inside when().otherwise(), array(), struct(), and most function calls, wrap literals in lit(). Outside of those, when you’re doing simple arithmetic like col("x") + 1, you don’t need lit. PySpark figures it out.

lit(None) is the standard way to add a literal null column:

df6 = orders.withColumn("not_yet_processed_at", lit(None).cast("timestamp"))

Note the cast — lit(None) on its own is typed as null/void, which can confuse downstream operations. Always cast lit(None) to the type you want.

when().otherwise(): the conditional column

The PySpark equivalent of SQL’s CASE WHEN. Reads top-down: first matching branch wins, falls through to .otherwise(...) if nothing matches.

from pyspark.sql.functions import when, col, lit

df7 = orders.withColumn(
    "size_bucket",
    when(col("amount") < 50,  lit("small"))
    .when(col("amount") < 200, lit("medium"))
    .otherwise(lit("large"))
)

Chain as many .when(...) calls as you need. If you skip .otherwise(...), unmatched rows get null for that column, which is sometimes what you want and often a bug waiting to happen. I always include .otherwise(...) even if it’s just lit(None).cast("string"); the explicitness is worth it.

You can use any Column expression as the predicate, not just simple equality:

df8 = orders.withColumn(
    "fraud_risk",
    when(
        (col("amount") > 1000) & col("email").isNull(),
        lit("high")
    )
    .when(col("country").isin("XX", "YY"), lit("medium"))
    .otherwise(lit("low"))
)

Boolean operators follow the same rules as in where: &, |, ~, with parentheses around each comparison. (Lesson 14 covers this; if (col(...) > 100) & (col(...) == "IT") is starting to feel like a tic, that’s the right tic to develop.)

Inline type casts

col("x").cast(type) returns a Column that evaluates to the cast value. You’ll combine it with withColumn constantly when cleaning up data:

df9 = (
    raw
    .withColumn("amount",      col("amount").cast("double"))
    .withColumn("customer_id", col("customer_id").cast("int"))
    .withColumn("ts",          col("ts").cast("timestamp"))
)

The cast argument is either a string DDL type name ("double", "int", "timestamp", "string") or a DataType object. Both work; the string form is shorter and reads better.

If a value can’t be cast (a non-numeric string going to double, an unparseable timestamp), the result is null — silent, no error. If you want to detect bad casts, count nulls before and after:

nulls_before = raw.filter(col("amount").isNull()).count()
casted = raw.withColumn("amount", col("amount").cast("double"))
nulls_after = casted.filter(col("amount").isNull()).count()
print(f"Cast lost {nulls_after - nulls_before} rows to nulls")

The chaining trap

Here’s the sharp edge. withColumn is convenient, and convenience leads to writing code like this:

# 50 columns from a config
new_df = df
for col_spec in feature_specs:
    new_df = new_df.withColumn(col_spec.name, build_expr(col_spec))

Looks innocent. Each withColumn call returns a new DataFrame, which is what they do — it’s not eagerly computing anything, just appending one more node to the logical plan. No physical work happens.

Except: each withColumn call adds one Project node to the plan. After 50 calls, you have 50 nested Project nodes. Catalyst (Spark’s query optimizer — we’ll cover it in lesson 41) tries to collapse them, and most of the time succeeds. But “most of the time” hides cases where it doesn’t, and when it doesn’t, you get:

  • Plan analysis time that grows non-linearly with the number of withColumn calls. On pipelines with hundreds of derived columns, plan-time alone can take minutes before any data is touched.
  • Stack-overflow errors in the planner with very deep chains (this is documented in Spark’s JIRA history; the threshold has moved over versions, but the pattern is real).
  • Optimizer rules that don’t fire because they pattern-match shapes the deep chain doesn’t expose.

The official Spark docs explicitly call this out and recommend a select with all expressions at once when you have many columns to add or modify.

Use select for many columns at once

The fix is to do all the column work in a single select:

from pyspark.sql.functions import col

# Build the list of expressions: keep existing columns, add new ones
projection = [
    col("*"),                                       # keep everything
    (col("amount") * 1.22).alias("amount_with_vat"),
    upper(col("country")).alias("country_upper"),
    when(col("amount") > 100, lit(10.0)).otherwise(lit(0.0)).alias("discount"),
]
result = orders.select(*projection)

Or, if you want full control over the output schema (replacing columns rather than just appending):

result = orders.select(
    "order_id",
    "customer_id",
    col("amount").cast("double").alias("amount"),
    upper(col("country")).alias("country"),
    (col("amount") * 1.22).alias("amount_with_vat"),
    when(col("amount") > 100, lit(10.0)).otherwise(lit(0.0)).alias("discount"),
)

One plan node. One projection. Catalyst gets a flat list of expressions to optimize, not a 50-deep chain to flatten first. Plan-time stays fast, the optimizer behaves consistently.

The same pattern in a loop, generating a list:

exprs = [col("*")]
for spec in feature_specs:
    exprs.append(build_expr(spec).alias(spec.name))

result = df.select(*exprs)

A single select with len(feature_specs) + 1 expressions, regardless of how many specs there are.

When chaining withColumn is fine

I’m not telling you to never use withColumn. For a handful of columns — say, up to ten — chaining is perfectly readable and there’s no performance issue. The trap is specifically when:

  • You’re adding many columns (rough threshold: more than 20 or so).
  • The new columns are being added in a loop driven by config or metadata.
  • Plan-analysis time is showing up as a problem (you’ll see it in Spark UI’s “SQL” tab as a long delay before any stage starts).

For occasional, hand-written transformations, withColumn is the right tool. For programmatic feature engineering with hundreds of columns, build a select list.

Comparing the plans with explain

You can see the difference yourself:

# 20 chained withColumn calls
chained = df
for i in range(20):
    chained = chained.withColumn(f"f{i}", col("amount") * (i + 1))

chained.explain(extended=True)
# The Analyzed and Optimized plans show 20 nested Project nodes,
# which Catalyst (usually) flattens to one in the Physical plan.

# Same thing as a single select
exprs = [col("*")] + [(col("amount") * (i + 1)).alias(f"f{i}") for i in range(20)]
flat = df.select(*exprs)

flat.explain(extended=True)
# One Project node from the start.

For 20 columns the flattening succeeds and the physical plans are identical. Increase to 200 and you’ll start seeing differences in the analysis time, even if the final physical plan still looks the same. (We’ll dig into reading explain output properly in lesson 41 on the Catalyst optimizer.)

Run this on your own machine

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, when, upper, length

spark = SparkSession.builder.appName("withcolumn-demo").getOrCreate()

data = [
    (1, "anne",   59.00,  "it", "anne@example.com"),
    (2, "bob",    149.00, "de", "bob@example.com"),
    (3, "claire", 12.00,  "es",  None),
    (4, "diego",  999.00, "it", "diego@example.com"),
    (5, "elena",  29.00,  "fr", "elena@example.com"),
]
schema = "order_id INT, name STRING, amount DOUBLE, country STRING, email STRING"
orders = spark.createDataFrame(data, schema)

# Q1: simple withColumn — add a derived column
orders.withColumn("amount_with_vat", col("amount") * 1.22).show()

# Q2: replace a column (uppercase the country)
orders.withColumn("country", upper(col("country"))).show()

# Q3: when().otherwise() with lit
orders.withColumn(
    "size",
    when(col("amount") < 50,  lit("small"))
    .when(col("amount") < 200, lit("medium"))
    .otherwise(lit("large"))
).show()

# Q4: lit(None) typed correctly
orders.withColumn("processed_at", lit(None).cast("timestamp")).printSchema()

# Q5: many columns the wrong way (chained)
chained = orders
for i in range(15):
    chained = chained.withColumn(f"f{i}", col("amount") * (i + 1))
print("=== chained ===")
chained.explain(False)   # look for the nested Projects in the analyzed plan

# Q6: many columns the right way (single select)
exprs = [col("*")] + [(col("amount") * (i + 1)).alias(f"f{i}") for i in range(15)]
flat = orders.select(*exprs)
print("=== flat ===")
flat.explain(False)

# The physical plans should match. The analyzed plans differ in depth.

Run both explain calls. The physical plan at the bottom is what actually runs; for moderate column counts it’ll be the same. The analyzed plan above it shows the structural difference. That difference is what bites at scale.

If you’re curious how bad it can get, push the loop count to 500 and time the explain itself. On a laptop the chained version takes seconds; the flat version stays sub-second. You haven’t even processed any data — that’s pure plan-time overhead. In a job that runs hundreds of times a day, that overhead adds up to real money in cluster minutes.

A note on withColumns (plural)

Recent Spark versions added withColumns (plural), which takes a dict of {name: expr} and adds them all at once:

df.withColumns({
    "amount_with_vat": col("amount") * 1.22,
    "country_upper":   upper(col("country")),
})

This is essentially a convenience wrapper around the single-select pattern from above. It produces one Project node, not many. If your Spark version supports it (3.3+), it’s a nicer-reading alternative to building a select list by hand for the “add many columns” case. The chaining trap doesn’t apply because there’s no chain — it’s one call.

Next lesson: aggregations and groupBy — counting, summing, averaging, and the surprises that come with agg versus the shorthand methods.


Reference: Apache Spark Python API (https://spark.apache.org/docs/latest/api/python/), retrieved 2026-05-01.

Search