pyspark实现连续数据分桶并映射到自定义标签(类似panda.cut功能)

本篇要解决的问题:利用 pyspark 已有的API实现 pandas.cut 的功能。
选择的工具是:分桶工具 Bucketizer

  • 定义数据集
  • >>> splits = [float("-inf"),10000.0,20000.0,30000.0,float('inf')]
    >>> labels = [ "(-inf,10000)","[10000,20000)","[20000,30000)","[30000,inf)"]
    >>> df = sc.parallelize([(1,4000),(2,12000),(3,13500),(4,21000),(5,31000)]).toDF(['id','sale'])
    >>> df.show()
    +---+-----+
    | id| sale|
    +---+-----+
    |  1| 4000|
    |  2|12000|
    |  3|13500|
    |  4|21000|
    |  5|31000|
    +---+-----+
    
    >>> from pyspark.ml.feature import Bucketizer
    >>> from pyspark.sql.functions import array, col, lit
    >>> bucketizer = Bucketizer(splits=splits, inputCol='sale',outputCol='split')
    >>> with_split = bucketizer.transform(df)
    >>> with_split.show()
    +---+-----+-----+
    | id| sale|split|
    +---+-----+-----+
    |  1| 4000|  0.0|
    |  2|12000|  1.0|
    |  3|13500|  1.0|
    |  4|21000|  2.0|
    |  5|31000|  3.0|
    +---+-----+-----+
    

    显示分桶后标签

    >>> label_array = array(*(lit(label) for label in labels))
    >>> print label_array
    Column<array((-inf,10000), [10000,20000), [20000,30000), [30000,inf))>
    >>> with_label = with_split.withColumn('label', label_array.getItem(col('split').cast('integer')))
    >>> with_label.show()
    +---+-----+-----+-------------+
    | id| sale|split|        label|
    +---+-----+-----+-------------+
    |  1| 4000|  0.0| (-inf,10000)|
    |  2|12000|  1.0|[10000,20000)|
    |  3|13500|  1.0|[10000,20000)|
    |  4|21000|  2.0|[20000,30000)|
    |  5|31000|  3.0|  [30000,inf)|
    +---+-----+-----+-------------+
    
  • 方法二:其实与方法一相同,只不过改成了udf的方式。
  • >>> from pyspark.sql.functions import udf
    >>> from pyspark.sql.types import *
    >>> t = {0.0: "(-inf,10000)",1.0:"[10000,20000)",2.0:"[20000,30000)",3.0:"[30000,inf)"}
    >>> udf_foo = udf(lambda x: t[x], StringType())
    >>> with_split.withColumn("label",udf_foo("split")).show()
    +---+-----+-----+-------------+
    | id| sale|split|        label|
    +---+-----+-----+-------------+
    |  1| 4000|  0.0| (-inf,10000)|
    |  2|12000|  1.0|[10000,20000)|
    |  3|13500|  1.0|[10000,20000)|
    |  4|21000|  2.0|[20000,30000)|
    |  5|31000|  3.0|  [30000,inf)|
    +---+-----+-----+-------------+
    

    整理成最终解决方案

    将前面的过程最终整理成函数的形式

    from pyspark.ml.feature import Bucketizer
    from pyspark.sql.functions import array, col, lit
    def cut(df,splits,inputCol,outputCol='cut',labels=[]):
        if len(splits) < 2:
            raise RuntimeError("splits's length must grater then 2.")
        if len(labels) != len(splits) -1:
            labels = []
            begin = str(splits[0])
            for i in range(1,len(splits)):
                end = str(splits[i])
                labels.append("[%s,%s)" % (begin,end))
                begin = end
        bucketizer = Bucketizer(splits=splits, inputCol=inputCol,outputCol='split')
        with_split = bucketizer.transform(df)
        label_array = array(*(lit(label) for label in labels))
        with_label = with_split.withColumn(outputCol, label_array.getItem(col('split').cast('integer')))
        return with_label
    df = sc.parallelize([(1,4000),(2,12000),(3,13500),(4,21000),(5,31000)]).toDF(['id','sale'])
    splits = [float("-inf"),10000.0,20000.0,30000.0,float('inf')]
    dfr=cut(df,splits,inputCol='sale')
    dfr.show()
    

    输出结果如下:

    +---+-----+-----+-----------------+
    | id| sale|split|              cut|
    +---+-----+-----+-----------------+
    |  1| 4000|  0.0|   [-inf,10000.0)|
    |  2|12000|  1.0|[10000.0,20000.0)|
    |  3|13500|  1.0|[10000.0,20000.0)|
    |  4|21000|  2.0|[20000.0,30000.0)|
    |  5|31000|  3.0|    [30000.0,inf)|
    +---+-----+-----+-----------------+
    Bucketizer知识点补充说明
    

    Bucketizer的作用是将连续值映射到离散的桶中,分桶规则是左闭右开

  • 下面的示例测试了分桶的边界取值逻辑
  • >>> df1 = sc.parallelize([(1,4000),(2,10000),(3,13500),(4,20000),(5,31000)]).toDF(['id','sale'])
    >>> with_split1 = bucketizer.transform(df1)
    >>> with_split1.show()
    +---+-----+-----+
    | id| sale|split|
    +---+-----+-----+
    |  1| 4000|  0.0|
    |  2|10000|  1.0|
    |  3|13500|  1.0|
    |  4|20000|  2.0|
    |  5|31000|  3.0|
    +---+-----+-----+
    >>> with_split1.withColumn('label', label_array.getItem(col('split').cast('integer'))).show()
    +---+-----+-----+-------------+
    | id| sale|split|        label|
    +---+-----+-----+-------------+
    |  1| 4000|  0.0| (-inf,10000)|
    |  2|10000|  1.0|[10000,20000)|
    |  3|13500|  1.0|[10000,20000)|
    |  4|20000|  2.0|[20000,30000)|
    |  5|31000|  3.0|  [30000,inf)|