多项式乘法与卷积
对于两个多项式 f(x)=∑i=0naixi,g(x)=∑i=0nbixi,如果要求出乘积多项式 fg,朴素实现是 O(n2) 的。但是,利用 FFT,我们可以在 O(nlogn) 时间内完成这一问题。
从另一个角度来理解,多项式的乘法是一种加法卷积。
对于两个序列 {a0,a1,…,an},{b0,b1,…,bn},定义其卷积 {c0,c1,…,c2n},ck=∑i+j=kaibj。
不难发现 {an},{bn} 分别为 f,g 的系数,而 {cn} 为乘积 fg 的系数。
系数表示法与点值表示法
对于一个 n 次多项式 f(x)=∑i=0naixi,我们把序列 {a0,a1,…,an} 称作 f(x) 的系数表示。
我们知道,通过 n+1 个 f(x) 上的点亦可以唯一确定这个多项式。
我们把 {(x0,f(x0)),(x1,f(x1)),…,(xn,f(xn))} 称作 f(x) 的点值表示。
对于多项式乘法来说,只需要把每一个位置上的点值相乘,得到的就是乘积多项式的点值表示。
(尽管因为乘积次数为 2n,所以我们理论上需要 2n+1 个点才能确定乘积多项式,但在实际操作中我们会把两个多项式的高次项补上足够的 0,来让 n+1 个点就足够表达出需要的多项式)
如果我们能快速完成点值表示和系数表示的转化,那么就能快速解决多项式乘法问题。这便是后文提到的 DFT 和 IDFT 的功能。
由此,我们就已经掌握了 FFT 的大致思路:
- 把 f(x),g(x) 转化为点值表示; (DFT)
- 把点值表示对应点值相乘;
- 把点值表示还原为系数表示。(IDFT)
DFT
对于一个 n−1 次多项式,暴力求出其点值表示还是 O(n2) 的。如何加速这一过程?
DFT 在复数域内,用 n 次单位根来带入这个多项式。
n 次单位根定义为在复数域内,xn=1 的 n 个解。
根据复数相关知识,这 n 个解 wnk=cos(n2kπ)+isin(n2kπ) (k=0,1,…,n−1)。
运用分治方式,我们可以快速解决这个问题。
具体来说,对于一个 n−1 次多项式 f(x)=∑i=0n−1aixi(保证 n 为偶数),我们将其按照奇偶分开:
f(x)=(a0+a2x2+…+an−2xn−2)+(a1x+a3x3+…+an−1xn−1)
将其拆开为两个 n/2−1 次多项式 f1,f2:
f1(x)=a0+a2x+…+an−2xn/2−1f2(x)=a1+a3x+…+an−1xn/2−1
我们有:
f(x)=f1(x2)+xf2(x2)
考虑代入单位根 wnk (0≤k<n/2):
f(wnk)=f1(wn2k)+wnkf2(wn2k)=f1(wn/2k)+wnkf2(wn/2k)
代入单位根 wnk+n/2 (0≤k<n/2):
f(wnk+n/2)=f1(wn2k+n)+wnk+n/2f2(wn2k+n)=f1(wn2k)+wnk+n/2f2(wn2k)=f1(wn/2k)+wnk+n/2f2(wn/2k)=f1(wn/2k)−wnkf2(wn/2k)
得到了如上两式,不难发现,如果我们已经知道了 f1 和 f2 的单位根点值表示,那么我们就可以在 O(n) 时间内求出 f 的单位根点值表示。
我们只需要不断递归下去(这也要求了 n 是 2 的幂),就能在 O(nlogn) 时间内求出所有的点值表示。
当 n=1 时,直接返回即可,因为带入的是 w10=1。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| const double pi = acos(-1); const int N = 4e6+5;
struct Complex{ double real,imag; Complex(double a=0,double b=0):real(a),imag(b){} friend Complex operator+(const Complex& a,const Complex& b){return {a.real+b.real, a.imag+b.imag};} friend Complex operator-(const Complex& a,const Complex& b){return {a.real-b.real, a.imag-b.imag};} friend Complex operator*(const Complex& a,const Complex& b){return {a.real*b.real-a.imag*b.imag, a.real*b.imag+a.imag*b.real};} }; Complex temp[N];
void DFT(Complex* f,int n){ if(n==1)return; for(int i=0;i<n;i++)temp[i]=f[i]; for(int i=0;i<n/2;i++)f[i]=temp[i*2],f[i+n/2]=temp[i*2+1]; DFT(f,n/2);DFT(f+n/2,n/2); for(int i=0;i<n;i++)temp[i]=f[i]; Complex w{cos(2*pi/n),sin(2*pi/n)}; Complex wnow{1,0}; for(int i=0;i<n/2;i++){ f[i] = temp[i] + wnow * temp[i+n/2]; f[i+n/2] = temp[i] - wnow * temp[i+n/2]; wnow = wnow * w; } }
|
IDFT
IDFT 解决的问题是:知道了 n−1 次多项式 f(x)=∑i=0n−1aixi 在 n 次单位根上的点值 f(wn0),f(wni)…,f(wnn−1),如何求出 a0,a1,…,an−1。
事实上,我们只需要将 DFT 过程中用到的单位根 wnk 替换成 wn−k,最后得到的结果每一项再除 n 即可。
可以看如下代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| void IDFT(Complex* f,int n){ if(n==1)return; for(int i=0;i<n;i++)temp[i]=f[i]; for(int i=0;i<n/2;i++)f[i]=temp[i*2],f[i+n/2]=temp[i*2+1]; IDFT(f,n/2);IDFT(f+n/2,n/2); for(int i=0;i<n;i++)temp[i]=f[i]; Complex w{cos(2*pi/n),-sin(2*pi/n)}; Complex wnow{1,0}; for(int i=0;i<n/2;i++){ f[i] = temp[i] + wnow * temp[i+n/2]; f[i+n/2] = temp[i] - wnow * temp[i+n/2]; wnow = wnow * w; } }
for(int i=0;i<n;i++)f[i].real/=n,f[i].imag/=n;
|
我们用较朴素的方法来证明这个结论。
设原来系数为 F,DFT 得到的点值为 G,即
G[k]=i=0∑n−1F[i](wnk)i
需要证明的是
F[k]=n1i=0∑n−1G[i](wn−k)i
证明:
右边=n1i=0∑n−1G[i](wn−k)i=n1i=0∑n−1j=0∑n−1F[j](wni)j(wn−k)i=n1i=0∑n−1j=0∑n−1F[j]wn(j−k)i
考察每一个 F[j] 的贡献。
- 当 j=k 时,此时 i=0,1,…,n−1 时,乘的都是 wn0=1,即 F[k] 的贡献是 n。
- 当 j=k 时,此时我们发现 i=t 和 t+n/2 的贡献正好相反抵消,没有任何贡献。
所以得证。
模板
P3803 【模板】多项式乘法(FFT)
除了 DFT 和 IDFT,还要解决的问题就是补 0。设两个多项式的次数为 n,m,则我们最后的次数要补到大于等于 n+m 的一个 2 的幂减一。这样 IDFT 才能准确得到结果多项式。
用一个 for
就可以完成。
1
| for(m+=n,n=1;n<=m;n<<=1);
|
综上所述,我们就可以通过模板题了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| #include<bits/stdc++.h> using namespace std;
const double pi = acos(-1); const int N = 4e6+5;
struct Complex{ double real,imag; Complex(double a=0,double b=0):real(a),imag(b){} friend Complex operator+(const Complex& a,const Complex& b){return {a.real+b.real, a.imag+b.imag};} friend Complex operator-(const Complex& a,const Complex& b){return {a.real-b.real, a.imag-b.imag};} friend Complex operator*(const Complex& a,const Complex& b){return {a.real*b.real-a.imag*b.imag, a.real*b.imag+a.imag*b.real};} }; Complex temp[N];
void DFT(Complex* f,int n){ if(n==1)return; for(int i=0;i<n;i++)temp[i]=f[i]; for(int i=0;i<n/2;i++)f[i]=temp[i*2],f[i+n/2]=temp[i*2+1]; DFT(f,n/2);DFT(f+n/2,n/2); for(int i=0;i<n;i++)temp[i]=f[i]; Complex w{cos(2*pi/n),sin(2*pi/n)}; Complex wnow{1,0}; for(int i=0;i<n/2;i++){ f[i] = temp[i] + wnow * temp[i+n/2]; f[i+n/2] = temp[i] - wnow * temp[i+n/2]; wnow = wnow * w; } }
void IDFT(Complex* f,int n){ if(n==1)return; for(int i=0;i<n;i++)temp[i]=f[i]; for(int i=0;i<n/2;i++)f[i]=temp[i*2],f[i+n/2]=temp[i*2+1]; IDFT(f,n/2);IDFT(f+n/2,n/2); for(int i=0;i<n;i++)temp[i]=f[i]; Complex w{cos(2*pi/n),-sin(2*pi/n)}; Complex wnow{1,0}; for(int i=0;i<n/2;i++){ f[i] = temp[i] + wnow * temp[i+n/2]; f[i+n/2] = temp[i] - wnow * temp[i+n/2]; wnow = wnow * w; } }
Complex f[N],g[N];
int main(){ int n,m;cin>>n>>m; for(int i=0;i<=n;i++)cin>>f[i].real; for(int i=0;i<=m;i++)cin>>g[i].real; for(m+=n,n=1;n<=m;n<<=1);
DFT(f,n);DFT(g,n); for(int i=0;i<n;i++)f[i]=f[i]*g[i]; IDFT(f,n); for(int i=0;i<=m;i++)cout<<(int)round(f[i].real/n)<<' ';
return 0; }
|
优化常数
这样我们已经能够通过模板题了。
事实上,通过一些巧妙的处理,我们可以完全删去代码中的任何数组拷贝。
在上面的代码中,分治的分开,合并都用到了拷贝,我们分别来看这两个过程。
分开
我们追踪分治过程中的下标变化。

