Pivot, Unpivot, and Mapping with PySpark DataFrames

Pivot, Unpivot, and Mapping with PySpark DataFrames

·

7 min read

Recently had to perform a kind of manual unpivoting and mapping of every row value to separate dataframe with specific schema. Pyspark supports pivoting and unpivoting similar to Excel and SQL Server. For examples below we'll use this dataframe below:

products_list = 
[(1, "Shirts", 10, 13, 34, 10), 
(2, "Joggers", 11, 2, 30, 20), 
(3, "Pants", 70, 43, 24, 60), 
(4, "Hoodies", 101, 44, 54, 80)]

schema = ["Id", "Products", "Small", "Medium", "Large", "ExLarge"]
df_product_qty = spark.createDataFrame(products_list, schema)

df_product_qty.show()
+---+--------+-----+------+-----+-------+
| Id|Products|Small|Medium|Large|ExLarge|
+---+--------+-----+------+-----+-------+
|  1|  Shirts|    9|    12|   32|      9|
|  2| Joggers|   10|     3|   35|     21|
|  3|   Pants|   73|    43|   24|     61|
|  4| Hoodies|  101|    42|   53|     85|
+---+--------+-----+------+-----+-------+

Pivot

The pivot() function is used to transform aggregated row values into columns. It kind of looks like a rotation. This function is expected to be used on columns with categorical values (not numerical). Pivoting typically converts from stacked/narrow format wide/flat format. Given the following dataframe:

df_product_small = df_product_qty.groupBy("Product").pivot("Small").count()
df_product_small.show()
+-------+----+----+----+----+
|Product|   9|  10|  73| 101|
+-------+----+----+----+----+
|Hoodies|null|null|null|   1|
|  Pants|null|null|   1|null|
|Joggers|null|   1|null|null|
| Shirts|   1|null|null|null|
+-------+----+----+----+----+

Alternatively:

schema2 = ["Product", "Size", "Cost"]
products_list = [
  ("Shirts", "Small", 10), ("Shirts", "Medium", 12), ("Shirts", "Large", 14), ("Shirts", "Extra Large", 16),
  ("Joggers", "Small", 20), ("Joggers", "Medium", 22), ("Joggers", "Large", 24), ("Joggers", "Extra Large", 26),
  ("Pants", "Small", 30), ("Pants", "Medium", 32), ("Pants", "Large", 34), ("Pants", "Extra Large", 36),
  ("Hoodies", "Small", 25), ("Hoodies", "Medium", 27), ("Hoodies", "Large", 29), ("Hoodies", "Extra Large", 31)
]
df_product_qty = spark.createDataFrame(products_list, schema2)
+-------+----------+----+
|Product|      Size|Cost|
+-------+----------+----+
| Shirts|     Small|  10|
| Shirts|    Medium|  12|
| Shirts|     Large|  14|
| Shirts|ExtraLarge|  16|
|Joggers|     Small|  20|
|Joggers|    Medium|  22|
|Joggers|     Large|  24|
|Joggers|ExtraLarge|  26|
|  Pants|     Small|  30|
|  Pants|    Medium|  32|
|  Pants|     Large|  34|
|  Pants|ExtraLarge|  36|
|Hoodies|     Small|  25|
|Hoodies|    Medium|  27|
|Hoodies|     Large|  29|
|Hoodies|ExtraLarge|  31|
+-------+----------+----+
df_product_small = df_product_qty.groupBy('Product').pivot("Size").sum("Cost")
df_product_small.show()
+-------+----------+-----+------+-----+
|Product|ExtraLarge|Large|Medium|Small|
+-------+----------+-----+------+-----+
|Hoodies|        31|   29|    27|   25|
|  Pants|        36|   34|    32|   30|
|Joggers|        26|   24|    22|   20|
| Shirts|        16|   14|    12|   10|
+-------+----------+-----+------+-----+

Unpivot

Unpivot (melt) is the reverse operation of pivot. Converting from wide/flat format to narrow/stacked. Spark doesn't have an unpivot() function, but unpivoting can still be achieved:

expr_unpivot = "stack(4, 'Small', Small, 'Medium', Medium, 'Large', Large, 'ExtraLarge', ExtraLarge) as (Size, Cost)"
df_unpivot_products = df_product_small.select("Product", expr(expr_unpivot)) \
    .where(f.col('Cost') > 0)
df_unpivot_products.show(truncate=False)
+-------+----------+----+
|Product|      Size|Cost|
+-------+----------+----+
|Hoodies|     Small|  25|
|Hoodies|    Medium|  27|
|Hoodies|     Large|  29|
|Hoodies|ExtraLarge|  31|
|  Pants|     Small|  30|
|  Pants|    Medium|  32|
|  Pants|     Large|  34|
|  Pants|ExtraLarge|  36|
|Joggers|     Small|  20|
|Joggers|    Medium|  22|
|Joggers|     Large|  24|
|Joggers|ExtraLarge|  26|
| Shirts|     Small|  10|
| Shirts|    Medium|  12|
| Shirts|     Large|  14|
| Shirts|ExtraLarge|  16|
+-------+----------+----+

