本文会介绍几个hive中关于排序的非常有用的窗口函数,它们可以帮助处理TopN,前N%这类问题,

更酷炫的是,它们还支持分组、排序, 前几不是问题,我们order by也可以解决。但是分组之后的前几能够 帮助我们极大的简化工作量。

我们后面有一个测试程序可以生成数据,测试本文要介绍的函数,这个程序并不需要依赖安装hive与spark, 只需要导入后面pom文件中的依赖就可以了。

下面是本文用到的数据,可以用于和后面算法结果对比,当然后面也有这些数据的csv格式数据,可以自己保存导入测试。

row_number

row_number() over ([partition col1] [order by col2])
row_number() over ([partition col1,col2,...] [order by col1,col2,...])
select country,gid,uv, row_number() over(partition by country order by uv desc) rank from temp

row_number顾名思义就是添加行号,配合分组排序,就可以计算每个分组中的前N个。

row_number既然是行号,那么就不会有重复的。

partition可以指定多个字段,但是谨慎使用多字段,数据量大的情况下,多指定一个partition会大大的加大统计时间。

rank() over ([partition col1] [order by col2])
select country,gid,uv, rank() over(partition by country order by gid desc) rank from temp

rank和row_number基本一致,唯一区别就是排序值可能不同,rank排序的值可以重复,

例如可以有2个并列第1,然后没有第2,直接第3,毕竟很多比赛只想奖励前3个人,而不是前3名。

dense_rank

dense_rank() over ([partition col1] [order by col2])
select country,gid,uv, dense_rank() over(partition by country order by gid desc) rank from temp

dense_rank和rank差不多,不过排序值有所不同,dense_rank允许并列,同时排序是顺序的,

例如有2个并列第1,那么第3个人的排名是第2,这点和rank不同。

关于row_number、rank、dense_rank的区别看下面一张图就明白了:

percent_rank

percent_rank() over ([partition col1] [order by col2])
select country,gid,uv, percent_rank() over(partition by country order by uv desc) rank from temp

