本文会介绍几个hive中关于排序的非常有用的窗口函数,它们可以帮助处理TopN,前N%这类问题,
更酷炫的是,它们还支持分组、排序, 前几不是问题,我们order by也可以解决。但是分组之后的前几能够 帮助我们极大的简化工作量。
我们后面有一个测试程序可以生成数据,测试本文要介绍的函数,这个程序并不需要依赖安装hive与spark, 只需要导入后面pom文件中的依赖就可以了。
下面是本文用到的数据,可以用于和后面算法结果对比,当然后面也有这些数据的csv格式数据,可以自己保存导入测试。
row_number
row_number() over ([partition col1] [order by col2])
row_number() over ([partition col1,col2,...] [order by col1,col2,...])
select country,gid,uv, row_number() over(partition by country order by uv desc) rank from temp
row_number顾名思义就是添加行号,配合分组排序,就可以计算每个分组中的前N个。
row_number既然是行号,那么就不会有重复的。
partition可以指定多个字段,但是谨慎使用多字段,数据量大的情况下,多指定一个partition会大大的加大统计时间。
rank() over ([partition col1] [order by col2])
select country,gid,uv, rank() over(partition by country order by gid desc) rank from temp
rank和row_number基本一致,唯一区别就是排序值可能不同,rank排序的值可以重复,
例如可以有2个并列第1,然后没有第2,直接第3,毕竟很多比赛只想奖励前3个人,而不是前3名。
dense_rank
dense_rank() over ([partition col1] [order by col2])
select country,gid,uv, dense_rank() over(partition by country order by gid desc) rank from temp
dense_rank和rank差不多,不过排序值有所不同,dense_rank允许并列,同时排序是顺序的,
例如有2个并列第1,那么第3个人的排名是第2,这点和rank不同。
关于row_number、rank、dense_rank的区别看下面一张图就明白了:
percent_rank
percent_rank() over ([partition col1] [order by col2])
select country,gid,uv, percent_rank() over(partition by country order by uv desc) rank from temp
percent_rank也是非常有用的,举2个简单的例子:
一次比赛如果我不想取前几名,而是想让前5%为一等奖,6%-15%为二等奖,%16-30%为三等奖,怎么处理?
一些数据需要预处理,我要过滤掉前10%和最后10%的数据,怎么处理?
如果使用row_number、rank、dense_rank还得取计算总数,非常麻烦。使用percent_rank轻松搞定。
percent_rank的算法如下:
percent_rank = 分组内当前行的rank值-1/分组内总行数-1
ntile
select country,gid,uv, ntile(10) over(partition by country order by uv desc) slice from temp
ntile(n),用于将分组数据按照顺序切分成n片,返回当前切片值,如果切片不均匀,默认增加第一个切片的分布
对于前面介绍的percent_rank函数,肯定有朋友吐槽,我要取前10%的数据,percent_rank的算法不精确啊。
其实对于分组数据比较多的情况,percent_rank的算法还是非常接近的,并且分组数据越大越准确,所以一般情况下percent_rank就够用了。
如果非要精确的百分比,就可以考虑ntile,要前10%的数据,就把分片分为10组,然后取分组值为1的数据。其实这样也是有一些问题的,因为分片可能不均匀。
所以不用强求,都用百分百了还要强求精确值干什么?业务就可以避免,要绝对精确就使用rank。
cume_dist
select country,gid,uv, cume_dist() over(partition by country order by uv desc) top_percent from temp
cume_dist主要用于计算在top多少里面。比如要计算某人的工资在前10%、还是前20%、还是前30%里面就可以使用cume_dist。
cume_dist计算公式 :
cume_dist = 小于等于当前值的行数/分组内总行数
first_value
select country,gid,uv, first_value(uv) over(partition by country order by uv desc) first from temp
取分组排序后,截止到当前行,第一个值。注意不是最大值最小值,最大值最小值最后使用rank计算。
last_value
select country,gid,uv, last_value(uv) over(partition by country order by uv desc) last from temp
取分组内排序后,截止到当前行,最后一个值
和first_value的基本属于同类但是相反的操作。
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.Before;
import org.junit.Test;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
public class SparkRankTest implements Serializable{
private static final String DATA_PATH = "F:\\tmp\\data.csv";
private static final String DATA_OBJECT_PATH = "F:\\tmp\\info";
private static String[] counties = {"中国","俄罗斯","美国","日本","韩国"};
private SparkSession sparkSession;
private Dataset<Info> dataset;
@Before
public void setUp(){
sparkSession = SparkSession
.builder()
.appName("test")
.master("local")
.getOrCreate();
sparkSession.sparkContext().setLogLevel("WARN");
@Test
public void start() throws IOException, ClassNotFoundException {
// List<Info> infos = getData(true);
List<Info> infos = getData(false);
dataset = sparkSession.createDataset(infos, Encoders.bean(Info.class));
// rowNumber();
// rank();
denseRank();
// percentRank();
// ntile();
// cumeDist();
// firstValue();
// lastValue();
private void rowNumber(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,gid,uv, row_number() over(partition by country order by uv desc) rank from temp";
Dataset<Row> ds = sparkSession.sql(sql);
// ds = ds.filter("rank < 0.95");
ds.show(100);
private void rank(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,gid,uv, rank() over(partition by country order by gid desc) rank from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(100);
private void denseRank(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,gid,uv, dense_rank() over(partition by country order by gid desc) rank from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(100);
private void percentRank(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,gid,uv, percent_rank() over(partition by country order by uv desc) rank from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(100);
private void ntile(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,gid,uv, ntile(10) over(partition by country order by uv desc) slice from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds = ds.filter("slice < 9");
ds.show(100);
private void cumeDist(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,gid,uv, cume_dist() over(partition by country order by uv desc) top_percent from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(100);
private void firstValue(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,gid,uv, first_value(uv) over(partition by country order by uv desc) first from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(100);
private void lastValue(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,gid,uv, last_value(uv) over(partition by country order by uv desc) last from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(100);
private static List<Info> getData(Boolean newGen) throws IOException, ClassNotFoundException {
if(newGen != null && newGen == true){
return generateData();
}else {
return readList();
private static List<Info> generateData() throws IOException {
FileWriter fileWriter = new FileWriter(DATA_PATH);
LinkedList<Info> infos = new LinkedList<>();
Random random = new Random();
for(int i=0;i<50;i++){
Info info = new Info();
String county = counties[random.nextInt(counties.length)];
info.setCountry(county);
int gid = random.nextInt(5);
info.setGid(gid);
int uv = random.nextInt(10000);
info.setUv(uv);
infos.add(info);
fileWriter.write(String.format("%s,%d,%d\n",county,gid,uv));
fileWriter.flush();
writeList(infos);
return infos;
private static void writeList(LinkedList<Info> infos) throws IOException {
FileOutputStream fos = new FileOutputStream(DATA_OBJECT_PATH);
ObjectOutputStream oos = new ObjectOutputStream(fos);
oos.writeObject(infos);
private static LinkedList<Info> readList() throws IOException, ClassNotFoundException {
FileInputStream fis = new FileInputStream(DATA_OBJECT_PATH);
ObjectInputStream ois = new ObjectInputStream(fis);
LinkedList<Info> list = (LinkedList) ois.readObject();
return list;
* 必须public,必须实现Serializable
public static class Info implements Serializable {
private String country;
* 分组id
private Integer gid;
* 活跃用户
private Integer uv;
public String getCountry() {
return country;
public void setCountry(String country) {
this.country = country;
public Integer getGid() {
return gid;
public void setGid(Integer gid) {
this.gid = gid;
public Integer getUv() {
return uv;
public void setUv(Integer uv) {
this.uv = uv;
pom文件
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.curitis</groupId>
<artifactId>spark-learn</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<properties>
<spring.version>5.1.8.RELEASE</spring.version>
<junit.version>4.11</junit.version>
<spark.version>2.4.3</spark.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.12</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.56</version>
</dependency>
<!--test-->
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-test</artifactId>
<version>${spring.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
韩国,0,5525
美国,1,7566
中国,4,1769
中国,2,1847
美国,2,772
韩国,3,5162
日本,1,8759
俄罗斯,0,2418
中国,4,3326
中国,0,7121
韩国,3,9026
俄罗斯,4,4353
韩国,1,3498
俄罗斯,2,4598
美国,0,4493
日本,3,3888
日本,0,1025
中国,2,5249
韩国,0,1874
日本,3,269
韩国,1,1120
韩国,2,2122
俄罗斯,1,87
俄罗斯,1,2266
日本,1,3406
中国,0,3267
中国,1,3043
美国,2,8298
日本,1,661
中国,2,5533
美国,2,9335
日本,0,6108
俄罗斯,0,7445
韩国,4,8657
美国,4,1136
韩国,2,2608
俄罗斯,4,2988
日本,0,3345
俄罗斯,1,6977
中国,4,4012
中国,1,9337
韩国,2,3282
中国,2,5126
韩国,0,8946
俄罗斯,3,1688
韩国,0,9249
中国,1,1947
美国,2,5402
俄罗斯,0,8479
俄罗斯,4,2536