[Splay] [Learning Notes]

Keywords: C++

Rats

Because Treap is learned first, and there are many similarities between splay and treap, so there will be many places that will not be very detailed. For Treap and balance tree reference This blog

About splay

splay, also known as stretching tree, is a binary sort tree, which can insert, find and delete operations in O(log n). It was created by Daniel Sleator and Robert Tarjan. Stretching tree is a self-adjusting binary search tree, which moves the node to the tree root through a series of rotations along the path from a node to the tree root.

Compared with other balanced trees, splay is more powerful and can deal with interval problems. It can be said that splays that other balanced trees can do can almost do. So many big guys say that the balance tree can write splays. The only drawback may be that the constant is larger than treap.

Definition

struct node {
    int ch[2],val,siz,cnt,pre;
}TR[N];

ch[0/1] is the two sons of the current node. val is the value of the current node. siz is the size of the subtree rooted by the current node, and cnt is the number of occurrences of the value of the current node. pre is the parent of the current node

rotate

splay also keeps balance by rotating. spaly's rotation is also easy to understand.

As shown in the figure, now let's rotate 4 to 2. That is to say, we need to rotate Node 4 up.

Step 1: Disconnect the 2-4 side and change 8 to 2's right son.

Step 2: Disconnect the 1-2 side and change 4 to 1's right son.

Step 3: Turn 2 into 4's left son

Step 4: Update Nodes 4 and 2 to complete rotation

void rotate(int cur) {
    int f = getwh(cur),fa = TR[cur].pre,gr = TR[fa].pre;
    TR[gr].ch[getwh(fa)] = cur;
    TR[cur].pre = gr;
    TR[fa].ch[f] = TR[cur].ch[f ^ 1];
    TR[TR[cur].ch[f ^ 1]].pre = fa;
    TR[fa].pre = cur;
    TR[cur].ch[f ^ 1] = fa;
    up(fa);
    up(cur);
}

Among them, getwh is used to get whether the current point is his father's left son or right son, fa is the current point's father, and gr is the current point's grandfather.

The getwh code is as follows

int getwh(int cur) {
    return TR[TR[cur].pre].ch[1] == cur;
}

stretch

Compared with treap, splay has a very important operation - stretching operation.

The so-called stretching is to move a node to a position (which is usually the root) that he wants to reach through a series of rotations.

There are three types of stretching for splay (what's the use of stretching, as I'll mention later, just know his role as above).

In the first case:

As shown in the figure, the x-node should be moved below his grandfather's node. In this case, just rotate the x-node once.

The second case

The x node should be moved below the node above his grandfather's node, and his grandfather, and his father, are in the same line with him.

What is on the same line???

As shown in the figure, now g, p and X are on the same line, and then x is transferred to the right. Just rotate p up now, and then say x up again.

In the third case,

The x node should be moved below the node above his grandfather's node, and his grandfather, and his father, are not in the same line with him.

As shown in the figure, now we just need to rotate x to the position of p and then x to the position of g.

PS

The blogger's practice has proved that the second and third situations can be operated through the third situation. As for why the second kind of operation is not so good, it is probably to maintain the balance of the tree. But it's slower.

In summary

We can get stretched code (when to is 0, it rotates to root)

void splay(int cur,int to) {
    while(TR[cur].pre != to) {
        if(TR[TR[cur].pre].pre != to) {
             if(getwh(cur) == getwh(TR[cur].pre)) rotate(TR[cur].pre);
             else rotate(cur);
        }
        rotate(cur);
    }
    if(!to) rt = cur;
}

insert

The insertion of splay is similar to that of treap. That is to constantly find the appropriate location of the current point. If you have it before, you can use cnt++, otherwise you can create a new node.

Finally, don't forget to stretch the newly inserted node to its root.

