AtCoder Grand Contest 019E: Shuffle and Swap Solutions

Very good dp + combination problem
This version refers to tourist's editorial
Instead of considering the random shuffle of two sequences, we consider two such operations.
1. Determining the Matching Method of a Sequence and b Sequence
2. Determine the order of occurrence of these matching methods
After we consider the matching of a sequence and b sequence, each ai in A sequence is connected to bi by a directed edge.
We find that there are only three cases for each position of A sequence.
1. There is a correspondence of a and no correspondence of b, so this point has only an edge.
2. There is a correspondence between a and b, so that the point has both entry and exit edges.
3. No a corresponds to a B corresponds to a B corresponds to a b, so this point has only an entry edge.
Let there be e points in class 1 and m points in class 2, because the number of 1 points in A and B is equal, so there are e points in class 3.
Considering the graph itself, we find that the graph must be composed of several rings and chains.
We find that the number of positions in the sequence A corresponding to these positions in the ring must be 1, so the order of occurrence of the edges in the ring is arbitrary.
We find that there is only one sequence of matching edges in a chain, because the value of the corresponding A sequence in a chain is only 0 at the end of the chain, the others are 1, the corresponding B sequence is only 0 at the end of the chain, and the others are 1 at the end of the chain.
Be sure to appear backwards and forwards.
The points in the ring are all of two kinds of points, the head and tail of the chain are all of one and three kinds of points, and the points in the middle of the chain are two kinds of points.
We consider how to divide two kinds of points into e-Chain and several rings.
Let dp[i][j] indicate that the number of schemes for putting J points into i-chain has been taken into account, dp[i][j]=j k=0dp [i_1][k](j_k+1)!dp[i][j]=k=0j DP [i_1][k](j_k+1)! (The meaning of factorial on denominator is explained later)
Finally, add up the answers of dp[e][0~m], multiply by e! M! (e+m)! (e+m)! (e+m)!
E! Refers to the pairing of the head and tail of the e chain. There is an e! Matching method.
m! Refers to the order of connection of m points of type 2. For example, the point of a chain is determined, but the method of connection of this chain has factorial species.
We also need to consider the order of edges, all of which are (e+m)! But there is only one order for each chain, so divide it by several (u+1)!, which has been removed when calculating dp.
So there's an O(n3)O(n3) approach.
Consider optimization
We find that the transfer equation of dp is a convolution form, k+(j-k+1)=j+1, so the transfer of each layer can be optimized by NTT, and the complexity is reduced to O(n2logn)O(n2logn).
Then we find that the polynomial of each multiplication is the same, it is f(x) = _mi = 0x I (i + 1)! f(x) = _i=0mxi(i+1)!, so we can quickly power + NTT, complexity O(nlog2n)O(nlog2n)

#include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <utility>
#include <cctype>
#include <algorithm>
#include <bitset>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <deque>
#include <stack>
#include <cmath>
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair<int,int>
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;

const int MOD=998244353;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;
const double pi=3.14159265;
const int G=3;

inline int getint()
{
    char ch;int res;bool f;
    while (!isdigit(ch=getchar()) && ch!='-') {}
    if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
    while (isdigit(ch=getchar())) res=res*10+ch-'0';
    return f?res:-res;
}

int inv[200048];
LL finv[200048],fac[200048];
int n;
char s1[100048],s2[100048];
int e,m;

inline void init_inv()
{
    int i;
    fac[0]=fac[1]=inv[0]=inv[1]=finv[0]=finv[1]=1;
    for (i=2;i<=100000;i++)
    {
        fac[i]=(fac[i-1]*i)%MOD;
        inv[i]=MOD-((long long)(MOD/i)*inv[MOD%i])%MOD;
        finv[i]=(finv[i-1]*inv[i])%MOD;
    }
} 

inline LL quick_pow(LL x,LL y)
{
    x%=MOD;LL res=1;
    while (y)
    {
        if (y&1) res=(res*x)%MOD,y--;
        x=(x*x)%MOD;y>>=1;
    }
    return res;
}

int len;
LL wn_pos[100048],wn_neg[100048];
inline void init_wn()
{
    for (register int clen=2;clen<=len;clen<<=1)
    {
        wn_pos[clen]=quick_pow(G,(MOD-1)/clen);
        wn_neg[clen]=quick_pow(G,(MOD-1)-(MOD-1)/clen);
    }
}

LL a[100048],b[100048];

inline void NTT(LL c[],int fl)
{
    int i,j,k,clen;
    for (i=(len>>1),j=1;j<len;j++)
    {
        if (i<j) swap(c[i],c[j]);
        for (k=(len>>1);i&k;k>>=1) i^=k;
        i^=k;
    }
    for (clen=2;clen<=len;clen<<=1)
    {
        LL wn=(fl==1?wn_pos[clen]:wn_neg[clen]);
        for (j=0;j<len;j+=clen)
        {
            LL w=1;
            for (k=j;k<j+(clen>>1);k++)
            {
                LL tmp1=c[k],tmp2=(c[k+(clen>>1)]*w)%MOD;
                c[k]=(tmp1+tmp2)%MOD;c[k+(clen>>1)]=((tmp1-tmp2)%MOD+MOD)%MOD;
                w=(w*wn)%MOD;
            }
        }
    }
    if (fl==-1)
        for (i=0;i<len;i++) c[i]=(c[i]*inv[len])%MOD;
}

inline void calc_NTT()
{
    NTT(a,1);NTT(b,1);
    for (register int i=0;i<len;i++) a[i]=(a[i]*b[i])%MOD;
    NTT(a,-1);
}

struct poly
{
    LL A[100048];
    inline poly operator * (const poly B) const
    {
        int i;poly res;
        memset(a,0,sizeof(a));memset(b,0,sizeof(b));
        for (i=0;i<=m;i++) a[i]=A[i],b[i]=B.A[i];
        calc_NTT();
        for (i=0;i<=m;i++) res.A[i]=a[i];
        return res;
    }
};

inline poly Quick_pow(poly x,LL y)
{
    int i;poly res;
    for (i=0;i<=m;i++) x.A[i]%=MOD;
    res.A[0]=1;
    while (y)
    {
        if (y&1) res=res*x,y--;
        x=x*x;y>>=1;
    }
    return res;
}

int main ()
{
    int i,j,k;
    scanf("%s%s",s1+1,s2+1);n=strlen(s1+1);
    init_inv();
    e=m=0;
    for (i=1;i<=n;i++)
    {
        if (s1[i]=='1' && s2[i]=='0') e++;
        if (s1[i]=='1' && s2[i]=='1') m++;
    }
    len=1;while (len<=m*2) len<<=1;
    init_wn();
    poly ans;for (i=0;i<=m;i++) ans.A[i]=finv[i+1];
    ans=Quick_pow(ans,e);
    LL fans=0;
    for (i=0;i<=m;i++) fans=(fans+ans.A[i])%MOD;
    fans=fans*fac[e]%MOD*fac[m]%MOD*fac[e+m]%MOD;
    printf("%lld\n",fans);
    return 0;
}

Posted by cvincent on Tue, 25 Dec 2018 11:57:06 -0800