CS Academy Round 75 Permutations NTT

subject

Give a positive integer n and a query number q. Give two numbers x and y for each query. Ask how many permutations P of length n satisfy:
py=maxni=1pipy=maxi=1npi
2px<py2px<py
n,q<10^5

Analysis

Because x is only a location, different X does not affect the answer for a given y.
So let's just think about y.
And then we can see that the answer is actually what we need.

ans(y)=∑i=1n⌊i−12⌋Cy−2i−2(y−2)!(n−y)!ans(y)=∑i=1n⌊i−12⌋Ci−2y−2(y−2)!(n−y)!

Then we can solve it with ntt.

Code

#include <bits/stdc++.h>

typedef long long ll;

const int N = 100010;
const int MOD = 998244353;

int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
    while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
    return x * f;
}

int f[N * 4],g[N * 4],rev[N * 4];

int ny[N],jc[N];

int L,n,q;

int ksm(int x,int y)
{
    int res = 1;
    while (y)
    {
        if (y & 1)
            res = 1ll * res * x % MOD;
        x = 1ll * x * x % MOD;
        y >>= 1;
    }
    return res;
}

void ntt(int *a, int f)
{
    for (int i = 0; i < L; i++)
        if (i < rev[i])
            std::swap(a[i], a[rev[i]]);
    for (int i = 1; i < L; i <<= 1)
    {
        int wn = ksm(3, f == 1 ? (MOD - 1) / i / 2 : MOD - 1 - (MOD - 1) / i / 2);
        for (int j = 0; j < L; j += (i << 1))
        {
            int w = 1;
            for (int k = 0; k < i; k++)
            {
                int u = a[j + k], v = 1ll * a[j + k + i] * w % MOD;
                a[j + k] = (u + v) % MOD;
                a[j + k + i] = (u + MOD - v) % MOD;
                w = 1ll * w * wn % MOD;
            }
        }
    }
    int nyL = ksm(L, MOD - 2);
    if (f == -1)
        for (int i = 0; i < L; i++)
            a[i] = 1ll * a[i] * nyL % MOD;
}

int main()
{
    n = read(), q = read();
    jc[0] = jc[1] = ny[0] = ny[1] = 1;
    for (int i = 2; i <= n; i++)
        jc[i] = 1ll * jc[i - 1] * i % MOD, ny[i] = 1ll * (MOD - MOD / i) * ny[MOD % i] % MOD;
    for (int i = 2; i <= n; i++)
        ny[i] = 1ll * ny[i - 1] * ny[i] % MOD;
    for (int i = 3; i <= n; i++)
        f[i] = 1ll * (i - 1) / 2 * jc[i - 2] % MOD;
    for (int i = 0; i <= n; i++)
        g[i] = ny[n - i];
    int lg = 0;
    for (L = 1; L <= n * 2; L <<= 1, lg++);
    for (int i = 0; i < L; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    ntt(f, 1), ntt(g, 1);
    for (int i = 0; i < L; i++)
        f[i] = 1ll * f[i] * g[i] % MOD;
    ntt(f, -1);
    while (q--)
    {
        int x = read(), y = read();
        printf("%d\n", 1ll * f[n + y] * jc[n - y] % MOD);
    }
}

Posted by dcuellar on Fri, 08 Feb 2019 15:12:18 -0800