void insert(int cur,int val,int lst) {
    if(!cur) {
        cur = ++tot;
        TR[cur].pre = lst;
        TR[cur].siz = TR[cur].cnt = 1;
        TR[cur].val = val;
        TR[lst].ch[val > TR[lst].val] = cur;
        splay(cur,0);
        return;
    }
    TR[cur].siz++;
    if(val == TR[cur].val) {TR[cur].cnt++;return;}
    if(val > TR[cur].val) insert(rs,val,cur);
    else insert(ls,val,cur);
}

merge

Merge operations are not available in treap. Merge in splay is mainly to prepare for deletion

The so-called merger is to combine two subtrees into one. The premise that two subtrees can be merged is that all elements in one tree are larger than those in the other.

In fact, it's very simple. If all the elements in the x subtree are smaller than all the elements in the Y subtree, then you just need to find the most right (that is, the largest) node of the x subtree, and then change y into the right child of the node.

Finally, the y node or the largest node in the x subtree should be extended to the root.

void merge(int cur,int y) {
    if(TR[cur].val > TR[y].val) swap(cur,y);
    if(!cur) {
        rt = y;
        return;
    }
    while(rs) cur = rs;
    splay(cur,0);
    rs = y;
    TR[y].pre = cur;
    up(cur);
}

Search node

This is also an operation to assist other operations. The function is to find the node whose weight is val.

Simply, it is a search operation on the binary search tree. If it is larger than the current node, it will look up the right subtree, otherwise it will look up the left subtree. As large as the current node, the scope is enough.

int getpos(int cur,int val) {
    int lst;
    while(cur) {
        lst = cur;
        if(TR[cur].val == val) return cur;
        cur = TR[cur].ch[val > TR[cur].val];
    }
    return lst;
}

delete

With the merge operation, the deletion is done very well. First find the node to delete, and then extend the node to the root. Then merge the left and right subtrees of the node.

void del(int cur,int val) {
    cur = getpos(rt,val);
    if(!cur) return;
    if(TR[cur].val != val) return;
    splay(cur,0);
    if(TR[cur].cnt > 1) {TR[cur].cnt--;TR[cur].siz--;return;}
    TR[ls].pre = TR[rs].pre = 0;
    merge(ls,rs);
}

Query ranking

Find the current node with the lookup operation, and then rotate to the root. The size of the left subtree + 1 is the ranking of the node.

Query number k

Like treap, if K is larger than the left subtree size + the number of current nodes, look for k-Left subtree Size-the number of current nodes in the right subtree. If k<= the size of the left subtree, then look directly for the largest K in the left subtree. Otherwise, return to the current point.

int kth(int cur,int x) {
    while(cur) {
        if(x <= TR[ls].siz) cur = ls;
        else if(x > TR[ls].siz + TR[cur].cnt) x -= TR[cur].cnt + TR[ls].siz,cur = rs;
        else return TR[cur].val;
    }
    return cur;
}

precursor

Find the point to query and stretch to the root. Then find the maximum value in the left subtree.

int pred(int cur,int val) {
    cur = getpos(rt,val);
    if(TR[cur].val < val) return TR[cur].val;
    splay(cur,0);
    cur = ls;
    while(rs) cur = rs;
    return TR[cur].val;
}

Successor

Find the point to query and stretch to the root. Then find the minimum value in the right subtree.

int nex(int cur,int val) {
    cur = getpos(rt,val);
    if(TR[cur].val > val) return TR[cur].val;
    splay(cur,0);
    cur = rs;
    while(ls) cur = ls;
    return TR[cur].val;
}

Complete code

