前言

前置知识:

FFT 已经能很好的在 O(nlogn)O(n\log n) 内解决卷积问题,但是 FFT 用到的单位根和三角函数有关,只能采用浮点数储存,必定会带来精度误差。

我们尝试不在复数域内,而是在模剩余系内寻求一种新的单位根来解决问题,这样带来了两个好处:

  1. 不再有误差(只要大小没有超出模数);
  2. 如果题目要求取模,一并可以完成取模的任务。

新型单位根

我们回看 FFT,FFT 中单位根 wnkw_n^k 满足了如下性质:

  1. wnk=(wn1)kw_n^k = \left(w_n^1\right)^k
  2. wn0,wn1,,wnn1w_n^0,w_n^1,\ldots,w_n^{n-1} 互不相同;
  3. wnk+n/2=wnkw_n^{k+n/2}=-w_n^k
    这个性质可以导出 wnkmodn=wnkw_n^{k\bmod n}=w_n^k
  4. w2n2k=wnkw_{2n}^{2k}=w_n^k

只要满足了以上性质,就可以作为 FFT 中的单位根。

假设我们的模数是质数 pp,其一个原根为 gg

定义 wn1=g(p1)/nw_n^1 = g^{(p-1)/n}wnk=(wn1)kw_n^k = \left(w_n^1\right)^k,可以证明其满足如上所有性质。(保证 nnp1p-1 的因数)

证明:

  1. 由我们的定义得证。
  2. 由阶的性质可知 wn1w_n^1 的阶为 p1gcd(p1,(p1)/n)=p1(p1)/n=n\dfrac{p-1}{\gcd(p-1,(p-1)/n)} = \dfrac{p-1}{(p-1)/n} = n,则得证。
  3. 只需证明 xnn/21(modp)x_n^{n/2}\equiv -1 \pmod p 即可。
    考虑引理(二次探测定理):如果 pp 是素数,x21(modp)x^2\equiv 1\pmod p 的解为 x1=1,x2=p1x_1=1,x_2=p-1
    已知 (xnn/2)21(modp)\left(x_n^{n/2}\right)^2 \equiv 1 \pmod p,显然其不能等于 11(否则与阶的性质矛盾),所以只能等于 p1p-1
  4. 左边=g2k(p1)/2n=gk(p1)/n=右边\text{左边}=g^{2k(p-1)/2n}=g^{k(p-1)/n}=\text{右边}

由此,我们就成功消灭了复数单位根。我们把采用这种单位根用来卷积的算法叫做快速数论变换(NTT)。

NTT 实现

我们考虑模数 pp 的选取。
由上文我们知道 p1p-1 一定要是我们分治长度 nn 的倍数,我们希望 p1p-1 中含有尽可能多的 22 的因子。

一个常用模数 998244353=119×223+1998244353 = 119 \times 2^{23}+1,这就表明,用 998244353998244353 做模数的 NTT 的序列长度最长为 2232^{23}。(当然不会这么长的,要不然先 TLE 了)

其一个原根为 33

NTT 相关质数与原根表 (Miskcoo’s Space)

我们用 998244353998244353 作为模数实现 NTT,模板代码如下(细节见注释):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include<bits/stdc++.h>
using namespace std;
using ll = long long;

const int MOD = 998244353;
int ksm(int a,int b){ // 快速幂
int r = 1;
while(b){
if(b&1)r=(ll)r*a%MOD;
a=(ll)a*a%MOD;
b>>=1;
}
return r;
}
const int G = 3, INVG = ksm(G,MOD-2); // 原根和原根的逆

const int N = 4e6+5;
int rev[N];
void NTT(int* f,int n,bool flag){
for(int i=0;i<n;i++){
if(i<rev[i])swap(f[i],f[rev[i]]);
}
for(int p=2;p<=n;p<<=1){
int g = ksm(flag?INVG:G, (MOD-1)/p); // 换成新的单位根
for(int j=0;j<n;j+=p){
int gnow = 1;
for(int k=j;k<(j|p>>1);k++){
int t = (ll)f[k|p>>1] * gnow % MOD;
f[k|p>>1] = f[k] - t; if(f[k|p>>1]<0)f[k|p>>1]+=MOD; // 注意取模
f[k] = f[k] + t; if(f[k]>=MOD)f[k]-=MOD;
gnow = (ll)gnow * g % MOD;
}
}
}

int invN = ksm(n, MOD-2);
if(flag)for(int i=0;i<n;i++)f[i]=(ll)f[i]*invN%MOD; // 改成逆元
}

