Return columns from Pyspark UDF

Return columns from Pyspark UDF

With multi-step process.

·

2 min read

There are times when you want an user-defined function (UDF) to return multiple values. You can return them as either a array, map or struct column type. However, what if you want to create new columns based off those values?

from uuid import uuid4
from pyspark.sql.functions import udf, col, current_timestamp
from pyspark.sql.types import StructType, StructField, IntegerType, TimestampType

PRODUCT, ID = 'Product', "ID"
products_list = [(1, "Shirts", 9, 12, 32, 9), (2, "Joggers", 10, 3, 35, 21), (3, "Pants", 73, 43, 24, 61), (4, "Hoodies", 101, 42, 53, 85)]
df_product_qty = spark.createDataFrame(products_list, [ID, PRODUCT, "Small", "Medium", "Large", "ExLarge"])

df_product_qty.show()
+---+-------+-----+------+-----+-------+
| ID|Product|Small|Medium|Large|ExLarge|
+---+-------+-----+------+-----+-------+
|  1| Shirts|    9|    12|   32|      9|
|  2|Joggers|   10|     3|   35|     21|
|  3|  Pants|   73|    43|   24|     61|
|  4|Hoodies|  101|    42|   53|     85|
+---+-------+-----+------+-----+-------+
schema = StructType([
    StructField(ID, IntegerType(), False),
    StructField("value", StringType(), True),
    StructField("created_datetime",StringType(), True)
])

def udf_map(id, product):
    if product == "Pants":
      return (id, "Pantaloons", uuid4().hex)
    return (id, None, None)

map_udf = udf(udf_map, schema)

df = df_product_qty.select(map_udf(col(ID), col(PRODUCT)).alias('g'))
df = df.select([f"g.{col}" for col in schema.names]

df.show(truncate=False)
df.printSchema()
+---+----------+--------------------------------+
|ID |value     |created_datetime                |
+---+----------+--------------------------------+
|1  |null      |null                            |
|2  |null      |null                            |
|3  |Pantaloons|1462163bedf246409a4870e3128349a0|
|4  |null      |null                            |
+---+----------+--------------------------------+

Alternatively with one line with flatMap.

df = df_product_qty.select(map_udf(col(ID), col(PRODUCT))).rdd.flatMap(lambda x: x).toDF(schema)

With this returned DataFrame, it can join back to the original df_product_qty.

References