#include<cstdio>
#include<iostream>
using namespace std;
typedef long long ll;
const int N = 100000 + 100;
#define ls TR[cur].ch[0]
#define rs TR[cur].ch[1]
ll read() {
    ll x = 0,f = 1;char c = getchar();
    while(c < '0' || c > '9') {
        if(c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9') {
        x = x * 10 + c -'0';
        c = getchar();
    }
    return x * f;
}
int rt;
struct node {
    int ch[2],val,siz,cnt,pre;
}TR[N];
void up(int cur) {
    TR[cur].siz = TR[ls].siz + TR[rs].siz + TR[cur].cnt;
}
int getwh(int cur) {
    return TR[TR[cur].pre].ch[1] == cur;
}
void rotate(int cur) {
    int f = getwh(cur),fa = TR[cur].pre,gr = TR[fa].pre;
    TR[gr].ch[getwh(fa)] = cur;
    TR[cur].pre = gr;
    TR[fa].ch[f] = TR[cur].ch[f ^ 1];
    TR[TR[cur].ch[f ^ 1]].pre = fa;
    TR[fa].pre = cur;
    TR[cur].ch[f ^ 1] = fa;
    up(fa);
    up(cur);
}
void splay(int cur,int to) {
    while(TR[cur].pre != to) {
        if(TR[TR[cur].pre].pre != to) {
            // if(getwh(cur) == getwh(TR[cur].pre)) rotate(TR[cur].pre);
            // else 
                rotate(cur);
        }
        rotate(cur);
    }
    if(!to) rt = cur;
}
int tot;
void insert(int cur,int val,int lst) {
    if(!cur) {
        cur = ++tot;
        TR[cur].pre = lst;
        TR[cur].siz = TR[cur].cnt = 1;
        TR[cur].val = val;
        TR[lst].ch[val > TR[lst].val] = cur;
        splay(cur,0);
        return;
    }
    TR[cur].siz++;
    if(val == TR[cur].val) {TR[cur].cnt++;return;}
    if(val > TR[cur].val) insert(rs,val,cur);
    else insert(ls,val,cur);
}
void merge(int cur,int y) {
    if(TR[cur].val > TR[y].val) swap(cur,y);
    if(!cur) {
        rt = y;
        return;
    }
    while(rs) cur = rs;
    splay(cur,0);
    rs = y;
    TR[y].pre = cur;
    up(cur);
}
int getpos(int cur,int val) {
    int lst;
    while(cur) {
        lst = cur;
        if(TR[cur].val == val) return cur;
        cur = TR[cur].ch[val > TR[cur].val];
    }
    return lst;
}
void del(int cur,int val) {
    cur = getpos(rt,val);
    if(!cur) return;
    if(TR[cur].val != val) return;
    splay(cur,0);
    if(TR[cur].cnt > 1) {TR[cur].cnt--;TR[cur].siz--;return;}
    TR[ls].pre = TR[rs].pre = 0;
    merge(ls,rs);
}
int Rank(int cur,int val) {
    cur = getpos(rt,val);
    splay(cur,0);
    return TR[ls].siz + 1;
}
int kth(int cur,int x) {
    while(cur) {
        if(x <= TR[ls].siz) cur = ls;
        else if(x > TR[ls].siz + TR[cur].cnt) x -= TR[cur].cnt + TR[ls].siz,cur = rs;
        else return TR[cur].val;
    }
    return cur;
}
int pred(int cur,int val) {
    cur = getpos(rt,val);
    if(TR[cur].val < val) return TR[cur].val;
    splay(cur,0);
    cur = ls;
    while(rs) cur = rs;
    return TR[cur].val;
}
int nex(int cur,int val) {
    cur = getpos(rt,val);
    if(TR[cur].val > val) return TR[cur].val;
    splay(cur,0);
    cur = rs;
    while(ls) cur = ls;
    return TR[cur].val;
}
int main() {
    int n = read();
    while(n--) {
        int opt = read(),x = read();
        if(opt == 1) insert(rt,x,0);
        if(opt == 2) del(rt,x);
        if(opt == 3) printf("%d\n",Rank(rt,x));
        if(opt == 4) printf("%d\n",kth(rt,x));
        if(opt == 5) printf("%d\n",pred(rt,x));
        if(opt == 6) printf("%d\n",nex(rt,x));
    }
    return 0;
}

Posted by EnDee321 on Wed, 20 Mar 2019 12:03:28 -0700