分治 NTT
给定序列 g1…n−1,求序列 f0…n−1。
其中 fi=∑j=1ifi−jgj,边界为 f0=1。
答案对 998244353 取模。
考虑对原序列进行分治,假设当前要求出的区间是 [l,r],其中 [l,mid] 已经递归左半部分算完,那么我们需要先考虑 [l,mid] 对 [mid+1,r] 的贡献再递归右半部分即可。
那我们发现,[l,mid] 对于 [mid+1,r] 的贡献是一个卷积形式。所以我们把 f[l,mid] 和 g[1,r−l] 抽出来做一个 NTT 卷积,然后把答案加到右边,之后再递归右边就能解决这个问题。
因为分治了 O(logn) 层,所以复杂度是 O(nlog2n)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
| #include<bits/stdc++.h> using namespace std;
const int p = 998244353; int qpow(int a,int b){ int r = 1; while(b){ if(b&1)r=1ll*r*a%p; a=1ll*a*a%p; b>>=1; } return r; }
using LL=long long; using ULL=unsigned long long;
const int N = 4e5+5; vector<int> omega[25];
void ntt_init(int n){ for(int k=2,d=0;k<=n;k*=2,d++){ omega[d].resize(k+1); int wn=qpow(3,(p-1)/k),tmp=1; for(int i=0;i<=k;i++){ omega[d][i]=tmp; tmp=(LL)tmp*wn%p; } } }
void ntt(int *c,int n,int tp){ static ULL a[N]; for(int i=0;i<n;i++)a[i]=c[i]; for(int i=1,j=0;i<n-1;i++){ int k=n; do j^=(k>>=1); while(j<k); if(i<j)swap(a[i],a[j]); }
for(int k=1,d=0;k<n;k*=2,d++){ if(d==16)for(int i=0;i<n;i++)a[i]%=p; for(int i=0;i<n;i+=k*2){ for(int j=0;j<k;j++){ int w = omega[d][tp>0 ? j : k*2-j]; ULL u = a[i+j], v = w*a[i+j+k]%p; a[i+j]=u+v; a[i+j+k]=u-v+p; } } } if(tp>0){for(int i=0;i<n;i++)c[i]=a[i]%p;} else{ int inv = qpow(n, p-2); for(int i=0;i<n;i++)c[i]=a[i]*inv%p; } }
int get_lim(int x){ int r = 1; while(r<=x)r<<=1; return r; }
int A[N],B[N]; int f[N],g[N]; void solve(int l,int r){ if(l==r){ if(!l) f[l] = 1; return; } int mid = (l+r)>>1;
solve(l, mid);
int lim = get_lim(2*(r-l+1)); for(int i=0;i<lim;i++)A[i]=B[i]=0; for(int i=l;i<=mid;i++)A[i-l]=f[i]; for(int i=0;i<=r-l;i++)B[i]=g[i];
ntt(A, lim, 1); ntt(B, lim, 1); for(int i=0;i<lim;i++)A[i]=1ll*A[i]*B[i]%p; ntt(A, lim, -1); for(int i=mid+1;i<=r;i++)f[i]=(f[i]+A[i-l])%p; solve(mid+1, r); }
int main(){ ios::sync_with_stdio(0);cin.tie(0); int n; cin>>n; for(int i=1;i<n;i++)cin>>g[i]; ntt_init(n*4);
solve(0, n-1); for(int i=0;i<n;i++)cout<<f[i]<<' '; cout<<'\n'; return 0; }
|
例题 1
题意
[2024 Shanghai Regional F] Fast Bogosort
考虑如下算法流程:
给定一个排列 p,将 p 划分为极多的 [l,r],使得 p[l,r] 恰好是 l,l+1,⋯,r 这些数。对于这些区间,若 l<r,则将其随机打乱。
现在给定一个排列 p,求让 p 变为有序所需要的期望打乱次数,对 998244353 取模。
n≤105。
解法
设 g(i) 表示长度为 i 的排列不能再分割的个数。
容易写出递推式
g(n)=n!−i=1∑n−1g(i)(n−i)!
边界条件 g(1)=1,这是一个分治 NTT 形式,直接处理即可。
设 f(i) 表示长度为 i 的所有排列变为有序所需要的期望打乱次数.
枚举打乱之后第一个不可划分的长度,可以写出递推式
f(n)=i=1∑nn!g(i)(n−i)!([i=1]+f(i)+f(n−i))
这里 [i=1] 表示如果第一段长度不是 1,那么就需要打乱一次。
把右边 f(n) 那一项和 [i=1] 处理一下,有
f(n)=n!−g(n)n!(1−n1+i=1∑n−1n!g(i)(n−i)!(f(i)+f(n−i)))=n!−g(n)n!(1−n1+n!1i=1∑n−1(g(i)f(i)(n−i)!+g(i)f(n−i)(n−i)!))=n!−g(n)n!(1−n1+n!1i=1∑n−1(g(i)f(i)(n−i)!+f(i)i!g(n−i)))
边界条件 f(1)=0,这也是一个分治 FFT 形式,所以做完了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
| #include<bits/stdc++.h> using namespace std;
const int p = 998244353;
int qpow(int a,int b){ int r = 1; while(b){ if(b&1)r=1ll*r*a%p; a=1ll*a*a%p; b>>=1; } return r; }
using LL=long long; using ULL=unsigned long long;
const int N = 4e5+5; vector<int> omega[25]; void ntt_init(int n){ for(int k=2,d=0;k<=n;k*=2,d++){ omega[d].resize(k+1); int wn=qpow(3,(p-1)/k),tmp=1; for(int i=0;i<=k;i++){ omega[d][i]=tmp; tmp=(LL)tmp*wn%p; } } } void ntt(int *c,int n,int tp){ static ULL a[N]; for(int i=0;i<n;i++)a[i]=c[i]; for(int i=1,j=0;i<n-1;i++){ int k=n; do j^=(k>>=1); while(j<k); if(i<j)swap(a[i],a[j]); }
for(int k=1,d=0;k<n;k*=2,d++){ if(d==16)for(int i=0;i<n;i++)a[i]%=p; for(int i=0;i<n;i+=k*2){ for(int j=0;j<k;j++){ int w = omega[d][tp>0 ? j : k*2-j]; ULL u = a[i+j], v = w*a[i+j+k]%p; a[i+j]=u+v; a[i+j+k]=u-v+p; } } } if(tp>0){for(int i=0;i<n;i++)c[i]=a[i]%p;} else{ int inv = qpow(n, p-2); for(int i=0;i<n;i++)c[i]=a[i]*inv%p; } }
int f[N],g[N]; int A[N],B[N]; int fact[N],invfact[N];
int get_lim(int n){ int r=1; for(;r<=n;r<<=1); return r; } void solve_g(int l,int r){ if(l==r){ g[l] = (g[l] + fact[l])%p; return; } int mid = (l+r)>>1; solve_g(l,mid); int lim = get_lim(2*(r-l+1)); for(int i=0;i<lim;i++)A[i]=B[i]=0; for(int i=l;i<=mid;i++)A[i-l]=g[i]; for(int i=1;i<=r-l;i++)B[i]=fact[i]; ntt(A, lim, 1), ntt(B, lim, 1); for(int i=0;i<lim;i++)A[i]=1LL*A[i]*B[i]%p; ntt(A, lim, -1); for(int i=mid+1;i<=r;i++)g[i]=(g[i]-A[i-l]+p)%p; solve_g(mid+1,r); } void solve_f(int l,int r){ if(l==r){ f[l] = 1LL * (1LL*f[l]*invfact[l] %p +1-qpow(l, p-2)+p)%p ; f[l] = 1LL * f[l] * fact[l] % p * qpow((fact[l]-g[l]+p)%p, p-2) % p; return; } int mid = (l+r)>>1; solve_f(l,mid); int lim = get_lim(2*(r-l+1));
for(int i=0;i<lim;i++)A[i]=B[i]=0; for(int i=l;i<=mid;i++)A[i-l]=1LL*g[i]*f[i]%p; for(int i=1;i<=r-l;i++)B[i]=fact[i]; ntt(A, lim, 1), ntt(B, lim, 1); for(int i=0;i<lim;i++)A[i]=1LL*A[i]*B[i]%p; ntt(A, lim, -1); for(int i=mid+1;i<=r;i++)f[i]=(f[i]+A[i-l])%p;
for(int i=0;i<lim;i++)A[i]=B[i]=0; for(int i=l;i<=mid;i++)A[i-l]=1LL*fact[i]*f[i]%p; for(int i=1;i<=r-l;i++)B[i]=g[i]; ntt(A, lim, 1), ntt(B, lim, 1); for(int i=0;i<lim;i++)A[i]=1LL*A[i]*B[i]%p; ntt(A, lim, -1); for(int i=mid+1;i<=r;i++)f[i]=(f[i]+A[i-l])%p;
solve_f(mid+1,r); }
int main(){ fact[0]=1; for(int i=1;i<N;i++)fact[i]=1LL*fact[i-1]*i%p; invfact[N-1]=qpow(fact[N-1], p-2); for(int i=N-1;i;i--)invfact[i-1]=1LL*invfact[i]*i%p;
ntt_init(4e5);
int n; cin>>n;
solve_g(1, n); solve_f(1, n);
vector<int> a(n+1); for(int i=1;i<=n;i++)cin>>a[i];
int ans = 0; int mn = 1e9, mx = 0; for(int r=1,l=1;r<=n;r++){ mn = min(mn, a[r]); mx = max(mx, a[r]); if(mn==l && mx==r){ if(l!=r){ ans = (ans + 1 + f[r-l+1]) % p; } l = r+1; mn = 1e9, mx = 0; } } cout << ans << '\n'; return 0; }
|
例题 2
题意
QOJ 11365. Popping Balloons
给定一个由 {0,1,2} 组成的序列,每秒随机删除其中一个数,求第一次序列变为单调不下降的期望时间,对 998244353 取模。
n≤2×105。
做法
考虑最终数组的形态。
设 f(k) 表示原数组长度为 k 的单调不下降子序列个数。
设最后的数组长度为 X,则有
E[X]=k=1∑nPr[X≥k]=k=1∑n(kn)f(k)
这是因为如果之前已经成为了单调不下降子序列,怎么删得到的还是单调不下降的子序列,那么在计算 Pr[X≥k] 时,不需要关心什么时候第一次变成单调不下降的时刻,而只需要求出删除了 n−k 个数之后,能得到单调不下降子序列的概率即可。
所求即为 n−E[X]。
考虑如何快速求出 f(n)。
考虑分治,设 f[l,r](j,x,y) 表示 [l,r] 中的数组成的长度为 j 的单调不下降子序列,且开头是 x,结尾是 y 的方案数。
考虑分治的合并过程,发现可以用 NTT 快速合并,于是做完了,复杂度 O(∣Σ∣4nlog2n),其中 Σ 为字符集大小。(事实上复杂度前面一项是 C(∣Σ∣+3,4),只是渐进 O(∣Σ∣4))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
| #include<bits/stdc++.h> using namespace std;
const int p = 998244353;
int qpow(int a,int b){ int r = 1; while(b){ if(b&1)r=1ll*r*a%p; a=1ll*a*a%p; b>>=1; } return r; }
using LL=long long; using ULL=unsigned long long;
const int N = 8e5+5; vector<int> omega[25]; void ntt_init(int n){ for(int k=2,d=0;k<=n;k*=2,d++){ omega[d].resize(k+1); int wn=qpow(3,(p-1)/k),tmp=1; for(int i=0;i<=k;i++){ omega[d][i]=tmp; tmp=(LL)tmp*wn%p; } } } void ntt(int *c,int n,int tp){ static ULL a[N]; for(int i=0;i<n;i++)a[i]=c[i]; for(int i=1,j=0;i<n-1;i++){ int k=n; do j^=(k>>=1); while(j<k); if(i<j)swap(a[i],a[j]); }
for(int k=1,d=0;k<n;k*=2,d++){ if(d==16)for(int i=0;i<n;i++)a[i]%=p; for(int i=0;i<n;i+=k*2){ for(int j=0;j<k;j++){ int w = omega[d][tp>0 ? j : k*2-j]; ULL u = a[i+j], v = w*a[i+j+k]%p; a[i+j]=u+v; a[i+j+k]=u-v+p; } } } if(tp>0){for(int i=0;i<n;i++)c[i]=a[i]%p;} else{ int inv = qpow(n, p-2); for(int i=0;i<n;i++)c[i]=a[i]*inv%p; } }
int arr[N]; int f[3][3][N]; int A[N],B[N]; int fact[N],invfact[N];
int get_lim(int n){ int r=1; for(;r<=n;r<<=1); return r; }
void solve(int l,int r){ static int tmp[3][3][N];
if(l==r){ f[arr[l]][arr[l]][l] = 1; return; } int mid = (l+r)>>1; solve(l,mid); solve(mid+1, r);
int lim = get_lim(r-l+1);
for(int a=0;a<3;a++)for(int b=a;b<3;b++) for(int i=0;i<lim;i++)tmp[a][b][i]=0;
for(int a=0;a<3;a++)for(int b=a;b<3;b++) for(int c=b;c<3;c++)for(int d=c;d<3;d++){ for(int i=0;i<lim;i++)A[i]=B[i]=0; for(int i=l;i<=mid;i++)A[i-l+1]=f[a][b][i]; for(int i=mid+1;i<=r;i++)B[i-mid]=f[c][d][i]; ntt(A, lim, 1), ntt(B, lim, 1); for(int i=0;i<lim;i++)tmp[a][d][i] = (tmp[a][d][i] + 1ll*A[i]*B[i])%p; }
for(int a=0;a<3;a++)for(int b=a;b<3;b++){ ntt(tmp[a][b], lim, -1); for(int i=l;i<=mid;i++)tmp[a][b][i-l+1]=(tmp[a][b][i-l+1] + f[a][b][i])%p; for(int i=mid+1;i<=r;i++)tmp[a][b][i-mid]=(tmp[a][b][i-mid] + f[a][b][i])%p;
for(int i=l;i<=r;i++){ f[a][b][i] = tmp[a][b][i-l+1]; } } }
int comb(int n,int r){ return 1ll * fact[n] * invfact[r] %p * invfact[n-r] % p; } int invcomb(int n,int r){ return 1ll * invfact[n] * fact[r] %p * fact[n-r] % p; }
int main(){ fact[0]=1; for(int i=1;i<N;i++)fact[i]=1LL*fact[i-1]*i%p; invfact[N-1]=qpow(fact[N-1], p-2); for(int i=N-1;i;i--)invfact[i-1]=1LL*invfact[i]*i%p;
ntt_init(8e5);
string s;cin>>s; int n = s.size(); for(int i=1;i<=n;i++){ if(s[i-1]=='B')arr[i]=0; if(s[i-1]=='Y')arr[i]=1; if(s[i-1]=='R')arr[i]=2; } solve(1, n);
int ans =0 ; for(int i=1;i<=n;i++){ for(int a=0;a<3;a++)for(int b=a;b<3;b++){ ans = (ans + 1ll * f[a][b][i] * invcomb(n, i))%p; } }
cout << (n - ans + p) %p << '\n'; return 0; }
|