Converting a PySpark DataFrame Column to a Python List
In this article, we will learn how to convert columns of PySpark DataFrame to a Python List. PySpark applications start with initializing SparkSession which is the entry point of PySpark as shown below.
# SparkSession initializationfrom pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
Note: PySpark shell via pyspark executable, automatically creates the session within the variable spark for users. So you’ll also run this using shell.
Method 1: Using Pandas
For converting the columns of PySpark DataFrame to a Python List, we first require a PySpark Dataframe. We have seen how we can Create a PySpark Dataframe. We will use the same dataframe and extract the values of all columns in a Python list.
For converting columns of PySpark DataFrame to a Python List, we will first select all columns using select() function of PySpark and then we will be using the built-in method toPandas(). toPandas() will convert the Spark DataFrame into a Pandas DataFrame. Then we will simply extract column values using column name and then use list() to store all the values in a python list.
# need to import to use Row in pyspark
from pyspark.sql import Row# Need to import to use date time
from datetime import datetime, date# need to import for session creation
from pyspark.sql import SparkSession# creating the session
spark = SparkSession.builder.getOrCreate()# schema creation by passing list
df = spark.createDataFrame([
Row(a=1, b=4., c='GFG1', d=date(2000, 8, 1),
e=datetime(2000, 8, 1, 12, 0)), Row(a=2, b=8., c='GFG2', d=date(2000, 6, 2),
e=datetime(2000, 6, 2, 12, 0)), Row(a=4, b=5., c='GFG3', d=date(2000, 5, 3),
e=datetime(2000, 5, 3, 12, 0))
])# using df.select() to select all columns and
# then using toPandas() to convert it to pandas dataframe
columns = df.select("*").toPandas()# Simply extracting a column values using column name
# and then using list() to store all the values in a python list
col1 = list(columns['a'])
col2 = list(columns['b'])# Printing the list
print("Values in Column1 are ", col1)
print("Values in Column2 are ", col2)
Output:
Values in Column1 are [1, 2, 4]
Values in Column2 are [4.0, 8.0, 5.0]
Method 2: Using collect()
collect() operation itself is not enough to extract the values of all columns in a Python list. Using collect() alone is helpful when you want to collect the DataFrame column in a Row Type. We can use collect() with other PySpark operations to extract the values of all columns in a Python list.
1. collect() with rdd.map() lambda expression
- In order to convert DataFrame Column to Python List, we first have to select the DataFrame Column we want using rdd.map() lamda expression and then collect the desired DataFrame.
- In rdd.map() lamba expression we can specify either the column index or the column name.
# need to import to use Row in pyspark
from pyspark.sql import Row# Need to import to use date time
from datetime import datetime, date# need to import for session creation
from pyspark.sql import SparkSession# creating the session
spark = SparkSession.builder.getOrCreate()# schema creation by passing list
df = spark.createDataFrame([
Row(a=1, b=4., c='GFG1', d=date(2000, 8, 1),
e=datetime(2000, 8, 1, 12, 0)), Row(a=2, b=8., c='GFG2', d=date(2000, 6, 2),
e=datetime(2000, 6, 2, 12, 0)), Row(a=4, b=5., c='GFG3', d=date(2000, 5, 3),
e=datetime(2000, 5, 3, 12, 0))
])# using rdd.map() to extract the 4th column (3rd index) from DataFrame to the Python list
# used column index to extract the column
column_d_values = df.rdd.map(lambda x: x[3]).collect()
# using rdd.map() to extract the column with name 'd' from DataFrame to the Python list
# used column name to extract the column
column_d_values1 = df.rdd.map(lambda x: x.d).collect()# Printing the list
print("Values in Column d using column index are ", column_d_values)
print("Values in Column d using column name are ", column_d_values1)
2. collect() with flatMap() Transformation
- Instead of using rdd.map() we can use flatMap() Transformation along with collect() to extract the values of all columns in a Python list.
# need to import to use Row in pyspark
from pyspark.sql import Row# Need to import to use date time
from datetime import datetime, date# need to import for session creation
from pyspark.sql import SparkSession# creating the session
spark = SparkSession.builder.getOrCreate()# schema creation by passing list
df = spark.createDataFrame([
Row(a=1, b=4., c='GFG1', d=date(2000, 8, 1),
e=datetime(2000, 8, 1, 12, 0)), Row(a=2, b=8., c='GFG2', d=date(2000, 6, 2),
e=datetime(2000, 6, 2, 12, 0)), Row(a=4, b=5., c='GFG3', d=date(2000, 5, 3),
e=datetime(2000, 5, 3, 12, 0))
])# using flatMap() transformation to extract the column with name 'd'
# from DataFrame to the Python list used column name to extract column
column_d_values = df.select(df.d).rdd.flatMap(lambda x: x).collect()# Printing the list
print("Values in Column d using column name are ", column_d_values1)
So this is are few methods of Converting a PySpark DataFrame Column to a Python List.