Unpivot Map

Suppose we wanted to take every row value and populate a different dataframe column with all those values.

This is the target schema:

schema= StructType([ \
    StructField('id', StringType(), False), \
    StructField('partition_key', StringType(), False), \
    StructField('table_name', StringType(), True), \
    StructField('column_name', StringType(), True), \
    StructField('column_value', StringType(), True), \
    StructField('column_datatype', StringType(), False)
  ])
def get_partition_key(value: str) -> str:
    if not value or value.isspace():
        return None
    alpha_numeric = string.digits + string.ascii_letters
    if len(value) == 1 and value in alpha_numeric:
        return value.lower()
    if value[0] in alpha_numeric and value[-1] in alpha_numeric:
        return (value[0] + value[-1]).lower()
    return 'some_static_partition'

EXCLUDED_COLUMNS = []

def unpivot_map_dataframe_with_collect(table_name, target_schema, df):
  mapped_rows = []
  for row in df.collect():
    for col in df.schema:
      if col.name in EXCLUDED_COLUMNS:
        continue
        print(col.name)
      column_value = str(row[col.name])
      compound_key_cols = [table_name, col.name, column_value]
      mapped_rows.append(
        (generate_unique_id(compound_key_cols), get_partition_key(column_value)
         , row['id'], table_name, col.name, column_value, col.dataType.typeName())
  return ss.createDataFrame(mapped_rows, target_schema)

The above works, but the collect() forces all the processing to the Spark driver which is fine for small data sets, but would be a bottleneck for larger data sets.

This will work with bigger data sets.

def unpivot_map_dataframe(table_name, target_schema, df_source):
  df_final = ss.createDataFrame([], target_schema)
  compound_key_cols = [TABLE_NAME, COLUMN_NAME, COLUMN_VALUE]

  udf_part_key = udf(lambda v: get_partition_key(v))
  for s in df_source.schema:
    if s.name in EXCLUDED_COLUMNS:
      continue
    df = df_source.select(
      f.col(ROW_ID).alias(ROW_ID), \
      f.lit(table_name).alias('table_name'), \
      f.lit(s.name).alias('column_name'), \
      f.col(s.name).cast(StringType()).alias('column_value'), \
      f.lit(s.dataType.typeName()).alias('column_datatype') \
        .withColumn('compound_key', compound_key_udf(StringType(), compound_key_cols)) \
        .withColumn('partition_key', udf_part_key(f.col('column_value')))
    df_final = df_final.unionByName(df, allowMissingColumns=False)
  return df_final
+--------------------------------+-----------------+------+---------+----------+-----------+--------------+
|__compound_key__                |__partition_key__|__id__|tableName|columnName|columnValue|columnDataType|
+--------------------------------+-----------------+------+---------+----------+-----------+--------------+
|8e09418c308aedb1658d6c3bafe12e51|32               |1     |Products |Large     |32         |long          |
|671569205acfe5a43bf814347ac78482|9                |1     |Products |ExLarge   |9          |long          |
|0c0edcc8dd34e600b9a3f58095c9b775|12               |1     |Products |Medium    |12         |long          |
|5de13acc47b2b46c8b3604cce659814f|ss               |1     |Products |Product   |Shirts     |string        |
|a9eaa88b3171427dbc8cc1b1f67a7846|9                |1     |Products |Small     |9          |long          |
|24d6eeb985c6d40358162dd685c7e590|js               |2     |Products |Product   |Joggers    |string        |
|6da73d351fc40f5a7b44bc5f1b113de7|3                |2     |Products |Medium    |3          |long          |
|7bb69389a7c4ffd434bc1df4cb56c130|35               |2     |Products |Large     |35         |long          |
|7441aa5d293ef65176f90f043222b511|10               |2     |Products |Small     |10         |long          |
|f95db6dbdabc7020a2f307ae0ef95ec1|21               |2     |Products |ExLarge   |21         |long          |
|c10155defaf838f43e427e866452c280|73               |3     |Products |Small     |73         |long          |
|5d0283fc87334bfc04172fdec494fce5|24               |3     |Products |Large     |24         |long          |
|cb55295c15b29c035f0e81eca0443315|ps               |3     |Products |Product   |Pants      |string        |
|f9e515f600efec068aa163888882713a|61               |3     |Products |ExLarge   |61         |long          |
|2fb4f8b1d3115f6e71cce8b74b7d3803|43               |3     |Products |Medium    |43         |long          |
|3454a4f461d80a5fff69a6b927bfc877|hs               |4     |Products |Product   |Hoodies    |string        |
|ba925f58504a695bb6574923a3e2e08d|11               |4     |Products |Small     |101        |long          |
|caad44607a2cb5519d55e56d6db4a6ca|53               |4     |Products |Large     |53         |long          |
|b849319118ddda27bd752789a6e9cc10|42               |4     |Products |Medium    |42         |long          |
|9a98dffe4a3991d7dca2b47601fa8ef0|85               |4     |Products |ExLarge   |85         |long          |
+--------------------------------+-----------------+------+---------+----------+-----------+--------------+

Pivot Map

The pivot-map operation is reversing the unpivot-map operation.

This was one of the initial approaches:

def pivot_map_old(df_table: DataFrame) -> DataFrame:
  target_schema = [StructField(ROW_ID, StringType(), False)]
  col_count, rows, new_row = 0, [], []

  columns_count = df_table.select(COLUMN_NAME).distinct().count()
  column_values = df_table.orderBy(ROW_ID, COLUMN_NAME).collect()

  for row in column_values:
      row_id, name, string_value = row[ROW_ID], row[COLUMN_NAME], row[COLUMN_VALUE]
      datatype = row[COLUMN_DATATYPE]
      datatype_class = locate(f"pyspark.sql.types.{datatype}")
      struct_field = StructField(name, datatype_class() # type: ignore
          , False if string_value == 'None' or string_value.strip() == '' else True)
      value = self._convert_type_alt(string_value)

      if len(target_schema) <= columns_count:
          target_schema.append(struct_field)

      if col_count % columns_count == 0:
          if col_count > 0:
              rows.append(tuple(new_row))
          new_row = [row_id, value]
      else:
          new_row.append(value)

      col_count += 1

  # Could be a lot of rows so may be better to create row then union to dataframe.
  rows.append(tuple(new_row))
  return spark_session.createDataFrame(rows, StructType(target_schema))

The downside to this will be inconsistencies when running on actual cluster. Either that or I have a bug here which I haven't spent time fixing since the approach below should have better performance (no collect of the individual values).

def pivot_map(df_table: DataFrame, schema: StructType = None) -> DataFrame:
  """
  Unpivot-map the the passed unpivoted-mapped dataframe to original schema for that table.

  :param DataFrame df_table: The a feed/table in with unpivot-map schema.
  :param StructType schema: Optional.  If not passed in then schema is determined from the DataType column.
  :returns: DataFrame in original schema of the table/feed.
  :raises Exception: If Pivot-Map operation fails.
  """
  try:
    if not schema:
      # ROW_ID is also a special column since it is the unique ID for the original schema.
      # It will not contain a corresponding datatype value.
      df_datatypes = df_table.groupBy(ROW_ID).pivot(COLUMN_NAME).agg(first(col(COLUMN_DATATYPE))) \
          .drop(ROW_ID)

      dict_datypes = df_datatypes.take(1)[0].asDict()
      dict_datypes[ROW_ID] = 'StringType' # Add back __id__ since it is a special column.

      get_datattype_instance = lambda datatype: locate(f"pyspark.sql.types.{datatype}")() # type: ignore

      # Dictionary iteration is guaranteed to be the same order since it's not modified after this point.
      fields = [ \
          StructField(name, get_datattype_instance(datatype), True) \
              for name, datatype in dict_datypes.items() 
      ]
      schema = StructType(fields)

    df_pivoted = df_table.groupBy(ROW_ID).pivot(COLUMN_NAME).agg(first(col(COLUMN_VALUE)))

    # Alternatively could use a user-defined function (udf) with casting logic (infer or mapping).
    df_pivoted_casted = df_pivoted.select(
        *[col(name).cast(datatype.lower().replace("type", "")) 
            for name, datatype in dict_datypes.items()]
    )

    return spark_session.createDataFrame(df_pivoted_casted.rdd, schema)

  except Exception as ex:
    print(f"Pivot-map error: {str(ex)}")
    raise

This approach chooses the first values of the pivoted column. So basically a pivot without seeming aggregation.

+-------+-----+------+-------+-----+------+
|ExLarge|Large|Medium|Product|Small|__id__|
+-------+-----+------+-------+-----+------+
|      9|   32|    12| Shirts|    9|     1|
|     21|   35|     3|Joggers|   10|     2|
|     61|   24|    43|  Pants|   73|     3|
|     85|   53|    42|Hoodies|  101|     4|
+-------+-----+------+-------+-----+------+

The order of the columns is reversed, but this could can be fixed with select in the particular order or with the order of the original dataframe. In terms of Spark and SQL column order doesn't matter for operations, just display.

Summary

Pivoting can give a tabular overview while unpivoting the data can allow for aggregate operations and further data analysis. Why would we want this? Sometimes the data is not in a form/shape that is easily processed by Spark (set based) operations. Changing the schema/shape allows us flexibility and performance.

References