int f[N],g[N]; // 主函数完全没有变动
int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n,m;cin>>n>>m;
for(int i=0;i<=n;i++)cin>>f[i];
for(int i=0;i<=m;i++)cin>>g[i];

for(m+=n,n=1;n<=m;n<<=1);
for(int i=0;i<n;i++){
rev[i]=rev[i>>1]>>1;
if(i&1)rev[i]|=n>>1;
}

NTT(f,n,0);NTT(g,n,0);
for(int i=0;i<n;i++)f[i]=(ll)f[i]*g[i]%MOD;
NTT(f,n,1);
for(int i=0;i<=m;i++){
cout<<f[i]<<' ';
}
cout<<'\n';
return 0;
}

例题(利用 NTT 解决计数问题)

ABC390G - Permutation Concatenation

题意

给定正整数 nn,求 1n1\sim nn!n! 个排列拼接起来得到的数字和,对 998244353998244353 取模。
比如 n=3n=3,答案为 123+132+213+231+312+321=1332123+132+213+231+312+321=1332

解法

考虑拆开考虑每一个数的贡献。
cic_i 表示 1n1\sim n 中长度为 ii 的数个数,lil_i 表示 ii 的长度。

假设当前数为 mm,定义 did_i 表示除了 mm 其他数中长度为 ii 的数个数。

枚举后面有 jj 个数,前面就有 n1jn-1-j 个数,贡献即为

j=0n1j!(nj1)!i1+i2++i6=jik<dk(dkik)10k=16kik\sum_{j=0}^{n-1}j!(n-j-1)!\sum_{i_1 + i_2+\ldots + i_6 = j}^{i_k<d_k}\prod{\binom{d_k}{i_k}}10^{\sum_{k=1}^6{ki_k}}

两个阶乘表示了前后的数的顺序排序,组合数表示了选法,1010 的幂表示了 mm 后面有几位。

我们尝试处理出阶乘后面的一项,也就是后面有 jj 个数的取法带上 1010 的幂的贡献。

定义形式幂级数

Fi=j=0ci10ij(cij)xjGi=j=0ci110ij(ci1j)xjHt=Gti=16FiFt\begin{aligned} & F_i=\sum_{j=0}^{c_i}10^{ij}\binom{c_i}{j}x^j \\ & G_i=\sum_{j=0}^{c_i-1}10^{ij}\binom{c_i-1}{j}x^j \\ & H_t=\frac{G_t\prod_{i=1}^6F_i}{F_t} \\ \end{aligned}

这样 [xj]Hi[x^j]H_i 表示当前数长度为 ii,后面有 jj 个数的贡献,则最后的系数可以写作:

j=0nj!(nj1)![xj]Ht\sum_{j=0}^nj!(n-j-1)![x^j]H_t

在实现方面,要计算 H1H6H_1\sim H_6,直接把 F1F6,G1G6F_1\sim F_6, G_1\sim G_6 都转化成点值形式,然后 HH 的点值直接乘出来,最后统一化回系数形式即可。

这样只需要 1818 次 NTT,复杂度 O(nlog2n)O(n\log^2 n),通过了本题。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include<bits/stdc++.h>
using namespace std;

