Fast number theory transformation (NTT)

Keywords: Algorithm

Fast number theory transformation (NTT)

I wanted to learn this thing before. I haven't done it orz. Now I'll make it up.

I don't think I have much new knowledge about learning this thing. It's good to master FFT before learning.

FFT can be seen here: https://www.cnblogs.com/Tenshi/p/15434004.html

NTT is used to solve the modular problem of polynomial multiplication, because FFT may not be accurate enough and the constant is large, so FFT is transformed by means of number theory to obtain NTT.

principle

Let's first assume that the modulus \ (P \) has the original root \ (rt \), which means \ (rt^{\varphi(P)}=1 \). If the modulus is more special, assume that the modulus is a prime number with the original root (such as \ (998244353 \). Then there are \ (rt^{P-1} \equiv 1 \pmod P,~ rt^{P} \equiv rt \pmod P \).

In FFT, we have \ (\ cos\theta + i\sin\theta)^{2\pi/ \theta} = 1 \), which is very close to \ (rt^{P-1} \equiv 1 \pmod P \) in form. We further limit the modulus: the modulus needs to be expressed in the form of \ (k\times 2^x + 1 \). (because this can ensure that \ (p-1 \) can be divided by the power of \ (2 \) in the process of divide and conquer).

Therefore, in the sense of modulus, we can write NTT code similar to FFT:

void NTT(ll *a, int type, int mod){
	for(int i=0; i<tot; i++){
		a[i]%=mod;
		if(i<rev[i]) swap(a[i], a[rev[i]]);
	}
	
	for(int mid=1; mid<tot; mid<<=1){
		ll w1=fpow(rt, (type==1? (mod-1)/(mid<<1): mod-1-(mod-1)/(mid<<1)), mod);
		for(int i=0; i<tot; i+=mid*2){
			ll wk=1;
			for(int j=0; j<mid; j++, wk=wk*w1%mod){
				auto x=a[i+j], y=wk*a[i+j+mid]%mod;
				a[i+j]=(x+y)%mod, a[i+j+mid]=(x-y+mod)%mod;
			}
		}
	}
	
	if(type==-1){
		for(int i=0; i<tot; i++) a[i]=a[i]*inv(tot, mod)%mod;
	}
}

expand

As mentioned above, we have some restrictions on the module in the NTT process. Naturally, we can not directly promote it to solve the problem of NTT with any module. What should we do?

In combination with the template question:

https://www.luogu.com.cn/problem/P4245

The module given by the title is \ (P \).

We can select three modules \ (m_1,m_2,m_3 \) to meet \ (m_1m_2m_3 > NP ^ 2 \).

It can be taken as \ (m_1=998244353,~ m_2=1004535809,~ m_3=469762049 \). ​

We first make NTT with these three modules, and then combine these three results with CRT (Chinese remainder theorem).

Direct merger will explode long long, so how to merge? Specifically:

Suppose the result is \ (ans \), and we agree that \ (inv(x,y) \) represents the inverse of \ (x \) module \ (Y \).

\[ans\equiv c_1 \pmod {m_1} \\ ans\equiv c_2 \pmod {m_2} \\ ans\equiv c_3 \pmod {m_3} \\ \]

Merge the first two first:

\(ans \equiv c_1\times m_2\times inv(m_1, m_2) + c_2\times m_1\times inv(m_2, m_1) \pmod {m_1m_2}\).

Record \ (M=m_1m_2, ~ C=c_1\times m_2\times inv(m_1, m_2) + c_2\times m_1\times inv(m_2, m_1) \)

Then we set \ (ans = xM + C = ym_3 + c_3 \).

There is \ (x \equiv (c_3-C)\times inv(M, m_3) \pmod {m_3} \).

Let \ (t = (c_3-C)\times inv(M, m_3) \), \ (t \) can obviously be calculated directly. Let's further set \ (x = km_3 + t \).

