我有这个在pandas数据帧中本地运行的python代码:
df_result = pd.DataFrame(df
.groupby('A')
.apply(lambda x: myFunction(zip(x.B, x.C), x.name))
我想在PySpark中运行它,但在处理pyspark.sql.group.GroupedData对象时遇到问题.
我尝试过以下方法:
sparkDF
.groupby('A')
.agg(myFunction(zip('B', 'C'), 'A'))
返回
KeyError: 'A'
我推测因为’A’不再是一列而我找不到x.name的等价物.
然后
sparkDF
.groupby('A')
.map(lambda row: Row(myFunction(zip('B', 'C'), 'A')))
.toDF()
但是得到以下错误:
AttributeError: 'GroupedData' object has no attribute 'map'
任何建议将非常感谢!
解决方法:
您尝试的是编写UDAF(用户定义聚合函数)而不是UDF(用户定义函数). UDAF是处理按密钥分组的数据的函数.具体来说,他们需要定义如何在单个分区中合并组中的多个值,然后如何跨分区合并键的结果.目前在python中没有办法实现UDAF,它们只能在Scala中实现.
但是,您可以在Python中解决它.您可以使用收集集来收集分组值,然后使用常规UDF来执行您想要的操作.唯一需要注意的是collect_set仅适用于原始值,因此您需要将它们编码为字符串.
from pyspark.sql.types import StringType
from pyspark.sql.functions import col, collect_list, concat_ws, udf
def myFunc(data_list):
for val in data_list:
b, c = data.split(',')
# do something
return <whatever>
myUdf = udf(myFunc, StringType())
df.withColumn('data', concat_ws(',', col('B'), col('C'))) \
.groupBy('A').agg(collect_list('data').alias('data'))
.withColumn('data', myUdf('data'))
如果要进行重复数据删除,请使用collect_set.此外,如果您的某些密钥有很多值,这将会很慢,因为密钥的所有值都需要在集群中的某个分区中收集.如果你的最终结果是你通过以某种方式组合每个键的值来构建的值(例如将它们相加),那么使用RDD aggregateByKey方法实现它可能会更快,这可以让你在改组之前为分区中的每个键构建一个中间值周围的数据.
编辑:11/21/2018
由于这个答案是写的,pyspark使用Pandas增加了对UDAF的支持.使用Panda的UDF和UDAF比使用RDD的直接python函数有一些不错的性能改进.在引擎盖下,它会对列进行矢量化(将多行中的值批处理在一起以优化处理和压缩).请查看here以获得更好的解释,或者查看下面的user6910411答案的示例.