template <class T>
using must_int = enable_if_t<is_integral<T>::value, int>;
template <unsigned umod>
struct modint {
static constexpr int mod = umod;
unsigned v;
modint() : v(0) {}
template <class T, must_int<T> = 0>
modint(T _v) {int x = _v % (int)umod; v = x < 0 ? x + umod : x;}
modint operator+() const { return *this; }
modint operator-() const { return modint() - *this; }
friend int raw(const modint &self) { return self.v; }
friend ostream &operator<<(ostream &os, const modint &self) { return os << raw(self);}
modint &operator+=(const modint &rhs) {v += rhs.v;if (v >= umod) v -= umod;return *this;}
modint &operator-=(const modint &rhs) {v -= rhs.v;if (v >= umod) v += umod;return *this;}
modint &operator*=(const modint &rhs) {v = 1ull * v * rhs.v % umod; return *this;}
modint &operator/=(const modint &rhs) { return *this *= rhs.inv(); }
modint inv() const {
assert(v);
static unsigned lim = 1 << 21;
static vector<modint> inv{0, 1};
if (v >= lim) return qpow(*this, mod - 2);
inv.reserve(v + 1);
while (v >= inv.size()) {
int m = inv.size();
inv.resize(m << 1);
for (int i = m; i < m << 1; i++)inv[i] = (mod - mod / i) * inv[mod % i];
}
return inv[v];
}
template <class T, must_int<T> = 0>
friend modint qpow(modint a, T b) {modint r = 1;for (; b; b >>= 1, a *= a)if (b & 1) r *= a;return r;}
friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
friend bool operator==(const modint &lhs, const modint &rhs) { return lhs.v == rhs.v; }
friend bool operator!=(const modint &lhs, const modint &rhs) { return lhs.v != rhs.v; }
};
const int MOD = 998244353;
using mint = modint<MOD>;

inline mint C(int n,int r){
static vector<mint> f{1, 1}, finv{1, 1};
while(n >= f.size()){
int m = f.size();
f.resize(m << 1), finv.resize(m << 1);
for (int i = m; i < m << 1; i++)f[i]=f[i-1]*i,finv[i]=finv[i-1]/i;
}
return (r<0)?0:(f[n]*finv[r]*finv[n-r]);
}

const mint G = 3, INVG = mint{3}.inv();

const int N = 2e6+5;
mint pow10[N];
void NTT(vector<mint>& f,int n,bool flag){
vector<int> rev(n);
for(int i=0;i<n;i++){
rev[i]=rev[i>>1]>>1;
if(i&1)rev[i]|=n>>1;
}
for(int i=0;i<n;i++){
if(i<rev[i])swap(f[i],f[rev[i]]);
}
for(int p=2;p<=n;p<<=1){
mint g = qpow((flag?INVG:G),(MOD-1)/p);
for(int j=0;j<n;j+=p){
mint gnow = 1;
for(int k=j;k<(j|p>>1);k++){
mint t = f[k|p>>1] * gnow;
f[k|p>>1] = f[k] - t;
f[k] = f[k] + t;
gnow *= g;
}
}
}

mint invN = qpow(mint{n}, MOD-2);
if(flag)for(int i=0;i<n;i++)f[i]*=invN;
}

vector<mint> f[7],g[7],h[7];
int cnt[7];

int len[N];
mint fact[N];
mint coef[7];

int main(){
ios::sync_with_stdio(0);cin.tie(0);
int n;cin>>n;
for(int i=1;i<=n;i++){
len[i]=len[i/10]+1;
cnt[len[i]]++;
}
pow10[0]=1;
for(int i=1;i<N;i++)pow10[i]=pow10[i-1]*10;
fact[0]=1;
for (int i =1; i <=n; i++)fact[i]=fact[i-1]*i;

int m;
for(m=n,n=1;n<=m;n<<=1);


for(int i=1;i<=6;i++){
f[i].resize(n); g[i].resize(n);
for(int j=0;j<=cnt[i];j++) f[i][j]=C(cnt[i],j)*pow10[i*j];
for(int j=0;j<=cnt[i]-1;j++)g[i][j]=C(cnt[i]-1,j)*pow10[i*j];

NTT(f[i],n,0);NTT(g[i],n,0);
}


for(int i=1;i<=6;i++){
h[i].resize(n);
for(int k=0;k<n;k++){
h[i][k]=1;
for(int j=1;j<i;j++)h[i][k]*=f[j][k];
for(int j=i+1;j<=6;j++)h[i][k]*=f[j][k];
h[i][k]*=g[i][k];
}

NTT(h[i],n,1);

for(int j=0;j<m;j++){
coef[i] += fact[j]*fact[m-1-j]*h[i][j];
}
}

mint ans = 0;
for(int i=1;i<=n;i++){
ans += coef[len[i]] * i;
}

cout << ans << '\n';
return 0;
}