Learning notes: tree chain subdivision

Keywords: Algorithm data structure

preface

Tree chain subdivision divides the edges of the tree into many chains, so as to reduce the complexity of modifying queries on the tree.
This is an introduction to light and heavy chain segmentation.
Concept:
Heavy son: the son with the most nodes in the subtree. If the subtrees of two sons are the same, then any one of them.
Light son: the rest of the sons.
Heavy side: father to heavy son.
Light edge: the remaining edges.
Heavy chain: path from node to heavy chain.

principle

Template transfer gateluoguP3384
analysis:
To modify and query a segment of chain or subtree, it is easy to think of using segment tree maintenance. How should we decompose the tree so that we can modify and query it on the segment tree. This requires tree chain subdivision.

Tree chain subdivision

We decompose by 2 DFS.

First DFS

dep [] represents the depth of the node, far [] represents the parent node of the node, size [] represents the number of subtree nodes with x as the root, and son [] represents the multiple son of x.
Code implementation:
dep [] is easy to implement, and the depth can be + 1 for each DFS.
far [] can also be implemented. Write down its father every DFS and connect the current node with the father.
size [] each time it is initialized to 1, the number of subtrees of the left and right sons can be accumulated when DFS is backtracked.
Recurrence equation: s i z e [ x ] = s i z e [ l ] + s i z e [ r ] size[x]=size[l]+size[r] size[x]=size[l]+size[r]
Son [] is similar to the center of gravity of the point divide and conquer tree. Since each node has at most two child nodes, we only need to compare the two nodes. Whoever has a larger size [] is the heavy son.
code:

inline void dfs1(int x,int fa,int deep){
	dep[x]=deep;
	far[x]=fa;
	size[x]=1;
	int Max=-1;
	for(int i=first[x];i;i=nex[i]){
		int y=to[i];
		if(y==fa) continue;
		dfs1(y,x,deep+1);
		size[x]+=size[y];
		if(size[x]>Max) son[x]=y,Max=size[y];
	}
}

Second DFS

Since the number corresponding to the point is discontinuous after our decomposition, we need to use the new number to store it.
We use id [] to represent the new number corresponding to node x, wt [] to represent the weight of the corresponding point of the node, and top [] to represent the top node on the chain.
Note: each chain starts with a light son.
Code implementation:
id[],wt [] can be updated every time.
Top [] since each chain starts from the light son, we update the top value for the light son.
We found that starting from the light son has always been the heavy son to the end, so as to ensure the continuity of the numbers on a chain, because the segment tree needs to be maintained.
So we should find the heavy son first, and then the light son.
code:

inline void dfs2(int x,int topf){
	id[x]=++cnt;
	wt[cnt]=w[x];
	top[x]=topf;
	if(son[x]==0) return;
	dfs2(son[x],topf);
	for(int i=first[x];i;i=nex[i]){
		int y=to[i];
		if(y==far[x] || y==son[x]) continue;
		dfs2(y,y);
	}
}

Segment tree maintenance

Maintenance chain

For modifying or querying the path between two nodes, we find that:
If two nodes are already on a heavy chain, because the numbers are continuous, they can be updated directly in the segment tree.
If two nodes are not on a heavy chain, that is, their top ends are different, we can modify a chain first, and then jump to see if they are on a heavy chain. If we are executing 1, otherwise repeat.
Because we have to jump up, we have to go up from a deeper point. For the same chain, the number should also be from small to large, which can be exchanged with swap.
The query operation is similar.