Then \ (ans = km_1m_2m_3 + tM + C \).

Because \ (ANS < m_1 m_2 m_3 \), so \ (ans=tM+C \).

Finally, take the mold of \ (P \). ​

realization:

// Problem: P4245 [template] arbitrary modulus polynomial multiplication
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P4245
// Memory Limit: 500 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include<bits/stdc++.h>
using namespace std;

#define debug(x) cerr << #x << ": " << (x) << endl
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define dwn(i,a,b) for(int i=(a);i>=(b);i--)

using pii = pair<int, int>;
using ll = long long;

inline void read(int &x){
    int s=0; x=1;
    char ch=getchar();
    while(ch<'0' || ch>'9') {if(ch=='-')x=-1;ch=getchar();}
    while(ch>='0' && ch<='9') s=(s<<3)+(s<<1)+ch-'0',ch=getchar();
    x*=s;
}

const int N=3e5+5;
const ll m1=998244353, m2=1004535809, m3=469762049, M=m1*m2, rt=3;

int n, m, P;
ll a[3][N], b[3][N], ans[N];

int rev[N], tot=1, bit;

ll fpow(ll x, int p, ll mod){
	int res=1;
	for(; p; p>>=1, x=x*x%mod) if(p&1) res=res*x%mod;
	return res;
}

ll inv(ll x, ll mod){
	return fpow(x, mod-2, mod);
}

ll mul(ll x, int p, ll mod){
	ll res=0;
	for(; p; p>>=1, x=(x+x)%mod) if(p&1) res=(res+x)%mod;
	return res;
}

void NTT(ll *a, int type, int mod){
	for(int i=0; i<tot; i++){
		a[i]%=mod;
		if(i<rev[i]) swap(a[i], a[rev[i]]);
	}
	
	for(int mid=1; mid<tot; mid<<=1){
		ll w1=fpow(rt, (type==1? (mod-1)/(mid<<1): mod-1-(mod-1)/(mid<<1)), mod);
		for(int i=0; i<tot; i+=mid*2){
			ll wk=1;
			for(int j=0; j<mid; j++, wk=wk*w1%mod){
				auto x=a[i+j], y=wk*a[i+j+mid]%mod;
				a[i+j]=(x+y)%mod, a[i+j+mid]=(x-y+mod)%mod;
			}
		}
	}
	
	if(type==-1){
		for(int i=0; i<tot; i++) a[i]=a[i]*inv(tot, mod)%mod;
	}
}

void CRT(){
	for(int i=0; i<tot; i++){
		ll res=0;
		(res+=mul(a[0][i]*m2%M, inv(m2, m1), M))%=M;
		(res+=mul(a[1][i]*m1%M, inv(m1, m2), M))%=M;
		a[1][i]=res;
	}
	for(int i=0; i<tot; i++){
		ll res=(a[2][i]-a[1][i]%m3+m3)%m3*inv(M%m3, m3)%m3;
		ans[i]=(M%P*res%P+a[1][i]%P)%P;
	}
}

void solve(int k, int mod){
	NTT(a[k], 1, mod), NTT(b[k], 1, mod);
	for(int i=0; i<tot; i++) a[k][i]=a[k][i]*b[k][i]%mod;
	NTT(a[k], -1, mod);
}

int main(){
	cin>>n>>m>>P;
	rep(i,0,n){
		int t; read(t);
		rep(j,0,2) a[j][i]=t%P;
	}
	rep(i,0,m){
		int t; read(t);
		rep(j,0,2) b[j][i]=t%P;
	}
	
	while(tot<=n+m) bit++, tot<<=1;
	for(int i=0; i<tot; i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
	
	solve(0, m1), solve(1, m2), solve(2, m3);
	CRT();
	
	rep(i,0,n+m) cout<<ans[i]<<' ';
	cout<<endl;
	
	return 0;
}

Posted by dnast on Sat, 04 Dec 2021 19:49:24 -0800