gym-102040 F Path Intersection

Title: https://vjudge.net/problem/Gym-102040F

Question: Find out the number of common points of k paths on a tree

Idea: The last common point must be a continuous interval and the number of times it passes must be K. The number of queries on the online segment tree is k, and the maximum and minimum values of the interval are recorded. When the maximum and minimum values are equal and equal to k, the whole interval contributes. Because K is relatively small, it is not necessary to empty the whole tree every time, record K paths, and then subtract them.

Code:

#include <bits/stdc++.h>
#define ls rt<<1
#define rs rt<<1|1
using namespace std;
const int maxn = 1e4+5;

vector<int>E[maxn];
int Index, d[maxn], top[maxn], son[maxn], f[maxn], siz[maxn], id[maxn], rk[maxn];
void dfs1(int u, int fa){
    siz[u] = 1; son[u] = 0;
    for(auto it : E[u]){
        int v = it;
        if(v == fa) continue;
        d[v] = d[u] + 1; f[v] = u;
        dfs1(v, u);
        siz[u] += siz[v];
        if(siz[v] > siz[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int rt){
    id[u] = ++Index; rk[Index] = u; top[u] = rt;
    if(son[u]) dfs2(son[u], rt);
    for(auto it : E[u]){
        if(it == f[u] || it == son[u]) continue;
        dfs2(it, it);
    }
}
int n, q, k, Max[maxn<<2], Min[maxn<<2], tag[maxn<<2];
void pushup(int rt){
    Max[rt] = max(Max[ls], Max[rs]);
    Min[rt] = min(Min[ls], Min[rs]);
}
void pushdown(int rt){
    if(tag[rt]){
        tag[ls] += tag[rt]; tag[rs] += tag[rt];
        Max[ls] += tag[rt]; Max[rs] += tag[rt];
        Min[ls] += tag[rt]; Min[rs] += tag[rt];
        tag[rt] = 0;
    }
}
void update(int L, int R, int val, int f=1, int rt=1, int l=1, int r=n){
    if(L <= l && R >= r){
        Min[rt] += val; Max[rt] += val;
        tag[rt] += val;
        return ;
    } int mid = l+r >> 1; pushdown(rt);
    if(L <= mid) update(L, R, val, f, ls, l, mid);
    if(R > mid) update(L, R, val, f, rs, mid+1, r);
    pushup(rt);
}
int query(int L, int R, int rt=1, int l=1, int r=n){
    if(L <= l && R >= r && Min[rt] == k && Max[rt] == k) return r-l+1;
    if(R<l || L>r || Max[rt] != k) return 0;
    pushdown(rt); int mid = l+r >> 1;
    return query(L, R, ls, l, mid)+query(L, R, rs, mid+1, r);
}
void Upd(int u, int v, int val){
    while(top[u] != top[v]){
        if(d[top[u]] < d[top[v]]) swap(u, v);
        update(id[top[u]], id[u], val);
        u = f[top[u]];
    }
    if(d[u] < d[v]) swap(u, v);
    update(id[v], id[u], val);
}
int Qry(int u, int v){
    int res = 0;
    while(top[u] != top[v]){
        if(d[top[u]] < d[top[v]]) swap(u, v);
        res += query(id[top[u]], id[u]);
        u = f[top[u]];
    }
    if(d[u] < d[v]) swap(u, v);
    res += query(id[v], id[u]);
    return res;
}
int t, Case;
int main()
{
    scanf("%d", &t);
    while(t--){
        scanf("%d", &n);

        Index = 0;
        for(int i=1; i<=n; i++) E[i].clear();
        int u, v;
        for(int i=1; i<n; i++) {
            scanf("%d%d", &u, &v);
            E[u].push_back(v); E[v].push_back(u);
        }
        d[1] = 1; f[1] = 1; dfs1(1, 0); dfs2(1, 1);

        printf("Case %d:\n", ++Case);
        scanf("%d", &q);
        pair<int, int> p[55];
        while(q--){
            scanf("%d", &k);
            update(1, n, 0);
            for(int i=1; i<=k; i++){
                scanf("%d%d", &u, &v);
                p[i] = make_pair(u, v);
                Upd(u, v, 1);
                if(i==k) printf("%d\n", Qry(u, v));
            }
            for(int i=1; i<=k; i++) Upd(p[i].first, p[i].second, -1);
        }
    }
}

 

Posted by mervyndk on Mon, 30 Sep 2019 07:50:17 -0700