PySpark, from the ground up Lesson 40 / 60

UDFs: when you need them, why you should avoid them

The Python serialization tax of regular UDFs, why pandas_udf saves you, and the rare cases where Scala is the only answer.

A UDF is a User Defined Function: a regular Python function that you ask Spark to apply to a column of a DataFrame. The first time you reach for one it feels like the most natural thing in the world. You have a column, you have a Python function that knows what to do with the values in it, you wrap one with the other and Spark runs it across the cluster. Done.

The problem is that “Spark runs it across the cluster” hides a very expensive detail. Every value your UDF sees has to leave the JVM where Spark lives, get serialized into Python’s process, get processed, get serialized back, and re-enter the JVM. On a one-million-row job nobody notices. On a one-billion-row job it’s the difference between a coffee break and a layover.

This lesson is about the cost of that round trip, the two PySpark mechanisms that mitigate it, and the small set of cases where you genuinely cannot avoid a UDF and just have to write the cleanest one you can.

The serialization tax

When you write this:

from pyspark.sql import functions as F
from pyspark.sql.types import StringType

@F.udf(returnType=StringType())
def shout(s):
    if s is None:
        return None
    return s.upper() + "!"

df = spark.createDataFrame([("hello",), ("world",)], ["word"])
df.withColumn("loud", shout("word")).show()

here is what physically happens for each row at execution time. Spark has the word column sitting in the JVM as a Tungsten-encoded UnsafeRow. To call your Python function, the executor has to:

  1. Pull the value out of the binary row format and convert it to a generic Java object.
  2. Serialize that object using PySpark’s pickling protocol.
  3. Ship the bytes over a local socket to a Python worker process that the executor manages.
  4. Have the Python worker unpickle the bytes into a Python object.
  5. Run your function.
  6. Pickle the result.
  7. Send it back across the socket to the JVM.
  8. Deserialize the bytes into a Java object.
  9. Re-encode that object into Tungsten format so the next operator can use it.

Nine steps per value. Per value. None of this work is parallelizable beyond what Spark already does at the partition level, and none of it is visible in the query plan — Catalyst sees a BatchEvalPython node and gives up trying to reason about what’s inside.

Three things follow from this. First, the throughput hit can be enormous: I’ve seen the same logical operation run sixty times slower as a UDF than as a built-in expression. Second, the optimizer can’t push filters through your UDF. If you have df.filter(shout("word").startswith("HELLO")), the filter cannot be pushed down to the Parquet scan because Catalyst has no idea what shout does. Third, the Python worker is a separate process with its own memory; if your UDF holds onto data, you can OOM the Python side without ever touching the executor’s JVM heap.

The cure starts with not writing UDFs.

The fix order

When someone says “I need a UDF,” the answer is almost always “no you don’t, you need to look harder at pyspark.sql.functions.” Spark’s built-in expressions are written in Scala, run inside the JVM with no serialization round trip, and participate in Catalyst optimization and Tungsten codegen (lessons 41 and 42). They’re orders of magnitude faster than the equivalent UDF.

A few categories that catch people:

# String manipulation — almost always a built-in
F.regexp_extract("col", r"^([A-Z]+)-(\d+)$", 1)
F.regexp_replace("col", r"\s+", " ")
F.split("col", ",").getItem(0)
F.concat_ws("-", "year", "month", "day")
F.lower("col")
F.translate("col", "abc", "xyz")

# JSON — there's a parser
F.from_json("payload", schema)
F.get_json_object("payload", "$.user.email")
F.to_json("struct_col")

# Arrays — full higher-order function support since Spark 2.4
F.transform("arr", lambda x: x * 2)
F.filter("arr", lambda x: x > 0)
F.aggregate("arr", F.lit(0), lambda acc, x: acc + x)
F.array_distinct("arr")
F.array_intersect("a", "b")

# Dates — far more than people use
F.date_trunc("month", "ts")
F.unix_timestamp("ts", "yyyy-MM-dd HH:mm:ss")
F.date_add("dt", 7)
F.months_between("end", "start")

# Conditionals
F.when(F.col("x") > 0, "pos").when(F.col("x") < 0, "neg").otherwise("zero")

If you’ve written a UDF that does any of those things, you’ve left performance on the table. Before reaching for a UDF, search the functions module. The page is long, but the time you spend reading it once will pay for itself the next time you have a transformation to write.

If after a real search the built-ins still don’t cover what you need, your second move isn’t a regular UDF. It’s a pandas_udf.

pandas_udf: the Arrow shortcut

A pandas_udf is the same idea as a regular UDF — your Python function applied to a column — but it ships data in batches using Apache Arrow’s columnar format. Instead of crossing the JVM-Python boundary one row at a time, Spark serializes a few thousand rows into an Arrow record batch (zero-copy where possible), hands the whole batch to Python as a pandas.Series, lets you process it vectorized, and ships back another Arrow batch.

The serialization cost per row drops by one to two orders of magnitude. The Python side runs vectorized NumPy or pandas operations instead of per-row Python interpreter overhead. And the data stays columnar throughout, which means it plays nicely with Tungsten’s columnar memory layout.

The most common flavor is Series → Series:

import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType

