There is an infinite tree with roots. Each node has N children. The distance between each child and the parent node is di. Find the number of nodes whose distance from the root node is less than or equal to X
Idea: pay attention to the observation data range, each d[i] is less than or equal to 100, so we can set dp[i] to represent the number of points from the origin I, sum[i] to represent the total, and finally sum[x] to calculate.
State transfer equation
1 for(int i=1;i<=100;i++) 2 { 3 for(int j=1;j<=i;j++) 4 { 5 dp[i]=(dp[i]+dp[i-j]*cnt[j])%MOD; 6 } 7 sum[i]=(sum[i-1]+dp[i])%MOD; 8 }
But because x is too large, and dp[n+1] is every dp[i]*cnt[n+1-i], let N=100, save the first 100 numbers of n+1, update sum with the 101st column, and solve it with the fast power of matrix. (C is the unit matrix!!!)
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const ll MOD=(ll)1e9+7; 5 6 struct matrix 7 { 8 ll x,y; 9 ll a[112][112]; 10 matrix(){ 11 }; 12 matrix(ll xx,ll yy):x(xx),y(yy) 13 { 14 memset(a,0,sizeof(a)); 15 } 16 }Base; 17 18 matrix mul (matrix A,matrix B) 19 { 20 matrix C(A.x,B.y); 21 for(int i=1;i<=A.x;i++) 22 { 23 for(int j=1;j<=B.y;j++) 24 { 25 for(int k=1;k<=B.x;k++) 26 { 27 C.a[i][j]=(C.a[i][j]+A.a[i][k]*B.a[k][j])%MOD; 28 } 29 } 30 } 31 return C; 32 } 33 34 ll n,x,cnt[120],sum[120],dp[120]; 35 void init() 36 { 37 memset(cnt,0,sizeof(cnt)); 38 scanf("%lld%lld",&n,&x); 39 for(int i=1;i<=n;i++) 40 { 41 int p; 42 scanf("%d",&p); 43 cnt[p]++; 44 } 45 memset(sum,0,sizeof(sum)); 46 dp[0]=1; sum[0]=1; 47 for(int i=1;i<=100;i++) 48 { 49 for(int j=1;j<=i;j++) 50 { 51 dp[i]=(dp[i]+dp[i-j]*cnt[j])%MOD; 52 } 53 sum[i]=(sum[i-1]+dp[i])%MOD; 54 } 55 56 Base.x=101; Base.y=101; 57 //123...97 98 99 100 101(ans) 58 59 //000...000 len[100] len[100] 60 //100...000 len[99] len[99] 61 //010...000 len[98] len[98] 62 //......................... 63 //000...001 len[1] len[1] 64 //000...000 0 1 65 for(int i=1;i<=99;i++) 66 { 67 Base.a[i+1][i]=1; 68 Base.a[i][100]=cnt[101-i]; 69 Base.a[i][101]=cnt[101-i]; 70 } 71 Base.a[100][100]=Base.a[100][101]=cnt[1]; 72 Base.a[101][101]=1; 73 } 74 75 matrix qpow (matrix Base,ll b) 76 { 77 matrix C(Base.x,Base.y); 78 for(int i=1;i<=101;i++) C.a[i][i]=1;//danweijuzhen 79 while(b) 80 { 81 if(b%2==1) C=mul(C,Base); 82 b/=2; 83 Base=mul(Base,Base); 84 } 85 return C; 86 } 87 88 int main() 89 { 90 init(); 91 matrix A(1,101); 92 for(int i=1;i<=100;i++) A.a[1][i]=dp[i]; 93 A.a[1][101]=sum[100]; 94 if(x<=100) 95 { 96 printf("%lld\n",sum[x]); 97 return 0; 98 } 99 A=mul(A,qpow(Base,x-100)); 100 printf("%lld\n",A.a[1][101]%MOD); 101 return 0; 102 }