inline void update1(int x,int y,int val){
	val%=mod;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		T.update(1,id[top[x]],id[x],val);
		x=far[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	T.update(1,id[x],id[y],val);
}

inline int query1(int x,int y){
	int ans=0;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		ans+=T.query(1,id[top[x]],id[x])%mod;
		x=far[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	ans=(ans+T.query(1,id[x],id[y]))%mod;
	return ans;
}

Maintain subtree

Because of DFS, all nodes of a subtree must be continuous.
Then the minimum number is the root, and the maximum number is the number of roots plus the number of subtrees - 1.

inline void update2(int x,int val){
	T.update(1,id[x],id[x]+size[x]-1,val);
}

inline int query2(int x){
	return T.query(1,id[x],id[x]+size[x]-1)%mod;
}

Segment tree

The segment tree is a standard interval update. For interval query, use lazy_tag

struct TREE{
	struct node{
		int l,r,w,lz;
	}t[N<<2];
	inline void pushdown(int k){
		if(t[k].lz){
			t[lc].lz+=t[k].lz;
			t[rc].lz+=t[k].lz;
			t[lc].w=(t[lc].w+t[k].lz*(t[lc].r-t[lc].l+1))%mod;
			t[rc].w=(t[rc].w+t[k].lz*(t[rc].r-t[rc].l+1))%mod;
			t[k].lz=0;
		}
	}
	inline void build(int k,int l,int r){
		t[k].l=l,t[k].r=r;
		if(l==r){
			t[k].w=wt[l]%mod;
			return;
		}
		int mid=(l+r)>>1;
		build(lc,l,mid);
		build(rc,mid+1,r);
		t[k].w=(t[lc].w+t[rc].w)%mod;
	}
	inline void update(int k,int l,int r,int val){
		if(t[k].l>=l && t[k].r<=r){
			t[k].lz=(t[k].lz+val)%mod;
			t[k].w=(t[k].w+val*(t[k].r-t[k].l+1)%mod)%mod;
			return;
		}
		pushdown(k);
		int mid=(t[k].l+t[k].r)>>1;
		if(l<=mid) update(lc,l,r,val);
		if(r>mid) update(rc,l,r,val);
		t[k].w=(t[lc].w+t[rc].w)%mod;
	}
	inline int query(int k,int l,int r){
		if(t[k].l>=l && t[k].r<=r) return t[k].w;
		pushdown(k);
		int mid=(t[k].l+t[k].r)>>1,sum=0;
		if(l<=mid) sum=(sum+query(lc,l,r))%mod;
		if(r>mid) sum=(sum+query(rc,l,r))%mod;
		return sum;
	}
}T;

Complete code

#include<bits/stdc++.h>
using namespace std;
#define lc k<<1
#define rc k<<1|1

const int N=1e5+5,M=2e5+5;
int n,m,rt,mod;
int first[N],nex[M],to[M],w[M],tot;
int son[N],id[N],far[N],cnt,dep[N],size[N],top[N],wt[N];

struct TREE{
	struct node{
		int l,r,w,lz;
	}t[N<<2];
	inline void pushdown(int k){
		if(t[k].lz){
			t[lc].lz+=t[k].lz;
			t[rc].lz+=t[k].lz;
			t[lc].w=(t[lc].w+t[k].lz*(t[lc].r-t[lc].l+1))%mod;
			t[rc].w=(t[rc].w+t[k].lz*(t[rc].r-t[rc].l+1))%mod;
			t[k].lz=0;
		}
	}
	inline void build(int k,int l,int r){
		t[k].l=l,t[k].r=r;
		if(l==r){
			t[k].w=wt[l]%mod;
			return;
		}
		int mid=(l+r)>>1;
		build(lc,l,mid);
		build(rc,mid+1,r);
		t[k].w=(t[lc].w+t[rc].w)%mod;
	}
	inline void update(int k,int l,int r,int val){
		if(t[k].l>=l && t[k].r<=r){
			t[k].lz=(t[k].lz+val)%mod;
			t[k].w=(t[k].w+val*(t[k].r-t[k].l+1)%mod)%mod;
			return;
		}
		pushdown(k);
		int mid=(t[k].l+t[k].r)>>1;
		if(l<=mid) update(lc,l,r,val);
		if(r>mid) update(rc,l,r,val);
		t[k].w=(t[lc].w+t[rc].w)%mod;
	}
	inline int query(int k,int l,int r){
		if(t[k].l>=l && t[k].r<=r) return t[k].w;
		pushdown(k);
		int mid=(t[k].l+t[k].r)>>1,sum=0;
		if(l<=mid) sum=(sum+query(lc,l,r))%mod;
		if(r>mid) sum=(sum+query(rc,l,r))%mod;
		return sum;
	}
}T;

inline void add(int x,int y){
	nex[++tot]=first[x];
	first[x]=tot;
	to[tot]=y;
}

inline void update1(int x,int y,int val){
	val%=mod;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		T.update(1,id[top[x]],id[x],val);
		x=far[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	T.update(1,id[x],id[y],val);
}

inline int query1(int x,int y){
	int ans=0;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		ans+=T.query(1,id[top[x]],id[x])%mod;
		x=far[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	ans=(ans+T.query(1,id[x],id[y]))%mod;
	return ans;
}

inline void update2(int x,int val){
	T.update(1,id[x],id[x]+size[x]-1,val);
}

inline int query2(int x){
	return T.query(1,id[x],id[x]+size[x]-1)%mod;
}

inline void dfs1(int x,int fa,int deep){
	dep[x]=deep;
	far[x]=fa;
	size[x]=1;
	int Max=-1;
	for(int i=first[x];i;i=nex[i]){
		int y=to[i];
		if(y==fa) continue;
		dfs1(y,x,deep+1);
		size[x]+=size[y];
		if(size[x]>Max) son[x]=y,Max=size[y];
	}
}

inline void dfs2(int x,int topf){
	id[x]=++cnt;
	wt[cnt]=w[x];
	top[x]=topf;
	if(son[x]==0) return;
	dfs2(son[x],topf);
	for(int i=first[x];i;i=nex[i]){
		int y=to[i];
		if(y==far[x] || y==son[x]) continue;
		dfs2(y,y);
	}
}

int main(){
	scanf("%d%d%d%d",&n,&m,&rt,&mod);
	int x,y,z;
	for(int i=1;i<=n;i++) scanf("%d",&w[i]);
	for(int i=1;i<n;i++){
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs1(rt,0,1);
	dfs2(rt,rt);
	T.build(1,1,n);
	int cas;
	while(m--){
		scanf("%d",&cas);
		if(cas==1){
			scanf("%d%d%d",&x,&y,&z);
			update1(x,y,z);
		}
		else if(cas==2){
			scanf("%d%d",&x,&y);
			cout<<query1(x,y)<<endl;
		}
		else if(cas==3){
			scanf("%d%d",&x,&y);
			update2(x,y);
		}
		else if(cas==4){
			scanf("%d",&x);
			cout<<query2(x)<<endl;
		}
	}
	return 0;
}

Posted by FuriousIrishman on Wed, 10 Nov 2021 05:51:17 -0800