There are times when you want an user-defined function (UDF) to return multiple values. You can return them as either a array, map or struct column type. However, what if you want to create new columns based off those values?
from uuid import uuid4
from pyspark.sql.functions import udf, col, current_timestamp
from pyspark.sql.types import StructType, StructField, IntegerType, TimestampType
PRODUCT, ID = 'Product', "ID"
products_list = [(1, "Shirts", 9, 12, 32, 9), (2, "Joggers", 10, 3, 35, 21), (3, "Pants", 73, 43, 24, 61), (4, "Hoodies", 101, 42, 53, 85)]
df_product_qty = spark.createDataFrame(products_list, [ID, PRODUCT, "Small", "Medium", "Large", "ExLarge"])
df_product_qty.show()
+---+-------+-----+------+-----+-------+
| ID|Product|Small|Medium|Large|ExLarge|
+---+-------+-----+------+-----+-------+
| 1| Shirts| 9| 12| 32| 9|
| 2|Joggers| 10| 3| 35| 21|
| 3| Pants| 73| 43| 24| 61|
| 4|Hoodies| 101| 42| 53| 85|
+---+-------+-----+------+-----+-------+
schema = StructType([
StructField(ID, IntegerType(), False),
StructField("value", StringType(), True),
StructField("created_datetime",StringType(), True)
])
def udf_map(id, product):
if product == "Pants":
return (id, "Pantaloons", uuid4().hex)
return (id, None, None)
map_udf = udf(udf_map, schema)
df = df_product_qty.select(map_udf(col(ID), col(PRODUCT)).alias('g'))
df = df.select([f"g.{col}" for col in schema.names]
df.show(truncate=False)
df.printSchema()
+---+----------+--------------------------------+
|ID |value |created_datetime |
+---+----------+--------------------------------+
|1 |null |null |
|2 |null |null |
|3 |Pantaloons|1462163bedf246409a4870e3128349a0|
|4 |null |null |
+---+----------+--------------------------------+
Alternatively with one line with flatMap
.
df = df_product_qty.select(map_udf(col(ID), col(PRODUCT))).rdd.flatMap(lambda x: x).toDF(schema)
With this returned DataFrame, it can join back to the original df_product_qty.