我们发现,最后的下标就是原下标在 log2n 位二进制下反转后得到的结果。
比如 6=(110)2 最后变到了 3=(011)2 的位置。
可以通过递推 O(n) 求出反转结果:
1 2 3 4
| for(int i=0;i<n;i++){ rev[i]=rev[i>>1]>>1; if(i&1)rev[i]|=n>>1; }
|
假设原来的数为 (abcdef)2,我们用右移后 (0abcde)2 的反转结果来递推。
(0abcde)2 反转后变成 (edcba0)2,右移得到 (0edcba)2,就差最高位的结果。
我们只需判断如果 f=1,那么就将最高位设为 1 就能递推出结果。
因此在 DFT/IDFT 开始的时候,把 f[i]
和 f[rev[i]]
交换即可。
1 2 3
| for(int i=0;i<n;i++){ if(i<rev[i])swap(f[i],f[rev[i]]); }
|
注意要求 i<rev[i]
,这是因为一个数反转两次还是自身,比如说 rev[3]=6,rev[6]=3
,我们只需要第一次换就可了,不用两次都换,要不然就白换了。
合并
现在 f[0],f[1],...,f[n/2-1]
储存了偶次的结果,f[n/2],f[n/2+1],...,f[n]
储存了奇次的结果。
怎么在不进行拷贝的情况下合并出 f[0],f[1],...,f[n-1]
呢?
回看我们的式子:
f(wnk)=f1(wn/2k)+wnkf2(wn/2k)f(wnk+n/2)=f1(wn/2k)−wnkf2(wn/2k)
上式中有四项,考察他们储存的位置:
- f1(wn/2k) 已经存在
f[k]
- f2(wn/2k) 已经存在
f[k+n/2]
- f(wnk) 要存在
f[k]
- f(wnk+n/2) 要存在
f[k+n/2]
我们把 wnkf2(wn/2k) 存在临时变量 t
里,然后就可以这么写:
1 2 3 4 5 6
| for(int k=0;k<n/2;k++){ Complex t = f[k+n/2] * wnow; f[k+n/2] = f[k] - t; f[k] = f[k] + t; wnow = wnow * wn; }
|
把递归改成循环
最后我们可以顺便把递归改成循环,进一步加快速度。
因为 DFT 和 IDFT 内容非常相似,我们把它们写进同一个函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| void fft(Complex* f,int n,bool flag){ for(int i=0;i<n;i++){ if(i<rev[i])swap(f[i],f[rev[i]]); }
for(int p=2;p<=n;p<<=1){ Complex wn{cos(2*pi/p),sin(2*pi/p)}; if(flag)wn.imag=-wn.imag; for(int j=0;j<n;j+=p){ Complex wnow{1,0}; for(int k=j;k<j+p/2;k++){ Complex t = f[k+p/2] * wnow; f[k+p/2] = f[k] - t; f[k] = f[k] + t; wnow = wnow * wn; } } }
if(flag)for(int i=0;i<n;i++)f[i].real/=n,f[i].imag/=n; }
|
完整代码
可以通过模板题。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
| #include<bits/stdc++.h> using namespace std;
const double pi = acos(-1); const int N = 4e6+5;
struct Complex{ double real,imag; Complex(double a=0,double b=0):real(a),imag(b){} friend Complex operator+(const Complex& a,const Complex& b){ return {a.real+b.real, a.imag+b.imag}; } friend Complex operator-(const Complex& a,const Complex& b){ return {a.real-b.real, a.imag-b.imag}; } friend Complex operator*(const Complex& a,const Complex& b){ return {a.real*b.real-a.imag*b.imag, a.real*b.imag+a.imag*b.real}; } };
int rev[N];
void fft(Complex* f,int n,bool flag){ for(int i=0;i<n;i++){ if(i<rev[i])swap(f[i],f[rev[i]]); } for(int p=2;p<=n;p<<=1){ Complex wn{cos(2*pi/p),sin(2*pi/p)}; if(flag)wn.imag=-wn.imag; for(int j=0;j<n;j+=p){ Complex w{1,0}; for(int k=j;k<(j|p>>1);k++){ Complex t = f[k|p>>1] * w; f[k|p>>1] = f[k] - t; f[k] = f[k] + t; w = w * wn; } } }
if(flag)for(int i=0;i<n;i++)f[i].real/=n,f[i].imag/=n; }
Complex f[N],g[N]; int main(){ ios::sync_with_stdio(0);cin.tie(0);
int n,m;cin>>n>>m; for(int i=0;i<=n;i++)cin>>f[i].real; for(int i=0;i<=m;i++)cin>>g[i].real;
for(m+=n,n=1;n<=m;n<<=1); for(int i=0;i<n;i++){ rev[i]=rev[i>>1]>>1; if(i&1)rev[i]|=n>>1; }
fft(f,n,0);fft(g,n,0); for(int i=0;i<n;i++)f[i]=f[i]*g[i]; fft(f,n,1); for(int i=0;i<=m;i++){ cout<<(int)round(f[i].real)<<' '; } cout<<'\n'; return 0; }
|
练习
P3338 [ZJOI2014] 力
AGC047C - Product Modulo
参考资料
https://www.luogu.com.cn/article/v7vgqau1