• 研究SparkSQL内置的数据类型,做成Java类与SparkSQL类型的映射表

推荐阅读spark源码 org.apache.spark.sql.catalyst.ScalaReflection类,其中列举了大部分基础类型与SparkSQL类型的映射。

但我还是重新写了这部分功能,最重要 的原因是源码只支持基本类型,对于复杂或嵌套Java类无能为力。

其次, 我想支持更多的类型,且我想做到对某些类型的对象进行自定义转换。

比如我遇到的Java类中有个属性为Map<String, Object> parameters; 其中的泛型Object无法映射到任何SparkSQL类型中,

导致StructType无法构建完整,造成不得不放弃一部分数据。

但我的做法是,对泛型未指定或指定为Object的,直接调用toString方法转换为String,可以挽回一部分数据丢失。

还有一些常见的,比如需要将java.util.Date转换为java.sql.Date, 将char[]转换为String的。

  • 研究 java.lang.reflect.Type

Type接口有一个子类和四个子接口,一个子类为java.lang.Class(最为大众所知),四个子接口为 GenericArrayType, ParameterizedType,

TypeVariable, WildcardType。

  1. Class : 基本类型,不包含泛型。比如List, Map, Integer
  2. ParameterizedType : 带有泛型的类型,比如List<String>, Map<String, Object>
  3. GenericArrayType : 带有泛型的类型 的数组,比如 List<String>[], Map<String, Object>[]
  4. TypeVariable和WildcardType我们不予考虑,主要是和<T>和<?>之类的东西相关,有兴趣的读者可自行研究

开发时间大概两周,运行较为稳定。下面分享代码,发现问题欢迎指正。

外部调用的主要是两个方法

def getStructType(clazz: Class[_]): Option[StructType]
def getRow(clazz: Class[_], obj: Any): Option[Row]
import java.lang.reflect.{ GenericArrayType, Modifier, ParameterizedType, Field }
import java.lang.{ Iterable => JIterable }
import java.util.{ Map => JMap }
import scala.collection.JavaConversions._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DataType, StructField, StructType, DecimalType, DataTypes }
import org.apache.spark.sql.types.DataTypes._
 * @author yizhu.sun 2016年7月21日
