分治 NTT

给定序列 g1n1g_{1\dots n - 1},求序列 f0n1f_{0\dots n - 1}

其中 fi=j=1ifijgjf_i=\sum_{j=1}^if_{i-j}g_j,边界为 f0=1f_0=1

答案对 998244353998244353 取模。

考虑对原序列进行分治,假设当前要求出的区间是 [l,r][l,r],其中 [l,mid][l,mid] 已经递归左半部分算完,那么我们需要先考虑 [l,mid][l,mid][mid+1,r][mid+1,r] 的贡献再递归右半部分即可。

那我们发现,[l,mid][l,mid] 对于 [mid+1,r][mid+1,r] 的贡献是一个卷积形式。所以我们把 f[l,mid]f[l,mid]g[1,rl]g[1,r-l] 抽出来做一个 NTT 卷积,然后把答案加到右边,之后再递归右边就能解决这个问题。

因为分治了 O(logn)O(\log n) 层,所以复杂度是 O(nlog2n)O(n\log^2 n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include<bits/stdc++.h>
using namespace std;

const int p = 998244353;
int qpow(int a,int b){
int r = 1;
while(b){
if(b&1)r=1ll*r*a%p;
a=1ll*a*a%p;
b>>=1;
}
return r;
}

using LL=long long;
using ULL=unsigned long long;

const int N = 4e5+5;
vector<int> omega[25];
// n 是 DFT 的最大长度,如果有 2 个长为 m 的多项式相乘,n 需要 >=2m
void ntt_init(int n){
for(int k=2,d=0;k<=n;k*=2,d++){
omega[d].resize(k+1);
int wn=qpow(3,(p-1)/k),tmp=1;
for(int i=0;i<=k;i++){
omega[d][i]=tmp;
tmp=(LL)tmp*wn%p;
}
}
}
// 传入的必须在 [0,p) 范围内,不能有负的
// 否则要把 d==16 改成 d%8==0 之类,多取几次模
void ntt(int *c,int n,int tp){
static ULL a[N];
for(int i=0;i<n;i++)a[i]=c[i];
for(int i=1,j=0;i<n-1;i++){
int k=n;
do j^=(k>>=1); while(j<k);
if(i<j)swap(a[i],a[j]);
}

for(int k=1,d=0;k<n;k*=2,d++){
if(d==16)for(int i=0;i<n;i++)a[i]%=p;
for(int i=0;i<n;i+=k*2){
for(int j=0;j<k;j++){
int w = omega[d][tp>0 ? j : k*2-j];
ULL u = a[i+j], v = w*a[i+j+k]%p;
a[i+j]=u+v;
a[i+j+k]=u-v+p;
}
}
}
if(tp>0){for(int i=0;i<n;i++)c[i]=a[i]%p;}
else{
int inv = qpow(n, p-2);
for(int i=0;i<n;i++)c[i]=a[i]*inv%p;
}
}

int get_lim(int x){
int r = 1;
while(r<=x)r<<=1;
return r;
}

int A[N],B[N];
int f[N],g[N];
void solve(int l,int r){
if(l==r){
if(!l) f[l] = 1;
return;
}
int mid = (l+r)>>1;

solve(l, mid);

int lim = get_lim(2*(r-l+1));
for(int i=0;i<lim;i++)A[i]=B[i]=0;
for(int i=l;i<=mid;i++)A[i-l]=f[i];
for(int i=0;i<=r-l;i++)B[i]=g[i];

ntt(A, lim, 1);
ntt(B, lim, 1);
for(int i=0;i<lim;i++)A[i]=1ll*A[i]*B[i]%p;
ntt(A, lim, -1);

for(int i=mid+1;i<=r;i++)f[i]=(f[i]+A[i-l])%p;
solve(mid+1, r);
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n; cin>>n;
for(int i=1;i<n;i++)cin>>g[i];
ntt_init(n*4);

solve(0, n-1);
for(int i=0;i<n;i++)cout<<f[i]<<' ';
cout<<'\n';
return 0;
}

例题 1

题意

[2024 Shanghai Regional F] Fast Bogosort

考虑如下算法流程:
给定一个排列 pp,将 pp 划分为极多的 [l,r][l,r],使得 p[l,r]p[l,r] 恰好是 l,l+1,,rl,l+1,\cdots,r 这些数。对于这些区间,若 l<rl<r,则将其随机打乱。

现在给定一个排列 pp,求让 pp 变为有序所需要的期望打乱次数,对 998244353998244353 取模。

n105n\le 10^5

解法

g(i)g(i) 表示长度为 ii 的排列不能再分割的个数。

容易写出递推式

g(n)=n!i=1n1g(i)(ni)!g(n) = n! - \sum_{i=1}^{n-1}g(i)(n-i)!

边界条件 g(1)=1g(1)=1,这是一个分治 NTT 形式,直接处理即可。

f(i)f(i) 表示长度为 ii所有排列变为有序所需要的期望打乱次数.

枚举打乱之后第一个不可划分的长度,可以写出递推式

f(n)=i=1ng(i)(ni)!n!([i1]+f(i)+f(ni))f(n) = \sum_{i=1}^{n} \dfrac{g(i)(n-i)!}{n!} ([i\ne 1]+f(i) + f(n-i))

这里 [i1][i\ne 1] 表示如果第一段长度不是 11,那么就需要打乱一次。

把右边 f(n)f(n) 那一项和 [i1][i\ne 1] 处理一下,有

f(n)=n!n!g(n)(11n+i=1n1g(i)(ni)!n!(f(i)+f(ni)))=n!n!g(n)(11n+1n!i=1n1(g(i)f(i)(ni)!+g(i)f(ni)(ni)!))=n!n!g(n)(11n+1n!i=1n1(g(i)f(i)(ni)!+f(i)i!g(ni)))\begin{aligned} f(n) &= \dfrac{n!}{n! - g(n)}\left( 1 -\dfrac 1 n + \sum_{i=1}^{n-1} \dfrac{g(i)(n-i)!}{n!} (f(i) + f(n-i))\right)\\ &= \dfrac{n!}{n! - g(n)}\left( 1 -\dfrac 1 n +\dfrac{1}{n!} \sum_{i=1}^{n-1}\left(g(i)f(i)(n-i)! + g(i)f(n-i)(n-i)!\right)\right)\\ &= \dfrac{n!}{n! - g(n)}\left( 1 -\dfrac 1 n +\dfrac{1}{n!} \sum_{i=1}^{n-1}\left(g(i)f(i)(n-i)! + f(i)i!g(n-i)\right)\right) \end{aligned}

边界条件 f(1)=0f(1)=0,这也是一个分治 FFT 形式,所以做完了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#include<bits/stdc++.h>
using namespace std;

const int p = 998244353;



int qpow(int a,int b){
int r = 1;
while(b){
if(b&1)r=1ll*r*a%p;
a=1ll*a*a%p;
b>>=1;
}
return r;
}


using LL=long long;
using ULL=unsigned long long;

const int N = 4e5+5;
vector<int> omega[25];
void ntt_init(int n){
for(int k=2,d=0;k<=n;k*=2,d++){
omega[d].resize(k+1);
int wn=qpow(3,(p-1)/k),tmp=1;
for(int i=0;i<=k;i++){
omega[d][i]=tmp;
tmp=(LL)tmp*wn%p;
}
}
}
void ntt(int *c,int n,int tp){
static ULL a[N];
for(int i=0;i<n;i++)a[i]=c[i];
for(int i=1,j=0;i<n-1;i++){
int k=n;
do j^=(k>>=1); while(j<k);
if(i<j)swap(a[i],a[j]);
}

for(int k=1,d=0;k<n;k*=2,d++){
if(d==16)for(int i=0;i<n;i++)a[i]%=p;
for(int i=0;i<n;i+=k*2){
for(int j=0;j<k;j++){
int w = omega[d][tp>0 ? j : k*2-j];
ULL u = a[i+j], v = w*a[i+j+k]%p;
a[i+j]=u+v;
a[i+j+k]=u-v+p;
}
}
}
if(tp>0){for(int i=0;i<n;i++)c[i]=a[i]%p;}
else{
int inv = qpow(n, p-2);
for(int i=0;i<n;i++)c[i]=a[i]*inv%p;
}
}

int f[N],g[N];
int A[N],B[N];
int fact[N],invfact[N];

int get_lim(int n){
int r=1;
for(;r<=n;r<<=1);
return r;
}
void solve_g(int l,int r){
if(l==r){
g[l] = (g[l] + fact[l])%p;
return;
}
int mid = (l+r)>>1;
solve_g(l,mid);
int lim = get_lim(2*(r-l+1));
// 分治计算 g
for(int i=0;i<lim;i++)A[i]=B[i]=0;
for(int i=l;i<=mid;i++)A[i-l]=g[i];
for(int i=1;i<=r-l;i++)B[i]=fact[i];
ntt(A, lim, 1), ntt(B, lim, 1);
for(int i=0;i<lim;i++)A[i]=1LL*A[i]*B[i]%p;
ntt(A, lim, -1);
for(int i=mid+1;i<=r;i++)g[i]=(g[i]-A[i-l]+p)%p;
solve_g(mid+1,r);
}
void solve_f(int l,int r){
if(l==r){
f[l] = 1LL * (1LL*f[l]*invfact[l] %p +1-qpow(l, p-2)+p)%p ;
f[l] = 1LL * f[l] * fact[l] % p * qpow((fact[l]-g[l]+p)%p, p-2) % p;
return;
}
int mid = (l+r)>>1;
solve_f(l,mid);
int lim = get_lim(2*(r-l+1));


for(int i=0;i<lim;i++)A[i]=B[i]=0;
for(int i=l;i<=mid;i++)A[i-l]=1LL*g[i]*f[i]%p;
for(int i=1;i<=r-l;i++)B[i]=fact[i];
ntt(A, lim, 1), ntt(B, lim, 1);
for(int i=0;i<lim;i++)A[i]=1LL*A[i]*B[i]%p;
ntt(A, lim, -1);
for(int i=mid+1;i<=r;i++)f[i]=(f[i]+A[i-l])%p;


for(int i=0;i<lim;i++)A[i]=B[i]=0;
for(int i=l;i<=mid;i++)A[i-l]=1LL*fact[i]*f[i]%p;
for(int i=1;i<=r-l;i++)B[i]=g[i];
ntt(A, lim, 1), ntt(B, lim, 1);
for(int i=0;i<lim;i++)A[i]=1LL*A[i]*B[i]%p;
ntt(A, lim, -1);
for(int i=mid+1;i<=r;i++)f[i]=(f[i]+A[i-l])%p;

solve_f(mid+1,r);
}

int main(){
fact[0]=1;
for(int i=1;i<N;i++)fact[i]=1LL*fact[i-1]*i%p;
invfact[N-1]=qpow(fact[N-1], p-2);
for(int i=N-1;i;i--)invfact[i-1]=1LL*invfact[i]*i%p;

ntt_init(4e5);

int n; cin>>n;

solve_g(1, n);
solve_f(1, n);

vector<int> a(n+1);
for(int i=1;i<=n;i++)cin>>a[i];

int ans = 0;
int mn = 1e9, mx = 0;
for(int r=1,l=1;r<=n;r++){
mn = min(mn, a[r]);
mx = max(mx, a[r]);
if(mn==l && mx==r){
if(l!=r){
ans = (ans + 1 + f[r-l+1]) % p;
}
l = r+1;
mn = 1e9, mx = 0;
}
}
cout << ans << '\n';
return 0;
}

例题 2

题意

QOJ 11365. Popping Balloons

给定一个由 {0,1,2}\{0,1,2\} 组成的序列,每秒随机删除其中一个数,求第一次序列变为单调不下降的期望时间,对 998244353998244353 取模。

n2×105n \le 2\times 10^5

做法

考虑最终数组的形态。

f(k)f(k) 表示原数组长度为 kk 的单调不下降子序列个数。

设最后的数组长度为 XX,则有

E[X]=k=1nPr[Xk]=k=1nf(k)(nk)\mathbb E[X] = \sum_{k=1}^n \Pr[X\ge k] =\sum_{k=1}^n \dfrac{f(k)}{\binom n k}

这是因为如果之前已经成为了单调不下降子序列,怎么删得到的还是单调不下降的子序列,那么在计算 Pr[Xk]\Pr[X\ge k] 时,不需要关心什么时候第一次变成单调不下降的时刻,而只需要求出删除了 nkn-k 个数之后,能得到单调不下降子序列的概率即可。

所求即为 nE[X]n-\mathbb E[X]

考虑如何快速求出 f(n)f(n)

考虑分治,设 f[l,r](j,x,y)f_{[l,r]}(j,x,y) 表示 [l,r][l,r] 中的数组成的长度为 jj 的单调不下降子序列,且开头是 xx,结尾是 yy 的方案数。

考虑分治的合并过程,发现可以用 NTT 快速合并,于是做完了,复杂度 O(Σ4nlog2n)O(|\Sigma|^4 n\log^2 n),其中 Σ\Sigma 为字符集大小。(事实上复杂度前面一项是 C(Σ+3,4)C(|\Sigma|+3, 4),只是渐进 O(Σ4)O(|\Sigma|^4)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include<bits/stdc++.h>
using namespace std;

const int p = 998244353;

int qpow(int a,int b){
int r = 1;
while(b){
if(b&1)r=1ll*r*a%p;
a=1ll*a*a%p;
b>>=1;
}
return r;
}


using LL=long long;
using ULL=unsigned long long;

const int N = 8e5+5;
vector<int> omega[25];
void ntt_init(int n){
for(int k=2,d=0;k<=n;k*=2,d++){
omega[d].resize(k+1);
int wn=qpow(3,(p-1)/k),tmp=1;
for(int i=0;i<=k;i++){
omega[d][i]=tmp;
tmp=(LL)tmp*wn%p;
}
}
}
void ntt(int *c,int n,int tp){
static ULL a[N];
for(int i=0;i<n;i++)a[i]=c[i];
for(int i=1,j=0;i<n-1;i++){
int k=n;
do j^=(k>>=1); while(j<k);
if(i<j)swap(a[i],a[j]);
}

for(int k=1,d=0;k<n;k*=2,d++){
if(d==16)for(int i=0;i<n;i++)a[i]%=p;
for(int i=0;i<n;i+=k*2){
for(int j=0;j<k;j++){
int w = omega[d][tp>0 ? j : k*2-j];
ULL u = a[i+j], v = w*a[i+j+k]%p;
a[i+j]=u+v;
a[i+j+k]=u-v+p;
}
}
}
if(tp>0){for(int i=0;i<n;i++)c[i]=a[i]%p;}
else{
int inv = qpow(n, p-2);
for(int i=0;i<n;i++)c[i]=a[i]*inv%p;
}
}

int arr[N];
int f[3][3][N];
int A[N],B[N];
int fact[N],invfact[N];

int get_lim(int n){
int r=1;
for(;r<=n;r<<=1);
return r;
}

void solve(int l,int r){
static int tmp[3][3][N];

if(l==r){
f[arr[l]][arr[l]][l] = 1;
return;
}
int mid = (l+r)>>1;
solve(l,mid); solve(mid+1, r);


int lim = get_lim(r-l+1);

for(int a=0;a<3;a++)for(int b=a;b<3;b++)
for(int i=0;i<lim;i++)tmp[a][b][i]=0;

for(int a=0;a<3;a++)for(int b=a;b<3;b++)
for(int c=b;c<3;c++)for(int d=c;d<3;d++){
// f[a][b][l..mid] + f[c][d][mid+1..r]
// -> f[a][d][l..r]
for(int i=0;i<lim;i++)A[i]=B[i]=0;
for(int i=l;i<=mid;i++)A[i-l+1]=f[a][b][i];
for(int i=mid+1;i<=r;i++)B[i-mid]=f[c][d][i];
ntt(A, lim, 1), ntt(B, lim, 1);
for(int i=0;i<lim;i++)tmp[a][d][i] = (tmp[a][d][i] + 1ll*A[i]*B[i])%p;
}

for(int a=0;a<3;a++)for(int b=a;b<3;b++){
ntt(tmp[a][b], lim, -1);
for(int i=l;i<=mid;i++)tmp[a][b][i-l+1]=(tmp[a][b][i-l+1] + f[a][b][i])%p;
for(int i=mid+1;i<=r;i++)tmp[a][b][i-mid]=(tmp[a][b][i-mid] + f[a][b][i])%p;

for(int i=l;i<=r;i++){
f[a][b][i] = tmp[a][b][i-l+1];
}
}
}

int comb(int n,int r){
return 1ll * fact[n] * invfact[r] %p * invfact[n-r] % p;
}
int invcomb(int n,int r){
return 1ll * invfact[n] * fact[r] %p * fact[n-r] % p;
}

int main(){
fact[0]=1;
for(int i=1;i<N;i++)fact[i]=1LL*fact[i-1]*i%p;
invfact[N-1]=qpow(fact[N-1], p-2);
for(int i=N-1;i;i--)invfact[i-1]=1LL*invfact[i]*i%p;

ntt_init(8e5);

string s;cin>>s;
int n = s.size();
for(int i=1;i<=n;i++){
if(s[i-1]=='B')arr[i]=0;
if(s[i-1]=='Y')arr[i]=1;
if(s[i-1]=='R')arr[i]=2;
}

solve(1, n);

int ans =0 ;
for(int i=1;i<=n;i++){
for(int a=0;a<3;a++)for(int b=a;b<3;b++){
ans = (ans + 1ll * f[a][b][i] * invcomb(n, i))%p;
}
}

cout << (n - ans + p) %p << '\n';
return 0;
}