percent_rank也是非常有用的,举2个简单的例子:

  • 一次比赛如果我不想取前几名,而是想让前5%为一等奖,6%-15%为二等奖,%16-30%为三等奖,怎么处理?
  • 一些数据需要预处理,我要过滤掉前10%和最后10%的数据,怎么处理?
  • 如果使用row_number、rank、dense_rank还得取计算总数,非常麻烦。使用percent_rank轻松搞定。

    percent_rank的算法如下:

    percent_rank = 分组内当前行的rank值-1/分组内总行数-1
    

    ntile

    select country,gid,uv, ntile(10) over(partition by country order by uv desc) slice from temp
    

    ntile(n),用于将分组数据按照顺序切分成n片,返回当前切片值,如果切片不均匀,默认增加第一个切片的分布

    对于前面介绍的percent_rank函数,肯定有朋友吐槽,我要取前10%的数据,percent_rank的算法不精确啊。

    其实对于分组数据比较多的情况,percent_rank的算法还是非常接近的,并且分组数据越大越准确,所以一般情况下percent_rank就够用了。

    如果非要精确的百分比,就可以考虑ntile,要前10%的数据,就把分片分为10组,然后取分组值为1的数据。其实这样也是有一些问题的,因为分片可能不均匀。

    所以不用强求,都用百分百了还要强求精确值干什么?业务就可以避免,要绝对精确就使用rank。

    cume_dist

    select country,gid,uv, cume_dist() over(partition by country order by uv desc) top_percent from temp
    

    cume_dist主要用于计算在top多少里面。比如要计算某人的工资在前10%、还是前20%、还是前30%里面就可以使用cume_dist。

    cume_dist计算公式 :

    cume_dist = 小于等于当前值的行数/分组内总行数
    

    first_value

    select country,gid,uv, first_value(uv) over(partition by country order by uv desc) first from temp
    

    取分组排序后,截止到当前行,第一个值。注意不是最大值最小值,最大值最小值最后使用rank计算。

    last_value

    select country,gid,uv, last_value(uv) over(partition by country order by uv desc) last from temp
    

    取分组内排序后,截止到当前行,最后一个值

    和first_value的基本属于同类但是相反的操作。

    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Encoders;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.SparkSession;
    import org.junit.Before;
    import org.junit.Test;
    import java.io.FileInputStream;
    import java.io.FileOutputStream;
    import java.io.FileWriter;
    import java.io.IOException;
    import java.io.ObjectInputStream;
    import java.io.ObjectOutputStream;
    import java.io.Serializable;
    import java.util.LinkedList;
    import java.util.List;
    import java.util.Random;
    public class SparkRankTest implements Serializable{
        private static final String DATA_PATH = "F:\\tmp\\data.csv";
        private static final String DATA_OBJECT_PATH = "F:\\tmp\\info";
        private static String[] counties = {"中国","俄罗斯","美国","日本","韩国"};
        private SparkSession sparkSession;
        private Dataset<Info> dataset;
        @Before
        public void setUp(){
            sparkSession = SparkSession
                    .builder()
                    .appName("test")
                    .master("local")
                    .getOrCreate();
            sparkSession.sparkContext().setLogLevel("WARN");
        @Test
        public void start() throws IOException, ClassNotFoundException {
    //        List<Info> infos = getData(true);
            List<Info> infos = getData(false);
            dataset = sparkSession.createDataset(infos, Encoders.bean(Info.class));
    //        rowNumber();
    //        rank();
            denseRank();
    //        percentRank();
    //        ntile();
    //        cumeDist();
    //        firstValue();
    //        lastValue();
        private void rowNumber(){
            dataset.createOrReplaceTempView("temp");
            String sql = "select country,gid,uv, row_number() over(partition by country order by uv desc) rank from temp";
            Dataset<Row> ds = sparkSession.sql(sql);
    //        ds = ds.filter("rank < 0.95");
            ds.show(100);
        private void rank(){
            dataset.createOrReplaceTempView("temp");
            String sql = "select country,gid,uv, rank() over(partition by country order by gid desc) rank from temp";
            Dataset<Row> ds = sparkSession.sql(sql);
            ds.show(100);
        private void denseRank(){
            dataset.createOrReplaceTempView("temp");
            String sql = "select country,gid,uv, dense_rank() over(partition by country order by gid desc) rank from temp";
            Dataset<Row> ds = sparkSession.sql(sql);
            ds.show(100);
        private void percentRank(){
            dataset.createOrReplaceTempView("temp");
            String sql = "select country,gid,uv, percent_rank() over(partition by country order by uv desc) rank from temp";
            Dataset<Row> ds = sparkSession.sql(sql);
            ds.show(100);
        private void ntile(){
            dataset.createOrReplaceTempView("temp");
            String sql = "select country,gid,uv, ntile(10) over(partition by country order by uv desc) slice from temp";
            Dataset<Row> ds = sparkSession.sql(sql);
            ds = ds.filter("slice < 9");
            ds.show(100);
        private void cumeDist(){
            dataset.createOrReplaceTempView("temp");
            String sql = "select country,gid,uv, cume_dist() over(partition by country order by uv desc) top_percent from temp";
            Dataset<Row> ds = sparkSession.sql(sql);
            ds.show(100);
        private void firstValue(){
            dataset.createOrReplaceTempView("temp");
            String sql = "select country,gid,uv, first_value(uv) over(partition by country order by uv desc) first from temp";
            Dataset<Row> ds = sparkSession.sql(sql);
            ds.show(100);
        private void lastValue(){
            dataset.createOrReplaceTempView("temp");
            String sql = "select country,gid,uv, last_value(uv) over(partition by country order by uv desc) last from temp";
            Dataset<Row> ds = sparkSession.sql(sql);
            ds.show(100);
        private static List<Info> getData(Boolean newGen) throws IOException, ClassNotFoundException {
            if(newGen != null && newGen == true){
                return generateData();
            }else {
                return readList();
        private static List<Info> generateData() throws IOException {
            FileWriter fileWriter = new FileWriter(DATA_PATH);
            LinkedList<Info> infos = new LinkedList<>();
            Random random = new Random();
            for(int i=0;i<50;i++){
                Info info = new Info();
                String county = counties[random.nextInt(counties.length)];
                info.setCountry(county);
                int gid = random.nextInt(5);
                info.setGid(gid);
                int uv = random.nextInt(10000);
                info.setUv(uv);
                infos.add(info);
                fileWriter.write(String.format("%s,%d,%d\n",county,gid,uv));
            fileWriter.flush();
            writeList(infos);
            return infos;
        private static void writeList(LinkedList<Info> infos) throws IOException {
            FileOutputStream fos = new FileOutputStream(DATA_OBJECT_PATH);
            ObjectOutputStream oos = new ObjectOutputStream(fos);
            oos.writeObject(infos);
        private static LinkedList<Info> readList() throws IOException, ClassNotFoundException {
            FileInputStream fis = new FileInputStream(DATA_OBJECT_PATH);
            ObjectInputStream ois = new ObjectInputStream(fis);
            LinkedList<Info> list = (LinkedList) ois.readObject();
            return list;
         * 必须public,必须实现Serializable
        public static class Info implements Serializable {
            private String country;
             * 分组id
            private Integer gid;
             * 活跃用户
            private Integer uv;
            public String getCountry() {
                return country;
            public void setCountry(String country) {
                this.country = country;
            public Integer getGid() {
                return gid;
            public void setGid(Integer gid) {
                this.gid = gid;
            public Integer getUv() {
                return uv;
            public void setUv(Integer uv) {
                this.uv = uv;
    

    pom文件

    <project xmlns="http://maven.apache.org/POM/4.0.0"
             xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
             xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
            http://maven.apache.org/xsd/maven-4.0.0.xsd">
        <modelVersion>4.0.0</modelVersion>
        <groupId>org.curitis</groupId>
        <artifactId>spark-learn</artifactId>
        <version>1.0.0</version>
        <packaging>jar</packaging>
        <properties>
            <spring.version>5.1.8.RELEASE</spring.version>
            <junit.version>4.11</junit.version>
            <spark.version>2.4.3</spark.version>
        </properties>
        <dependencies>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-core_2.12</artifactId>
                <version>${spark.version}</version>
            </dependency>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-sql_2.12</artifactId>
                <version>${spark.version}</version>
            </dependency>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-hive_2.12</artifactId>
                <version>${spark.version}</version>
            </dependency>
            <dependency>
                <groupId>com.alibaba</groupId>
                <artifactId>fastjson</artifactId>
                <version>1.2.56</version>
            </dependency>
            <!--test-->
            <dependency>
                <groupId>org.springframework</groupId>
                <artifactId>spring-test</artifactId>
                <version>${spring.version}</version>
                <scope>test</scope>
            </dependency>
            <dependency>
                <groupId>junit</groupId>
                <artifactId>junit</artifactId>
                <version>${junit.version}</version>
                <scope>test</scope>
            </dependency>
        </dependencies>
        <build>
            <plugins>
                <plugin>
                    <groupId>org.apache.maven.plugins</groupId>
                    <artifactId>maven-compiler-plugin</artifactId>
                    <configuration>
                        <source>8</source>
                        <target>8</target>
                    </configuration>
                </plugin>
            </plugins>
        </build>
    </project>
    
    韩国,0,5525
    美国,1,7566
    中国,4,1769
    中国,2,1847
    美国,2,772
    韩国,3,5162
    日本,1,8759
    俄罗斯,0,2418
    中国,4,3326
    中国,0,7121
    韩国,3,9026
    俄罗斯,4,4353
    韩国,1,3498
    俄罗斯,2,4598
    美国,0,4493
    日本,3,3888
    日本,0,1025
    中国,2,5249
    韩国,0,1874
    日本,3,269
    韩国,1,1120
    韩国,2,2122
    俄罗斯,1,87
    俄罗斯,1,2266
    日本,1,3406
    中国,0,3267
    中国,1,3043
    美国,2,8298
    日本,1,661
    中国,2,5533
    美国,2,9335
    日本,0,6108
    俄罗斯,0,7445
    韩国,4,8657
    美国,4,1136
    韩国,2,2608
    俄罗斯,4,2988
    日本,0,3345
    俄罗斯,1,6977
    中国,4,4012
    中国,1,9337
    韩国,2,3282
    中国,2,5126
    韩国,0,8946
    俄罗斯,3,1688
    韩国,0,9249
    中国,1,1947
    美国,2,5402
    俄罗斯,0,8479
    俄罗斯,4,2536