object DataFrameReflectUtil {
  /** 成员变量的类型和sparkSQL类型的映射 */
  val predefinedDataType: collection.mutable.Map[Class[_], DataType] =
    collection.mutable.Map(
      (classOf[Boolean], BooleanType),
      (classOf[java.lang.Boolean], BooleanType),
      (classOf[Byte], ByteType),
      (classOf[java.lang.Byte], ByteType),
      (classOf[Array[Byte]], BinaryType),
      (classOf[Array[java.lang.Byte]], BinaryType),
      (classOf[Short], ShortType),
      (classOf[java.lang.Short], ShortType),
      (classOf[Int], IntegerType),
      (classOf[java.lang.Integer], IntegerType),
      (classOf[Long], LongType),
      (classOf[java.lang.Long], LongType),
      (classOf[Float], FloatType),
      (classOf[java.lang.Float], FloatType),
      (classOf[Double], DoubleType),
      (classOf[java.lang.Double], DoubleType),
      (classOf[Char], StringType),
      (classOf[java.lang.Character], StringType),
      (classOf[Array[Char]], StringType),
      (classOf[Array[java.lang.Character]], StringType),
      (classOf[String], StringType),
      (classOf[java.math.BigDecimal], DecimalType.SYSTEM_DEFAULT),
      (classOf[java.util.Date], DateType),
      (classOf[java.sql.Date], DateType),
      (classOf[java.security.Timestamp], TimestampType),
      (classOf[java.util.Calendar], CalendarIntervalType),
      // 成员为Object类型的,都转为String
      (classOf[Any], StringType))
  /** 类之间的转换。比如将java.util.Date转换为java.sql.Date */
  private val classConverter: Map[Class[_], (Any) => _ <: Any] =
      classOf[java.util.Date] ->
        ((o: Any) => new java.sql.Date(o.asInstanceOf[java.util.Date].getTime)),
      classOf[Char] ->
        ((o: Any) => o.asInstanceOf[Char].toString),
      classOf[java.lang.Character] ->
        ((o: Any) => o.asInstanceOf[java.lang.Character].toString),
      classOf[Array[Char]] ->
        ((o: Any) => new String(o.asInstanceOf[Array[Char]])),
      classOf[Array[java.lang.Character]] ->
        ((o: Any) => new String(o.asInstanceOf[Array[java.lang.Character]].map(_.charValue))),
      classOf[Any] ->
        ((o: Any) => o.toString))
  /** cache of Class -> Option[StructType] */
  private val structTypeCache = new org.apache.commons.collections.map.LRUMap(100)
  /** cache of java.lang.reflect.Type -> Option[DataType] */
  private val dataTypeCache = new org.apache.commons.collections.map.LRUMap(1000)
  /** cache of Class -> Array[Field] */
  private val classFieldsCache = collection.mutable.Map[Class[_], Array[Field]]()
  /** scala.collection.Map 类型的Class的cache */
  private val scalaMapClassCache = collection.mutable.Set[Class[_]]()
  /** scala.collection.Iterable 类型的Class的cache */
  private val scalaIterableClassCache = collection.mutable.Set[Class[_]]()
  /** java.util.Map 类型的Class的cache */
  private val javaMapClassCache = collection.mutable.Set[Class[_]]()
  /** java.lang.Iterable 类型的Class的cache */
  private val javaIterableClassCache = collection.mutable.Set[Class[_]]()
  // 注意在Scala中Map是Iterable的子类
  def isScalaMapClass(clazz: Class[_]) = {
    if (scalaMapClassCache.contains(clazz)) true
    else if (classOf[Map[_, _]].isAssignableFrom(clazz)) {
      scalaMapClassCache += clazz
    } else false
  def isScalaIterableClass(clazz: Class[_]) = {
    if (scalaIterableClassCache.contains(clazz)) true
    else if (classOf[Iterable[_]].isAssignableFrom(clazz)) {
      scalaIterableClassCache += clazz
    } else false
  def isJavaMapClass(clazz: Class[_]) = {
    if (javaMapClassCache.contains(clazz)) true
    else if (classOf[JMap[_, _]].isAssignableFrom(clazz)) {
      javaMapClassCache += clazz
    } else false
  def isJavaIterableClass(clazz: Class[_]) = {
    if (javaIterableClassCache.contains(clazz)) true
    else if (classOf[JIterable[_]].isAssignableFrom(clazz)) {
      javaIterableClassCache += clazz
    } else false
  def getFields(clazz: Class[_]) =
    classFieldsCache.getOrElseUpdate(clazz, {
      val fields = clazz.getDeclaredFields
        .filterNot(f => Modifier.isTransient(f.getModifiers))
        .flatMap(f =>
          getDataType(f.getGenericType) match {
            case Some(_) => Some(f)
            case None => None
      fields.foreach(_.setAccessible(true))
      fields
   * 根据Class对象,生成StructType对象。
  def getStructType(clazz: Class[_]): Option[StructType] = {
    val cachedStructType = structTypeCache.get(clazz)
    if (cachedStructType == null) {
      val fields = getFields(clazz)
      val newStructType =
        if (fields.isEmpty) None
        else {
          val types = fields.map(f => {
            val dataType = getDataType(f.getGenericType).get
            StructField(f.getName, dataType, true) // 默认所有的字段都可能为空
          if (types.isEmpty) None else Some(StructType(types))
      structTypeCache.put(clazz, newStructType)
      newStructType
    } else cachedStructType.asInstanceOf[Option[StructType]]
   * 根据java.lang.reflect.Type获取org.apache.spark.sql.types.DataType
   * 递归处理嵌套类型
  private def getDataType(tp: java.lang.reflect.Type): Option[DataType] = {
    val cachedDataType = dataTypeCache.get(tp)
    if (cachedDataType == null) {
      val newDataType = tp match {
        case ptp: ParameterizedType => // 带有泛型的数据类型,e.g. List[String]
          val clazz = ptp.getRawType.asInstanceOf[Class[_]]
          val rowTypes = ptp.getActualTypeArguments
          if (isScalaMapClass(clazz) || isJavaMapClass(clazz)) {
            (getDataType(rowTypes(0)), getDataType(rowTypes(1))) match {
              case (Some(keyType), Some(valueType)) =>
                Some(DataTypes.createMapType(keyType, valueType, true))
              case _ => None
          } else if (isScalaIterableClass(clazz) || isJavaIterableClass(clazz)) {
            getDataType(rowTypes(0)) match {
              case Some(dataType) => Some(DataTypes.createArrayType(dataType, true))
              case None => None
          } else {
            getStructType(clazz)
        case gatp: GenericArrayType => // 泛型数据类型的数组,e.g. Array[List[String]]
          getDataType(gatp.getGenericComponentType) match {
            case Some(dataType) => Some(DataTypes.createArrayType(dataType, true))
            case None => None
        case clazz: Class[_] => // 没有泛型的类型(包括没有指定泛型的Map和Collection)
          predefinedDataType.get(clazz) match {
            case Some(tp) => Some(tp)
            case None =>
              if (clazz.isArray) { // 非泛型对象的数组
                getDataType(clazz.getComponentType) match {
                  case Some(dataType) => Some(DataTypes.createArrayType(dataType, true))
                  case None => None
              } else if (isScalaMapClass(clazz) || isJavaMapClass(clazz)) {
                Some(DataTypes.createMapType(StringType, StringType, true))
              } else if (isScalaIterableClass(clazz) || isJavaIterableClass(clazz)) {
                Some(DataTypes.createArrayType(StringType, true))
              } else { // 一般Object类型,转换为嵌套类型
                getStructType(clazz)
        case _ =>
          throw new IllegalArgumentException("不支持 WildcardType 和 TypeVariable")
      dataTypeCache.put(tp, newDataType)
      newDataType
    } else cachedDataType.asInstanceOf[Option[DataType]]
   * 读取一行数据
  def getRow(clazz: Class[_], obj: Any): Option[Row] =
    getStructType(clazz) match {
      case Some(_) =>
        if (obj == null) Some(null)
        else Some(Row(getFields(clazz).flatMap(f => getCell(f.getGenericType, f.get(obj))): _*))
      case None => None
   * 读取单个数据
  private def getCell(tp: java.lang.reflect.Type, value: Any): Option[Any] =
    tp match {
      case ptp: ParameterizedType => // 带有泛型的数据类型,e.g. List[String]
        val clazz = ptp.getRawType.asInstanceOf[Class[_]]
        val rowTypes = ptp.getActualTypeArguments
        if (isScalaMapClass(clazz)) {
          (getDataType(rowTypes(0)), getDataType(rowTypes(1))) match {
            case (Some(keyType), Some(valueType)) =>
              if (value == null) Some(null)
              else Some(value.asInstanceOf[Map[Any, Any]].filterKeys(_ != null)
                .map { case (k, v) => getCell(rowTypes(0), k).get -> getCell(rowTypes(1), v).get })
            case _ => None
        } else if (isScalaIterableClass(clazz)) {
          getDataType(rowTypes(0)) match {
            case Some(_) =>
              if (value == null) Some(null)
              else Some(value.asInstanceOf[Iterable[Any]].filter(_ != null).map(v => getCell(rowTypes(0), v).get).toSeq)
            case None => None
        } else if (isJavaIterableClass(clazz)) {
          getDataType(rowTypes(0)) match {
            case Some(_) =>
              if (value == null) Some(null)
              else Some(value.asInstanceOf[JIterable[Any]].filter(_ != null).map(v => getCell(rowTypes(0), v).get).toSeq)
            case None => None
        } else if (isJavaMapClass(clazz)) {
          (getDataType(rowTypes(0)), getDataType(rowTypes(1))) match {
            case (Some(keyType), Some(valueType)) =>
              if (value == null) Some(null)
              else Some(value.asInstanceOf[JMap[Any, Any]].filterKeys(_ != null)
                .map { case (k, v) => getCell(rowTypes(0), k).get -> getCell(rowTypes(1), v).get })
            case _ => None
        } else {
          getCell(clazz, value)
      case gatp: GenericArrayType => // 泛型数据类型的数组,e.g. Array[List[String]]
        getDataType(gatp.getGenericComponentType) match {
          case Some(dataType) => Some(value.asInstanceOf[Array[Any]].map(v => getCell(gatp.getGenericComponentType, v).get).toSeq)
          case None => None
      case clazz: Class[_] => // 没有泛型的类型(包括没有指定泛型的Map和Collection)
        predefinedDataType.get(clazz) match {
          case Some(_) =>
            classConverter.get(clazz) match {
              case Some(converter) => Some(if (value == null) null else converter(value))
              case None => Some(value)
          case None =>
            if (clazz.isArray) { // 非泛型对象的数组
              getDataType(clazz.getComponentType) match {
                case Some(dataType) =>
                  if (value == null) Some(null)
                  else Some(value.asInstanceOf[Array[_]].filter(_ != null).flatMap(v => getCell(clazz.getComponentType, v)).toSeq)
                case None => None
            } else if (isScalaMapClass(clazz)) {
              Some(value.asInstanceOf[Map[Any, Any]].filterKeys(_ != null)
                .map { case (k, v) => getCell(classOf[Any], k).get -> getCell(classOf[Any], v).get })
            } else if (isScalaIterableClass(clazz)) {
              Some(value.asInstanceOf[Iterable[Any]].filter(_ != null)
                .map(v => getCell(classOf[Any], v).get).toSeq)
            } else if (isJavaIterableClass(clazz)) {
              Some(value.asInstanceOf[JIterable[Any]].filter(_ != null)
                .map(v => getCell(classOf[Any], v).get).toSeq)
            } else if (isJavaMapClass(clazz)) {
              Some(value.asInstanceOf[JMap[Any, Any]].filterKeys(_ != null)
                .map { case (k, v) => getCell(classOf[Any], k).get -> getCell(classOf[Any], v).get })
            } else { // 一般Object类型,转换为嵌套类型
              getRow(clazz, value)
      case _ =>
        throw new IllegalArgumentException("不支持 WildcardType 和 TypeVariable")
  val list1: List[Array[Char]],
  val map1: Map[String, Array[Int]],
  val obj1: TInnerClass) extends Serializable
class TInnerClass(
  val date1: java.util.Date) extends Serializable 
// sc:  SparkContext
// ssc: SQLContext
val obj1 = new TClass(
  List(Array('1', '2', '3'), null),
  Map("123" -> Array(1, 2, 3),
    "nil" -> null),
  new TInnerClass(new java.util.Date))
val obj2 = new TClass(
  List(Array('1', '2', '3'), null),
  Map("empty" -> Array(),
    "90" -> Array(9, 0)),
  new TInnerClass(null))
val tClazz = classOf[TClass]
val rdd = sc.makeRDD(Seq(obj1, obj2))
val rowRDD = rdd.flatMap(DataFrameReflectUtil.getRow(tClazz, _))
DataFrameReflectUtil.getStructType(tClazz) match {
  case Some(scheme) =>
    val df = ssc.createDataFrame(rowRDD, scheme)
    df.registerTempTable("df")
    df.printSchema
    ssc.sql("select list1, map1, obj1 from df").show(false)
    ssc.sql("select map1['90'], map1['90'][0], date_add(obj1.date1, 1) from df").show(false)
  case None =>
    println("getStructType failed")
 |-- list1: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- map1: map (nullable = true)
 |    |-- key: string
 |    |-- value: array (valueContainsNull = true)
 |    |    |-- element: integer (containsNull = true)
 |-- obj1: struct (nullable = true)
 |    |-- date1: date (nullable = true)
+-----+------------------------------------------------------+------------+
|list1|map1                                                  |obj1        |
+-----+------------------------------------------------------+------------+
|[123]|Map(123 -> WrappedArray(1, 2, 3), nil -> null)        |[2016-09-01]|
|[123]|Map(empty -> WrappedArray(), 90 -> WrappedArray(9, 0))|[null]      |
+-----+------------------------------------------------------+------------+
+------+----+----------+
|_c0   |_c1 |_c2       |
+------+----+----------+
|null  |null|2016-09-02|
|[9, 0]|9   |null      |
+------+----+----------+
import pandas as pd a={'one':['A','A','B','C','C','A','B','B','A','A'], 'tao':['B','B','C','C','A','A','C','B','C','A'], 'three':['C','B','A','A','B','B','B','A','C','D']} b=pd. DataFrame (a) b.describe() b是 换后 DataFrame ,显示如表格: one tao three 0 A B
DataFrame 是一个组织成命名列的数据集。它在概念上等同于关系数据库中的表或R/Python中的数据框架,但其经过了优化。 DataFrame s可以从各种各样的源构建,例如:结构化数据文件,Hive中的表,外部数据库或现有 RDD DataFrame API 可以被Scala, Java ,Python和R调用。 在Scala和 Java 中, DataFrame 由Rows的数据集表示。 在Scala API中, DataFrame 只是一个类型别名Dataset[Row]。而在 Java API中,用户需要Dataset用来表示 DataFrame 。 在本文档中,我们经常将Scala/ Java 数据
本文来自dongkelun,讲各种情况下的sc.defaultParallelism,defaultMinPartitions,各种情况下创建以及 化。熟悉 Spark 的分区对于 Spark 性能调优很重要,本文总结 Spark 通过各种函数创建 RDD DataFrame 时默认的分区数,其中主要和sc.defaultParallelism、sc.defaultMinPartitions以及HDFS文件的Block数量有关,还有很坑的某些情况的默认分区数为1。如果分区数少,那么并行执行的task就少,特别情况下,分区数为1,即使你分配的Executor很多,而实际执行的Executor只有1个,如果数据很
一: RDD DataFrame 换 通过 反射 的方式来推断 RDD 元素中的元数据。因为 RDD 本身一条数据本身是没有元数据的,例如Person,而Person有name,id等,而record是不知道这些的,但是变成 DataFrame 背后一定知道,通过 反射 的方式就可以了解到背后这些元数据,进而 转换成 DataFrame 。如何 反射 ? Scala: 通过case class...
spark 官方提供了两种方法实现从 RDD 换到 DataFrame 。第一种方法是利用 反射 机制来推断包含特定类型 对象 的Schema,这种方式适用于对已知的数据结构的 RDD 换;第二种方法通过编程接口构造一个 Schema ,并将其应用在已知的 RDD 数据中。 (一) 反射 机制推断Schema 在Windows系统下开发Scala 代码,可以使用本地环境测试,因此首先需要在本地磁 盘准备文本数据文件,这里将HD FS中的/ spark /person.txt文件下载到本地D:/spa...
private val schema: StructType = StructType(List( StructField("name", DataTypes.StringType), StructFiel.
DataFrame 中的StructType类型字段下的所有内容 换为Json字符串。 调用 Spark 源码中的org.apache. spark . sql .execution.datasources.json.JacksonGenerator类,使用Jackson,根据传入的StructType、JsonGenerator和InternalRow,生成Json字符串。 需要实现org.apache. spark . sql .catalyst.expressions包下的UnaryExpression
from py spark . sql import Spark Session spark = Spark Session.builder.appName("text_file_reader").getOrCreate() 2. 使用 Spark Session的read方法读取文本文件 ```python text_file = spark .read.text("path/to/text/file") 3. 将 RDD 换为 DataFrame ```python df = text_file.toDF() 完整代码示例: ```python from py spark . sql import Spark Session spark = Spark Session.builder.appName("text_file_reader").getOrCreate() text_file = spark .read.text("path/to/text/file") df = text_file.toDF() df.show() 其中,"path/to/text/file"为文本文件的路径。