@pandas_udf(DoubleType())
def standardize(s: pd.Series) -> pd.Series:
    return (s - s.mean()) / s.std()

df.withColumn("z_score", standardize("amount"))

Spark will call standardize repeatedly, each time with a pandas.Series containing some chunk of the column (the size is governed by spark.sql.execution.arrow.maxRecordsPerBatch, default 10,000). Inside the function you have all of pandas: vectorized arithmetic, NumPy ufuncs, anything that operates on the Series as a whole.

A subtle point: each batch is one piece of the column, so functions like s.mean() give you the batch mean, not the global one. If you need a global statistic, compute it first with a regular aggregation and pass it as a literal, or use a groupby().applyInPandas pattern instead.

The second flavor is Iterator[Series] → Iterator[Series]. Use this when you have expensive setup work that you want to amortize across all batches in a partition — loading a model, opening a database connection, allocating a large buffer:

from typing import Iterator

@pandas_udf(DoubleType())
def predict(batches: Iterator[pd.Series]) -> Iterator[pd.Series]:
    # Setup runs ONCE per executor task, not per batch
    model = load_model_from_disk("/mnt/models/v3.pkl")
    for batch in batches:
        yield pd.Series(model.predict(batch.to_numpy()))

df.withColumn("pred", predict("features"))

This pattern is the standard way to score with a model from inside Spark. The model loads once per task, and each batch reuses it.

The third flavor is Iterator[Tuple[Series, ...]] → Iterator[Series] — same idea but for multi-argument functions:

@pandas_udf(DoubleType())
def weighted(batches: Iterator[tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
    for prices, weights in batches:
        yield prices * weights

df.withColumn("rev", weighted("price", "qty"))

There are also grouped variants — groupby().applyInPandas(...) for full DataFrame-in, DataFrame-out group operations, and mapInPandas for partition-level transformations — but those are deeper than this lesson goes. The three above cover most production cases.

For Arrow to actually kick in you need it enabled and PyArrow installed:

spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

It defaults to true on recent Spark versions; check before assuming it’s off.

When you genuinely need a regular UDF

A short list of cases where neither built-ins nor pandas_udf cleanly help, and a regular UDF is honest:

A Python library with no vectorized API. You need to call a third-party parser — say, a custom industry binary format, or a niche legacy XML schema, or a domain-specific log line — and the library only accepts one record at a time. You can sometimes wrap it in a pandas_udf with a Python loop inside, which still beats a regular UDF because of Arrow batching, but if the library does its own per-call setup the win shrinks.

Stateful row-by-row logic that doesn’t vectorize. Some parsers carry state across rows (a tokenizer with a mode flag, an unbalanced-delimiter detector). You can’t easily express that as a vectorized pandas operation. A pandas_udf with the iterator flavor still helps because you can carry state across batches, but inside each batch you still loop.

Calling out to an external system per row. Calling a slow REST API per row from a UDF is almost always wrong, but sometimes it’s what you have. Use pandas_udf with the iterator flavor and batch the requests. If the API doesn’t support batching at all, you should probably stage the data and call the API outside Spark. If you absolutely cannot, that’s a UDF.

In every one of these cases, pandas_udf is still the better starting point than a plain udf. The plain UDF is the bottom of the fix order, used when you’ve ruled out the others.

The Scala UDF escape hatch

Sometimes the cleanest solution is to write the UDF in Scala or Java, package it as a JAR, register it from PySpark, and call it like any built-in. The JVM-Python boundary disappears entirely; the function runs inside the executor JVM, participates in codegen, and pays no serialization tax. Cases where this is worth it:

  • You’re calling Java/Scala libraries that don’t have Python equivalents (some financial or scientific libraries).
  • The UDF is a hot path in a workload that runs many times a day, and even pandas_udf overhead is unacceptable.
  • You’re writing a reusable function for a platform team and want it available identically from PySpark, Scala Spark, and Spark SQL.

The mechanics: write the function as a Scala UserDefinedFunction, build a JAR, attach it to the session via --jars, and register it:

spark = (SparkSession.builder
         .config("spark.jars", "/path/to/myudfs.jar")
         .getOrCreate())
spark.udf.registerJavaFunction("shout", "com.example.udfs.Shout", StringType())
df.selectExpr("shout(word) AS loud").show()

It’s more operational hassle than people want — you now have a Scala build, a deploy step, and a JAR to keep version-aligned with Spark. That’s why most teams never go here. But for a small team running heavy production workloads, a tiny Scala UDF library can be a meaningful win.

What to remember

UDFs let you escape the SQL/DataFrame world into Python, and that’s exactly why they’re slow: every escape pays a serialization tax, and Catalyst can’t optimize through them. The fix order is built-ins, then pandas_udf, then plain UDF, with Scala UDFs as a last-resort escape hatch for hot paths. Most “I need a UDF” turns into “I need to read the functions docs once.”

Next lesson, we go straight at the optimizer that’s been doing all this rewriting behind your back: Catalyst.


References: Apache Spark Python user guide on Arrow integration (https://spark.apache.org/docs/latest/api/python/user_guide/sql/arrow_pandas.html) and the PySpark API reference for pyspark.sql.functions. Retrieved 2026-05-01.

Search