📜  PySpark Collect() – 从 DataFrame 中检索数据(1)

📅  最后修改于: 2023-12-03 15:04:02.020000             🧑  作者: Mango

PySpark Collect() – 从 DataFrame 中检索数据

在 PySpark 中,Collect() 是一种检索 DataFrame 中数据的方法。可以使用 PySpark Collect() 方法将 DataFrame 转换为本地数据结构(例如 List、Set等),并将该数据结构存储在驱动程序中。

collect() 方法的使用

collect() 方法非常简单。只需调用 DataFrame 上的 collect() 方法即可。PySpark 将 DataFrame 转换为本地数据结构,并将它们存储在驱动程序中。以下是一个简单的示例:

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("CollectMethodDemo").getOrCreate()

data = [('Alice', 1), ('Bob', 2), ('Charlie', 3), ('Dave', 4), ('Eva', 5)]
rdd = spark.sparkContext.parallelize(data)
df = spark.createDataFrame(rdd, ['Name', 'ID'])

data_list = df.collect()

for row in data_list:
    print(row)

在上面的示例中,我们从一个 RDD 创建了一个 DataFrame,并在 DataFrame 上调用了 collect() 方法。然后,将DataFrame转换为本地 Python List 并在控制台上打印出来。输出结果如下:

Row(Name='Alice', ID=1)
Row(Name='Bob', ID=2)
Row(Name='Charlie', ID=3)
Row(Name='Dave', ID=4)
Row(Name='Eva', ID=5)

可以看到,collect() 方法返回一个包含DataFrame的所有行的本地 Python 列表。如果 DataFrame 包含数百万行,那么要小心使用 collect() 方法,因为它可能导致 OOM 错误。通常,collect() 可用于检查 DataFrame 是否正确加载。一旦数据加载完毕,我们应该避免使用 collect() 方法。

检索指定列

如果只需要检索 DataFrame 中的特定列,则不必为整个数据集调用collect() 方法。相反,可以使用 PySpark中的 select() 方法来选择数据集中的特定列,并调用 collect() 以从 DataFrame 中检索数据。下面是一个例子:

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("CollectMethodDemo").getOrCreate()

data = [('Alice', 1), ('Bob', 2), ('Charlie', 3), ('Dave', 4), ('Eva', 5)]
rdd = spark.sparkContext.parallelize(data)
df = spark.createDataFrame(rdd, ['Name', 'ID'])

selected_data = df.select('Name').collect()

for row in selected_data:
    print(row)

在上面的示例中,我们使用 select() 方法选择了 DataFrame 中的 'Name' 列,并在 DataFrame 上调用了 collect() 方法。然后,将DataFrame转换为本地 Python List 并在控制台上打印出来。输出的结果如下:

Row(Name='Alice')
Row(Name='Bob')
Row(Name='Charlie')
Row(Name='Dave')
Row(Name='Eva')
总结

在 PySpark 中,collect() 方法是从 DataFrame 中检索数据的一种简单方法。使用 collect() 方法来检查 DataFrame 是否正确加载,但是我们要避免在数据集太大时使用 collect(),因为它可能导致 OOM 错误。如果只需要 DataFrame 中的特定列,则可以使用 select() 方法来选择数据集中的特定列,并调用 collect() 以从 DataFrame 中检索数据。