相关文章推荐
鼻子大的人字拖  ·  2009年春季号·  6 天前    · 
开心的登山鞋  ·  阿雅_百度百科·  1 年前    · 
  • 我们此篇使用的树都是User.json这个,具体如下图

{“username”: “zhangsan”,“age”: 20}
{“username”: “lisi”,“age”: 21}
{“username”: “wangwu”,“age”: 19}

自定义UDF UDF的简介

UDF: 输入一行, 返回一个结果. 一对一关系,放入函数一个值, 就返回一个值, 而不会返回多个值 。如下面的例子就可以看出:

(x: String) => "Name=" + x

这个函数, 入参为一个, 返回也是一个, 而不会返回多个值

object UDF {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("UTF")
      .getOrCreate()
    val df = spark.read
      .json("data/user.json")
    df.createOrReplaceTempView("user")
    //注册udf
    spark.udf.register("prefixName", (name: String) => {
      "Name:" + name
    spark.sql("select age,prefixName(username) from user").show()
    spark.close()

结果展示
【Spark】自定义函数UDF和UDAF_spark
解释

  • UDF在使用之前,需要先注册spark.udf.register

跳转顶部

自定义UDAF UDAF的简介

UDAF主要可以分为强类型和弱类型

  • 强弱类型的主要区别就是强类型要注意数据的类型

强类型的 Dataset 和弱类型的 DataFrame 都提供了相关的聚合函数, 如 count()countDistinct()avg()max()min()。除此之外,用户可以设定自己的自定义聚合函数。通过继承 UserDefinedAggregateFunction 来实现用户自定义弱类型聚合函数。如今UserDefinedAggregateFunction已经不推荐使用了。可以统一采用强类型聚合函数Aggregator

弱类型的UDAF

自定义UDAF

  class MyAvgUDAF extends UserDefinedAggregateFunction {
     * 输入数据的结构,我们这里是求年龄的平均值,所以输入的数据是年龄
     * 由于是聚合函数,肯定时输入一个数组的数据,最后返回一个数据也就是平均值
     * 所以输入的是一个数组,数据的类别名叫age,数据的类型是longType
    override def inputSchema: StructType = {
      StructType(
        Array(
          StructField("age", LongType)
     * 缓冲区
     * 缓冲区是用来暂时存储数据,数据会在这里进行暂时的存储、运算然后才输出数据
     * 例如求平均值:数据在缓冲区进行求和和计算数量,求出平均值后输出
     * @return
    override def bufferSchema: StructType = {
      StructType(
        Array(
          StructField("total", LongType),
          StructField("count", LongType)
     * 函数输出的数据类型就是是计算结果的数据类型
     * @return
    override def dataType: DataType = LongType
     * 函数的稳定性
     * @return
    override def deterministic: Boolean = true
     * 缓冲区的初始换
     * @param buffer
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      //这里就是如何该初始哈缓冲区的数据(也就是归零),这里有两个方法来归零
      //方法一
      //buffer(0) = 0l
      //buffer(1) = 0l
      //方法二
      buffer.update(0, 0l)
      buffer.update(1, 0l)
     * 根据输入的数据来更新缓冲区的数据,也就是缓冲区的计算规则
     * @param buffer
     * @param input
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      //第一个数据就是求和,缓冲区里的数据加上输入的数据
      buffer.update(0, buffer.getLong(0) + input.getLong(0))
      //第二个数据就是计算总数,每次加一即可
      buffer.update(1, buffer.getLong(1) + 1)
     * 缓冲区的数据合并
     * 保留1
     * @param buffer1
     * @param buffer2
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
      buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
     * 计算平均值
     * @param buffer
     * @return
    override def evaluate(buffer: Row): Any = (buffer.getLong(0) / buffer.getLong(1))

主要步骤:

  • 继承UserDefinedAggregateFunction
  • 实现他的方法

方法的含义各是什么?

  • inputSchema:输入数据的结构。由于是聚合,输入数据肯定是一个数组
  • bufferSchema:缓冲区数据的结构,缓冲区就是编写计算规则的,如选哟计算平均值,那么就需要在缓冲区中计算出总数和总和
  • dataType:输出的数据结构,即输出结果的数据结构
  • deterministic:函数的稳定性,确保一致性, 一般用true
  • initialize:缓冲区的初始化即归零
  • update:根据输入的数据来更新缓冲区的数据,也就是缓冲区的计算规则
  • merge:缓冲区的合并
  • evaluate:计算平均值

注册并且使用

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("UDAF")
      .getOrCreate()
    val df = spark.read
      .json("data/user.json")
    df.createOrReplaceTempView("user")
    //注册函数
    spark.udf.register("ageAvg",new MyAvgUDAF())
    spark.sql("select ageAvg(age) from user").show()
    spark.close()
强类型的UDAF 

自定义两个样例类

  //存储缓冲区的数据
  case class Buff(var total: Long, var count: Long)
  //存储输入数据
  case class User(var username: String, var age: Long)

自定义强类型UDAF类

  class MyAvgAgeUDAF extends Aggregator[User, Buff, Long] {
     * 初始值或者是零值
     * 缓冲区的初始化
     * @return
    override def zero: Buff = {
      Buff(0l, 0l)
     * 根据输入的数据来更新缓冲区的数据
     * @param b
     * @param a
     * @return
    override def reduce(b: Buff, a: User): Buff = {
      b.total += a.age
      b.count += 1
     * 合并缓冲区
     * @param b1
     * @param b2
     * @return
    override def merge(b1: Buff, b2: Buff): Buff = {
      b1.total += b2.total
      b1.count += b2.count
     * 计算结果
     * @param reduction
     * @return
    override def finish(reduction: Buff): Long = (reduction.total / reduction.count)
     * 这是固定的写法,若是自定义的类那么就是:product
     * 缓冲区的编码操作
     * @return
    override def bufferEncoder: Encoder[Buff] = Encoders.product
     * 这也是固定的写法,若是scala存在的类(如long,int,string……)就是选择对应的即可
     * 输出的编码操作
     * @return
    override def outputEncoder: Encoder[Long] = Encoders.scalaLong
  • 继承Aggregator
  • 与弱类型相比,此时这里需要定义输入、缓冲区和输出数据的泛型

方法的简绍

  • zero:缓冲区的初始化
  • reduce:根据输入的数据来更新缓冲区的数据,也就是计算总数据数和数据和
  • merge:合并缓冲区数据
  • finsh:计算结果
  • bufferEncoder和·outputEncoder:这两个分别是缓冲区和输出的编码格式,其实是由固定格式的,若再次阶段输出的数据是自定义的那么就是Encoders.product,若输出的数据是scala自带的那么就是Encoders.scalaLong后面的long根据自己输出的数据类型而定

注册并且使用

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("UDAF")
      .getOrCreate()
    import spark.implicits._
    val df = spark.read
      .json("data/user.json")
    df.createOrReplaceTempView("user")
    val ds = df.as[User]
    //将UDAF变成查询的列对象
    val udafCol = new MyAvgAgeUDAF().toColumn
    ds.select(udafCol).show()
    spark.close()