Converting a PySpark DataFrame Column to a Python List

Chirag Shilwant
4 min readJul 6, 2021

--

In this article, we will learn how to convert columns of PySpark DataFrame to a Python List. PySpark applications start with initializing SparkSession which is the entry point of PySpark as shown below.

# SparkSession initializationfrom pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

Note: PySpark shell via pyspark executable, automatically creates the session within the variable spark for users. So you’ll also run this using shell.

Method 1: Using Pandas

For converting the columns of PySpark DataFrame to a Python List, we first require a PySpark Dataframe. We have seen how we can Create a PySpark Dataframe. We will use the same dataframe and extract the values of all columns in a Python list.

For converting columns of PySpark DataFrame to a Python List, we will first select all columns using select() function of PySpark and then we will be using the built-in method toPandas(). toPandas() will convert the Spark DataFrame into a Pandas DataFrame. Then we will simply extract column values using column name and then use list() to store all the values in a python list.

# need to import to use Row in pyspark
from pyspark.sql import Row
# Need to import to use date time
from datetime import datetime, date
# need to import for session creation
from pyspark.sql import SparkSession
# creating the session
spark = SparkSession.builder.getOrCreate()
# schema creation by passing list
df = spark.createDataFrame([
Row(a=1, b=4., c='GFG1', d=date(2000, 8, 1),
e=datetime(2000, 8, 1, 12, 0)),
Row(a=2, b=8., c='GFG2', d=date(2000, 6, 2),
e=datetime(2000, 6, 2, 12, 0)),
Row(a=4, b=5., c='GFG3', d=date(2000, 5, 3),
e=datetime(2000, 5, 3, 12, 0))
])
# using df.select() to select all columns and
# then using toPandas() to convert it to pandas dataframe
columns = df.select("*").toPandas()
# Simply extracting a column values using column name
# and then using list() to store all the values in a python list
col1 = list(columns['a'])
col2 = list(columns['b'])
# Printing the list
print("Values in Column1 are ", col1)
print("Values in Column2 are ", col2)

Output:

Values in Column1 are  [1, 2, 4]
Values in Column2 are [4.0, 8.0, 5.0]

Method 2: Using collect()

collect() operation itself is not enough to extract the values of all columns in a Python list. Using collect() alone is helpful when you want to collect the DataFrame column in a Row Type. We can use collect() with other PySpark operations to extract the values of all columns in a Python list.

1. collect() with rdd.map() lambda expression

  • In order to convert DataFrame Column to Python List, we first have to select the DataFrame Column we want using rdd.map() lamda expression and then collect the desired DataFrame.
  • In rdd.map() lamba expression we can specify either the column index or the column name.
# need to import to use Row in pyspark
from pyspark.sql import Row
# Need to import to use date time
from datetime import datetime, date
# need to import for session creation
from pyspark.sql import SparkSession
# creating the session
spark = SparkSession.builder.getOrCreate()
# schema creation by passing list
df = spark.createDataFrame([
Row(a=1, b=4., c='GFG1', d=date(2000, 8, 1),
e=datetime(2000, 8, 1, 12, 0)),
Row(a=2, b=8., c='GFG2', d=date(2000, 6, 2),
e=datetime(2000, 6, 2, 12, 0)),
Row(a=4, b=5., c='GFG3', d=date(2000, 5, 3),
e=datetime(2000, 5, 3, 12, 0))
])
# using rdd.map() to extract the 4th column (3rd index) from DataFrame to the Python list
# used column index to extract the column
column_d_values = df.rdd.map(lambda x: x[3]).collect()
# using rdd.map() to extract the column with name 'd' from DataFrame to the Python list
# used column name to extract the column
column_d_values1 = df.rdd.map(lambda x: x.d).collect()
# Printing the list
print("Values in Column d using column index are ", column_d_values)
print("Values in Column d using column name are ", column_d_values1)

2. collect() with flatMap() Transformation

  • Instead of using rdd.map() we can use flatMap() Transformation along with collect() to extract the values of all columns in a Python list.
# need to import to use Row in pyspark
from pyspark.sql import Row
# Need to import to use date time
from datetime import datetime, date
# need to import for session creation
from pyspark.sql import SparkSession
# creating the session
spark = SparkSession.builder.getOrCreate()
# schema creation by passing list
df = spark.createDataFrame([
Row(a=1, b=4., c='GFG1', d=date(2000, 8, 1),
e=datetime(2000, 8, 1, 12, 0)),
Row(a=2, b=8., c='GFG2', d=date(2000, 6, 2),
e=datetime(2000, 6, 2, 12, 0)),
Row(a=4, b=5., c='GFG3', d=date(2000, 5, 3),
e=datetime(2000, 5, 3, 12, 0))
])
# using flatMap() transformation to extract the column with name 'd'
# from DataFrame to the Python list used column name to extract column
column_d_values = df.select(df.d).rdd.flatMap(lambda x: x).collect()
# Printing the list
print("Values in Column d using column name are ", column_d_values1)

So this is are few methods of Converting a PySpark DataFrame Column to a Python List.

--

--