Group control of several common window functions in Hive

Keywords: Programming SQL Java Apache Spark

brief introduction

Of course, there is nothing to say about regular window functions. It's very simple. Here's an introduction to grouping, focusing on the usage of rows between after grouping and sorting.

The key is to understand the meaning of keywords in rows between:

Keyword Meaning
preceding Forward
following In the future
current row Current row
unbounded Start line
unbounded preceding Indicates starting from the front
unbounded following Indicates to the end point after

Let's look at some abstractions directly. Let's look at an example.

max

select country,time,charge,
max(charge) over (partition by country order by time) as normal,
max(charge) over (partition by country order by time rows between unbounded preceding and current row) as unb_pre_cur,
max(charge) over (partition by country order by time rows between 2 preceding and 1 following) as pre2_fol1,
max(charge) over (partition by country order by time rows between current row and unbounded following) as cur_unb_fol 
from temp

By default, it is calculated in the row before the current row of the grouping class.

rows between unbounded preceding and current row is the same as the default rows between 2 preceding and 1 following rows between current row and unbounded following

The meaning of rows between for avg, min, max, sum is basically the same.

Complete test code


import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.Before;
import org.junit.Test;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.time.LocalDate;
import java.time.format.DateTimeFormatter;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;

public class SparkHiveFunctionTest implements Serializable{

    private static final String DATA_PATH = "F:\\tmp\\charge.csv";

    private static final String DATA_OBJECT_PATH = "F:\\tmp\\charge";

    private static String[] counties = {"China","Russia","U.S.A","Japan","The Republic of Korea"};

    private SparkSession sparkSession;

    private Dataset<Charge> dataset;

    @Before
    public void setUp(){
        sparkSession = SparkSession
                .builder()
                .appName("test")
                .master("local")
                .getOrCreate();
        sparkSession.sparkContext().setLogLevel("WARN");
    }

    @Test
    public void start() throws IOException, ClassNotFoundException {
//        List<Charge> infos = getData(true);
        List<Charge> infos = getData(false);
        dataset = sparkSession.createDataset(infos, Encoders.bean(Charge.class));
//        sum();
//        avg();
//        min();
        max();

    }

    private void sum(){
        dataset.createOrReplaceTempView("temp");
        String sql = "select country,time,charge," +
                "sum(charge) over (partition by country order by time) as normal," +
                "sum(charge) over (partition by country order by time rows between unbounded preceding and current row) as unb_pre_cur," +
                "sum(charge) over (partition by country order by time rows between 2 preceding and 1 following) as pre2_fol1," +
                "sum(charge) over (partition by country order by time rows between current row and unbounded following) as cur_unb_fol" +
                " from temp";
        Dataset<Row> ds = sparkSession.sql(sql);
        ds.show(100);
    }

    private void avg(){
        dataset.createOrReplaceTempView("temp");
        String sql = "select country,time,charge," +
                "avg(charge) over (partition by country order by time) as normal," +
                "avg(charge) over (partition by country order by time rows between unbounded preceding and current row) as unb_pre_cur," +
                "avg(charge) over (partition by country order by time rows between 2 preceding and 1 following) as pre2_fol1," +
                "avg(charge) over (partition by country order by time rows between current row and unbounded following) as cur_unb_fol" +
                " from temp";
        Dataset<Row> ds = sparkSession.sql(sql);
        ds.show(100);
    }

    private void min(){
        dataset.createOrReplaceTempView("temp");
        String sql = "select country,time,charge," +
                "min(charge) over (partition by country order by time) as normal," +
                "min(charge) over (partition by country order by time rows between unbounded preceding and current row) as unb_pre_cur," +
                "min(charge) over (partition by country order by time rows between 2 preceding and 1 following) as pre2_fol1," +
                "min(charge) over (partition by country order by time rows between current row and unbounded following) as cur_unb_fol" +
                " from temp";
        Dataset<Row> ds = sparkSession.sql(sql);
        ds.show(100);
    }

    private void max(){
        dataset.createOrReplaceTempView("temp");
        String sql = "select country,time,charge," +
                "max(charge) over (partition by country order by time) as normal," +
                "max(charge) over (partition by country order by time rows between unbounded preceding and current row) as unb_pre_cur," +
                "max(charge) over (partition by country order by time rows between 2 preceding and 1 following) as pre2_fol1," +
                "max(charge) over (partition by country order by time rows between current row and unbounded following) as cur_unb_fol" +
                " from temp";
        Dataset<Row> ds = sparkSession.sql(sql);
        ds.show(100);
    }

    private static List<Charge> getData(Boolean newGen) throws IOException, ClassNotFoundException {
        if(newGen != null && newGen == true){
            return generateData();
        }else {
            return readList();
        }
    }

    private static List<Charge> generateData() throws IOException {
        FileWriter fileWriter = new FileWriter(DATA_PATH);
        LinkedList<Charge> infos = new LinkedList<>();
        Random random = new Random();
        LocalDate localDate = LocalDate.of(2020, 1, 4);
        DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
        for(int i=0;i<50;i++){
            Charge info = new Charge();
            String county = counties[random.nextInt(counties.length)];
            info.setCountry(county);
            int day = random.nextInt(10);
            LocalDate date = localDate.plusDays(day);
            String time = date.format(dateTimeFormatter);
            int charge = 10000 + random.nextInt(10000);
            info.setCharge(charge);
            info.setTime(time);
            infos.add(info);
            fileWriter.write(String.format("%s,%s,%d\n",county,time,charge));
        }
        fileWriter.flush();
        writeList(infos);
        return infos;
    }

    private static void writeList(LinkedList<Charge> infos) throws IOException {
        FileOutputStream fos = new FileOutputStream(DATA_OBJECT_PATH);
        ObjectOutputStream oos = new ObjectOutputStream(fos);
        oos.writeObject(infos);
    }

    private static LinkedList<Charge> readList() throws IOException, ClassNotFoundException {
        FileInputStream fis = new FileInputStream(DATA_OBJECT_PATH);
        ObjectInputStream ois = new ObjectInputStream(fis);
        LinkedList<Charge> list = (LinkedList) ois.readObject();
        return list;
    }

    /**
     * Must be public, must implement Serializable
     */
    public static class Charge implements Serializable {
        /**
         * Country
         */
        private String country;

        /**
         * Recharge time
         */
        private String time;
        /**
         * Recharge amount
         */
        private Integer charge;

        public String getCountry() {
            return country;
        }

        public void setCountry(String country) {
            this.country = country;
        }

        public String getTime() {
            return time;
        }

        public void setTime(String time) {
            this.time = time;
        }

        public Integer getCharge() {
            return charge;
        }

        public void setCharge(Integer charge) {
            this.charge = charge;
        }
    }
}

Posted by skyxmen on Thu, 09 Jan 2020 07:26:16 -0800