相关文章推荐
冷静的猴子  ·  Selenium + Chrome ...·  1 年前    · 
年轻有为的香蕉  ·  seaborn distplot / ...·  2 年前    · 
Collectives™ on Stack Overflow

Find centralized, trusted content and collaborate around the technologies you use most.

Learn more about Collectives

Teams

Q&A for work

Connect and share knowledge within a single location that is structured and easy to search.

Learn more about Teams

I'm new to SparkSQL, and I want to calculate the percentage in my data with every status. Here is my data like below:

A   B
11  1
11  3
12  1
13  3
12  2
13  1
11  1
12  2

So,I can do it in SQL like this:

select (C.oneTotal / C.total)   as onePercentage,
       (C.twoTotal / C.total)   as twotPercentage,
       (C.threeTotal / C.total) as threPercentage
from (select count(*) as total,
             sum(case when B = '1' then 1 else 0 end) as oneTotal,
             sum(case when B = '2' then 1 else 0 end) as twoTotal,
             sum(case when B = '3' then 1 else 0 end) as threeTotal
      from test
      group by A) as C;

But in the SparkSQL DataFrame, first I calculate totalCount in every status like below:

// wrong code
val cc = transDF.select("transData.*").groupBy("A")
      .agg(count("transData.*").alias("total"),
        sum(when(col("B") === "1", 1)).otherwise(0)).alias("oneTotal")
        sum(when(col("B") === "2", 1).otherwise(0)).alias("twoTotal")

The problem is the sum(when)'s result is zero.

Do I have wrong use with it? How to implement it in SparkSQL just like my above SQL? And then calculate the portion of every status?

Thank you for your help. In the end, I solve it with sum(when). below is my current code.

val cc = transDF.select("transData.*").groupBy("A")
      .agg(count("transData.*").alias("total"),
        sum(when(col("B") === "1", 1).otherwise(0)).alias("oneTotal"),
        sum(when(col("B") === "2", 1).otherwise(0)).alias("twoTotal"))
      .select(col("total"),
        col("A"),
        col("oneTotal") / col("total").alias("oneRate"),
        col("twoTotal") / col("total").alias("twoRate"))

Thanks again.

you can use sum(when(... or also count(when.., the second option being shorter to write:

val df = Seq(
  (11, 1),
  (11, 3),
  (12, 1),
  (13, 3),
  (12, 2),
  (13, 1),
  (11, 1),
  (12, 2)
).toDF("A", "B")
  .groupBy($"A")
  .agg(
    count("*").as("total"),
    count(when($"B"==="1",$"A")).as("oneTotal"),
    count(when($"B"==="2",$"A")).as("twoTotal"),
    count(when($"B"==="3",$"A")).as("threeTotal")
  .select(
    $"A",
    ($"oneTotal"/$"total").as("onePercentage"),
    ($"twoTotal"/$"total").as("twoPercentage"),
    ($"threeTotal"/$"total").as("threePercentage")
  .show()

gives

+---+------------------+------------------+------------------+
|  A|     onePercentage|     twoPercentage|   threePercentage|
+---+------------------+------------------+------------------+
| 12|0.3333333333333333|0.6666666666666666|               0.0|
| 13|               0.5|               0.0|               0.5|
| 11|0.6666666666666666|               0.0|0.3333333333333333|
+---+------------------+------------------+------------------+

alternatively, you could produce a "long" table with window-functions:

.groupBy($"A",$"B").count() .withColumn("total",sum($"count").over(Window.partitionBy($"A"))) .select( $"A", $"B", ($"count"/$"total").as("percentage") ).orderBy($"A",$"B") .show() +---+---+------------------+ | A| B| percentage| +---+---+------------------+ | 11| 1|0.6666666666666666| | 11| 3|0.3333333333333333| | 12| 1|0.3333333333333333| | 12| 2|0.6666666666666666| | 13| 1| 0.5| | 13| 3| 0.5| +---+---+------------------+

As far as I understood you want to implement the logic like above sql showed in the question.

one way is like below example

package examples
import org.apache.log4j.Level
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
object AggTest extends App {
  val logger = org.apache.log4j.Logger.getLogger("org")
  logger.setLevel(Level.WARN)
  val spark = SparkSession.builder.appName(getClass.getName)
    .master("local[*]").getOrCreate
  import spark.implicits._
  val df = Seq(
    (11, 1),
    (11, 3),
    (12, 1),
    (13, 3),
    (12, 2),
    (13, 1),
    (11, 1),
    (12, 2)
  ).toDF("A", "B")
  df.show(false)
  df.createOrReplaceTempView("test")
  spark.sql(
      |select (C.oneTotal / C.total)   as onePercentage,
      |       (C.twoTotal / C.total)   as twotPercentage,
      |       (C.threeTotal / C.total) as threPercentage
      |from (select count(*) as total,
      |             A,
      |             sum(case when B = '1' then 1 else 0 end) as oneTotal,
      |             sum(case when B = '2' then 1 else 0 end) as twoTotal,
      |             sum(case when B = '3' then 1 else 0 end) as threeTotal
      |      from test
      |      group by A) as C
    """.stripMargin).show

Result :

+---+---+
|A  |B  |
+---+---+
|11 |1  |
|11 |3  |
|12 |1  |
|13 |3  |
|12 |2  |
|13 |1  |
|11 |1  |
|12 |2  |
+---+---+
+------------------+------------------+------------------+
|     onePercentage|    twotPercentage|    threPercentage|
+------------------+------------------+------------------+
|0.3333333333333333|0.6666666666666666|               0.0|
|               0.5|               0.0|               0.5|
|0.6666666666666666|               0.0|0.3333333333333333|
+------------------+------------------+------------------+
        

Thanks for contributing an answer to Stack Overflow!

  • Please be sure to answer the question. Provide details and share your research!

But avoid

  • Asking for help, clarification, or responding to other answers.
  • Making statements based on opinion; back them up with references or personal experience.

To learn more, see our tips on writing great answers.