Probability + Tree Gauge Skillful Partitioning


The root node is not necessarily 1, but it's a definite point. Just look who's not the son.
Let's consider pushing the question from son to root. Let f[i][j] denote the probability that the longest light chain length is j in a subtree with I as its root.
Because every son has the same probability of being chosen as a son, and the contribution of a son to his father is different from that of a son, it is necessary to focus on each son, and then enumerate each son one by one. This efficiency is N^2, and then we need to enumerate the length of the chain. If we enumerate it to size[root], it's equivalent to N^3, and it's scrapped. But just enumerate to size[son]+1. Larger is no longer meaningful. The probability must be zero.
The f array is used again when calculating the answer, so save it first and then turn it around.
The calculation method of enumerating all sons for each enumeration of weighted sons. g[i][j] denotes the sum of probabilities that the longest chain length in the i-node subtree is 0-j.
F [i] [j] = f [son] [j] * g [i] [j] + G [son] [j] * f [i] - F [i] [j] * f [son] [j]; (for heavy sons)
Consider adding a child node to the answer at a time. Then there are two possibilities where the chain with the longest parent node is j appears. 1. Appear in the sub-nodes that have been added to the answer before (f[i][j]), then the length of the newly added sub-nodes does not matter as long as either of them is between 0 and J. Similarly, if the longest J appears in the newly added node, it doesn't matter how long it was added before...
What about young sons? Unlike heavy sons, light son-father connections contribute to the answer, so just change son's J to j-1.
Finally, statistical answer: the probability that the longest light chain length in the root node subtree is j * j, just add a sum.
For division, just multiply the inverse element.
Note: Better not look at my code, Joker, it's hard to understand... I combine f and g numbers into one...

#pragma GCC optimize("O3")
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#define N 3005
#define mod 1000000007
#define ll long long
using namespace std;
int read()
{
    int sum=0,f=1;char x=getchar();
    while(x<'0'||x>'9'){if(x=='-')f=-1;x=getchar();}
    while(x>='0'&&x<='9'){sum=(sum<<1)+(sum<<3)+x-'0';x=getchar();}
    return sum*f;
}
struct road{int v,next;}lu[N*2];
int n,e,adj[N];
ll out[N],sz[N],g[N],h[N],f[N][N];
bool v[N];
inline ll cheng(ll x,int m)
{
    ll ans=1;
    while(m)
    {
        if(m&1)ans=ans*x%mod;
        x=x*x%mod;
        m/=2;
    }
    return ans;
}
inline void dfs(int x)
{
    sz[x]=1;
    for(int i=adj[x];i;i=lu[i].next)
    {
        dfs(lu[i].v);
        sz[x]+=sz[lu[i].v];
    }
    ll fm=cheng(out[x],mod-2);
    for(int i=adj[x];i;i=lu[i].next)
    {
        int zz=lu[i].v;
        for(int j=0;j<=n;j++)g[j]=1;
        for(int j=adj[x];j;j=lu[j].next)
        {
            int to=lu[j].v;
            for(int k=0;k<=sz[to]+1;k++)
            {
                ll s=g[k],sum=f[to][k];
                if(k)s-=g[k-1],sum-=f[to][k-1];
                if(s<0)s+=mod;if(sum<0)sum+=mod;
                if(to==zz)
                    h[k]=((sum*g[k]+s*f[to][k]-sum*s)%mod+mod)%mod; 
                else if(k)
                {
                    sum=f[to][k-1];if(k!=1)sum-=f[to][k-2];
                    if(sum<0)sum==mod;
                    h[k]=((sum*g[k]+s*f[to][k-1]-sum*s)%mod+mod)%mod;
                }
            }
            g[0]=h[0];h[0]=0;
            for(int k=1;k<=sz[to]+1;k++)g[k]=(g[k-1]+h[k])%mod,h[k]=0;
        }
        for(int j=sz[x];j>=1;j--)g[j]=(g[j]-g[j-1]+mod)%mod;
        for(int j=0;j<=sz[x];j++)f[x][j]=(f[x][j]+g[j]*fm%mod)%mod;
    }
    if(!adj[x])f[x][0]=1;
    for(int i=1;i<=n;i++)f[x][i]=(f[x][i]+f[x][i-1])%mod;
}
int main()
{
    //freopen("tree.in","r",stdin);
    //freopen("tree.out","w",stdout);
    n=read();
    for(int i=1;i<=n;i++)
    {
        int k=read();out[i]=k;
        for(int j=1;j<=k;j++)
        {
            int x=read();v[x]=1;
            lu[++e]=(road){x,adj[i]};adj[i]=e;
        }
    }
    int root;
    for(int i=1;i<=n;i++)if(!v[i])root=i;
    dfs(root);
    ll ans=0;
    for(ll i=1;i<=n;i++)
        ans=(ans+i*(f[root][i]-f[root][i-1]+mod)%mod)%mod;
    cout<<ans;
}

Posted by Mr.x on Mon, 20 May 2019 16:36:12 -0700