Combining rows into an array in pyspark

Overview

I’ve just spent a bit of time trying to work out how to group a Spark Dataframe by a given column then aggregate up the rows into a single ArrayType column.

Given the input;

transaction_iditem
1a
1b
1c
1d
2a
2d
3c
4b
4c
4d

I want to turn that into the following;

transaction_iditems
1[a, b, c, d]
2[a, d]
3[c]
4[b, c, d]

To achieve this, I can use the following query;

from pyspark.sql.functions import collect_list

df = spark.sql('select transaction_id, item from transaction_data')

grouped_transactions = df.groupBy('transaction_id').agg(collect_list('item').alias('items'))