# Codeforces 1606F - Tree Queries (virtual tree + tree dp)

Obviously, the points we choose to delete together with \ (u \) will form a connected block. Otherwise, if we choose not to delete the points that are not in the same connected block with \ (u \), the answer must be better.

Note that if we choose to delete a son \ (V \) of \ (u \), the increment of the answer is \ (chd_v-1-k \), where \ (chd_v \) is the number of sons of node \ (V \). The initial answer is \ (chd_ \) is a fixed value, so our task can be equivalent to assigning a point weight \ (chd_-1-k \) to each point, and then finding a connected block on the tree with \ (u \) as the root, so that the sum of the weights of other points in the connected block except point \ (u \) is as large as possible.

Therefore, we have a method of complexity \ (\ Theta(n^2) \): run the tree \ (DP \) for each \ (K \), that is, let \ (dp_ \) represent the maximum value of the sum of weights of connected blocks rooted in \ (U \), then there is obviously a transfer \ (dp_ = chd_-1-k + \ sum \ limits {V \ in son_} \ max (dp_v, 0) \). Note that the point weights of the points on the boundary of the connected block on the tree (that is, all the points satisfying \ (x \) are in the connected block on the tree, but all the sons of \ (x \) are not in the connected block on the tree) must be positive. Otherwise, we deduct those points whose point weights are not positive and on the boundary of the connected block, and the answer must become better. Therefore, for each \ (K \), take out all points with positive point weight to build a virtual tree - because \ (\ sum \ limits {I = 1} ^ nchd_i = n-1 \), for all \ (K \), the total number of points with positive point weight is \ (\ Theta(n) \), and then run a tree \ (DP \) for them, and then query the maximum value of the sum of point weights of connected blocks with root \ (U \) each time, if point \ (U \) It is on the virtual tree itself. Just return the DP value of \ (U \) directly. Otherwise, we can find the node \ (V \) below the chain where \ (U \) is located on the virtual tree - this can take a set to store the DFS order of all nodes on the virtual tree, and then directly lower it in the set_ bound. Then the maximum value of the point weight sum of the connected block with \ (U \) as the root is the point weight sum of the path sum of \ (dp_v \) plus \ (fa_v\to u \).

Time complexity \ (\ Theta(n\log n) \).

const int MAXN=2e5;
const int LOG_N=18;
int n,qu,hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec=0;
int chd[MAXN+5],dep[MAXN+5],fa[MAXN+5][LOG_N+2];
int dfn[MAXN+5],tim=0,rid[MAXN+5],edt[MAXN+5];
void dfs0(int x,int f){
fa[x][0]=f;rid[dfn[x]=++tim]=x;
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];if(y==f) continue;
chd[x]++;dep[y]=dep[x]+1;dfs0(y,x);
} edt[x]=tim;
}
int sum_chd[MAXN+5];
void dfs_chd(int x,int f){
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];if(y==f) continue;
sum_chd[y]=sum_chd[x]+chd[y];
dfs_chd(y,x);
}
}
void lca_init(){
dep[1]=1;dfs0(1,0);
for(int i=1;i<=LOG_N;i++) for(int j=1;j<=n;j++)
fa[j][i]=fa[fa[j][i-1]][i-1];
}
int getlca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=LOG_N;~i;i--) if(dep[x]-(1<<i)>=dep[y]) x=fa[x][i];
if(x==y) return x;
for(int i=LOG_N;~i;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
vector<int> pos[MAXN+5];
int stk[MAXN+5],tp=0;set<int> st;
vector<pii> qv[MAXN+5];
ll calc_sum(int u,int v,int k){//u is ancestor of v
return (sum_chd[v]-sum_chd[u])-1ll*(k+1)*(dep[v]-dep[u]);
}
int calc_val(int u,int k){return chd[u]-1-k;}
vector<int> g[MAXN+5];
ll dp[MAXN+5],res[MAXN+5];
void insert(int x){
if(!tp) return stk[++tp]=x,void();
int lc=getlca(x,stk[tp]);
//	printf("LCA of %d %d is %d\n",x,stk[tp],lc);
//	printf("stack: ");
//	for(int i=1;i<=tp;i++) printf("%d%c",stk[i]," \n"[i==tp]);
while(tp>1&&dep[stk[tp-1]]>dep[lc]){
g[stk[tp-1]].pb(stk[tp]);--tp;
}
if(tp&&dep[stk[tp]]>dep[lc]){
g[lc].pb(stk[tp--]);
}
if(!tp||stk[tp]!=lc) stk[++tp]=lc;
stk[++tp]=x;
}
void fin(){
while(tp>=2){
g[stk[tp-1]].pb(stk[tp]);--tp;
} tp=0;
}
void dfs_dp(int x,int k){
dp[x]=calc_val(x,k);st.insert(dfn[x]);
for(int y:g[x]){
dfs_dp(y,k);
dp[x]+=max(dp[y]+calc_sum(x,y,k)-calc_val(y,k),0ll);
//		printf("calc_sum %d %d %d = %lld\n",x,y,k,calc_sum(x,y,k));
} //printf("DP %d %lld\n",x,dp[x]);
}
void clr(int x){for(int y:g[x]) clr(y);g[x].clear();}
int main(){
scanf("%d",&n);
lca_init();sum_chd[1]=chd[1];dfs_chd(1,0);scanf("%d",&qu);
//	for(int i=1;i<=n;i++) printf("chd[%d]=%d\n",i,chd[i]);
for(int i=1;i<=n;i++) for(int j=0;j<chd[i];j++) pos[j].pb(i);
for(int i=1,u,k;i<=qu;i++) scanf("%d%d",&u,&k),qv[k].pb(mp(u,i));
for(int i=0;i<=MAXN;i++){
st.clear();
if(!pos[i].empty()){
//			printf("solving %d\n",i);
sort(pos[i].begin(),pos[i].end(),[&](int x,int y){return dfn[x]<dfn[y];});
if(pos[i][0]!=1) insert(1);
for(int x:pos[i]) insert(x);
fin();dfs_dp(1,i);
}
for(pii p:qv[i]){
int u=p.fi,id=p.se;
set<int>::iterator it=st.lower_bound(dfn[u]);
if(it==st.end()||(*it)>edt[u]) res[id]=chd[u];
else{
int pt=rid[*it];
//				printf("%d %d\n",id,pt);
//				printf("%lld\n",calc_sum(fa[u][0],fa[pt][0],i));
res[id]=chd[u]+max(0ll,dp[pt]+calc_sum(fa[u][0],fa[pt][0],i)-calc_val(u,i));
}
}
clr(1);
}
for(int i=1;i<=qu;i++) printf("%lld\n",res[i]);
return 0;
}
/*
11
1 2
2 3
2 4
1 5
5 6
5 7
7 8
7 9
7 10
1 11
4
1 0
1 1
5 2
11 0
*/


Posted by Salsaboy on Mon, 01 Nov 2021 03:51:40 -0700