Codes forces802 J& & K Send the Fool Further! (easy & & medium) -- Tree DP

Main idea of the title:

Given a tree, find the maximum length of a path that passes through all points not more than k from the root (the length is calculated only once many times after passing through an edge).

Obviously, a sub-tree will pass only once at most.
Let f[i] denote the answer in the subtree rooted with i, g[i] denote the answer in the subtree rooted with i, and dist[i] denote the distance between the first point and his father.
Then g[i] is equal to the sum of the maximum k-1 g+dist of all sons.
When calculating f[i], enumerate each son j, take the subtree of J as the final point, update f[i] with g[i]+f[j]+dist[j], and pay attention to judgment repetition.
Then output f[0].

Code:

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<queue>
using namespace std;
#define N 100010
priority_queue<int>Q;
struct Edge{
    int t,w,nx;
}e[N<<1];
struct Node{
    int w,f;
    Node(){}
    Node(int w,int f):w(w),f(f){};
}a[N];
int i,j,k,n,m,h[N],x,y,z,Num,f[N],g[N],b[N],M,Top,Sum;
inline void Add(int x,int y,int z){e[++Num].t=y;e[Num].w=z;e[Num].nx=h[x];h[x]=Num;}
inline int Max(int x,int y){return x<y?y:x;}
inline bool Cmp(Node a,Node b){return a.w>b.w;}
inline void Getf(int x,int y){
    Top=Sum=0;
    for(int i=h[x];i;i=e[i].nx)
    if(e[i].t!=y){
        Sum+=g[e[i].t]+e[i].w;
        a[++Top]=Node(g[e[i].t]+e[i].w,e[i].t);
    }
    if(Top<=k){
        for(int i=1;i<=Top;i++)
        f[x]=Max(f[x],Sum-g[a[i].f]+f[a[i].f]);
        return;
    }
    sort(a+1,a+Top+1,Cmp);
    Sum=0;
    for(int i=1;i<k;i++)b[a[i].f]=x,Sum+=a[i].w;
    for(int i=1;i<k;i++)f[x]=Max(f[x],a[k].w+Sum-g[a[i].f]+f[a[i].f]);
    for(int i=h[x];i;i=e[i].nx)
    if(e[i].t!=y&&b[e[i].t]!=x)f[x]=Max(f[x],Sum+f[e[i].t]+e[i].w);
}
inline void Getg(int x,int y){
    Top=0;
    for(int i=h[x];i;i=e[i].nx)
    if(e[i].t!=y)a[++Top]=Node(g[e[i].t]+e[i].w,e[i].t);
    sort(a+1,a+Top+1,Cmp);
    for(int i=1;i<k&&i<=Top;i++)g[x]+=a[i].w;
}
inline void Dfs(int x,int y){
    for(int i=h[x];i;i=e[i].nx)
    if(e[i].t!=y)Dfs(e[i].t,x);
    Getf(x,y);Getg(x,y);
}
int main(){
    scanf("%d%d",&n,&k);
    for(i=1;i<n;i++)scanf("%d%d%d",&x,&y,&z),Add(++x,++y,z),Add(y,x,z);
    Dfs(1,0);
    printf("%d\n",f[1]);
    return 0;
}

Posted by cookspyder on Wed, 13 Feb 2019 17:54:18 -0800