多项式乘法与卷积

对于两个多项式 f(x)=i=0naixi,g(x)=i=0nbixif(x)=\sum_{i=0}^na_ix^i,g(x)=\sum_{i=0}^nb_ix^i,如果要求出乘积多项式 fgfg,朴素实现是 O(n2)O(n^2) 的。但是,利用 FFT,我们可以在 O(nlogn)O(n\log n) 时间内完成这一问题。

从另一个角度来理解,多项式的乘法是一种加法卷积。
对于两个序列 {a0,a1,,an},{b0,b1,,bn}\{a_0,a_1,\ldots,a_n\},\{b_0,b_1,\ldots,b_n\},定义其卷积 {c0,c1,,c2n}\{c_0,c_1,\ldots,c_{2n}\}ck=i+j=kaibjc_k = \sum_{i+j=k}a_ib_j
不难发现 {an},{bn}\{a_n\},\{b_n\} 分别为 f,gf,g 的系数,而 {cn}\{c_n\} 为乘积 fgfg 的系数。

系数表示法与点值表示法

对于一个 nn 次多项式 f(x)=i=0naixif(x)=\sum_{i=0}^na_ix^i,我们把序列 {a0,a1,,an}\{a_0,a_1,\ldots,a_n\} 称作 f(x)f(x)系数表示

我们知道,通过 n+1n+1f(x)f(x) 上的点亦可以唯一确定这个多项式。
我们把 {(x0,f(x0)),(x1,f(x1)),,(xn,f(xn))}\{(x_0,f(x_0)),(x_1,f(x_1)),\ldots,(x_n,f(x_n))\} 称作 f(x)f(x)点值表示

对于多项式乘法来说,只需要把每一个位置上的点值相乘,得到的就是乘积多项式的点值表示。
(尽管因为乘积次数为 2n2n,所以我们理论上需要 2n+12n+1 个点才能确定乘积多项式,但在实际操作中我们会把两个多项式的高次项补上足够的 0,来让 n+1n+1 个点就足够表达出需要的多项式)

如果我们能快速完成点值表示和系数表示的转化,那么就能快速解决多项式乘法问题。这便是后文提到的 DFT 和 IDFT 的功能。

由此,我们就已经掌握了 FFT 的大致思路:

  1. f(x),g(x)f(x),g(x) 转化为点值表示; (DFT)
  2. 把点值表示对应点值相乘;
  3. 把点值表示还原为系数表示。(IDFT)

DFT

对于一个 n1n-1 次多项式,暴力求出其点值表示还是 O(n2)O(n^2) 的。如何加速这一过程?

DFT 在复数域内,用 nn 次单位根来带入这个多项式。

nn 次单位根定义为在复数域内,xn=1x^n=1nn 个解。
根据复数相关知识,这 nn 个解 wnk=cos(2kπn)+isin(2kπn) (k=0,1,,n1)w_n^k = \cos(\frac{2k\pi}{n})+\mathrm{i}\sin(\frac{2k\pi}{n})\ (k=0,1,\ldots,n-1)

运用分治方式,我们可以快速解决这个问题。
具体来说,对于一个 n1n-1 次多项式 f(x)=i=0n1aixif(x)=\sum_{i=0}^{n-1}a_ix^i(保证 nn 为偶数),我们将其按照奇偶分开:

f(x)=(a0+a2x2++an2xn2)+(a1x+a3x3++an1xn1)f(x)=(a_0+a_2x^2+\ldots+a_{n-2}x^{n-2})+(a_1x+a_3x^3+\ldots+a_{n-1}x^{n-1})

将其拆开为两个 n/21n/2 -1 次多项式 f1,f2f_1,f_2

f1(x)=a0+a2x++an2xn/21f2(x)=a1+a3x++an1xn/21\begin{aligned} f_1(x) = a_0+a_2x+\ldots+a_{n-2}x^{n/2 -1}\\ f_2(x) = a_1+a_3x+\ldots+a_{n-1}x^{n/2 -1} \end{aligned}

我们有:

f(x)=f1(x2)+xf2(x2)f(x)=f_1(x^2)+xf_2(x^2)

考虑代入单位根 wnk (0k<n/2)w_n^{k}\ (0\le k<n/2)

f(wnk)=f1(wn2k)+wnkf2(wn2k)=f1(wn/2k)+wnkf2(wn/2k)\begin{aligned} f(w_n^{k})&=f_1(w_n^{2k})+w_n^{k}f_2(w_n^{2k})\\ &=f_1(w_{n/2}^{k})+w_n^{k}f_2(w_{n/2}^{k})\\ \end{aligned}

代入单位根 wnk+n/2 (0k<n/2)w_n^{k+n/2}\ (0\le k<n/2)

f(wnk+n/2)=f1(wn2k+n)+wnk+n/2f2(wn2k+n)=f1(wn2k)+wnk+n/2f2(wn2k)=f1(wn/2k)+wnk+n/2f2(wn/2k)=f1(wn/2k)wnkf2(wn/2k)\begin{aligned} f(w_n^{k+n/2})&=f_1(w_n^{2k+n})+w_n^{k+n/2}f_2(w_n^{2k+n})\\ &=f_1(w_n^{2k})+w_n^{k+n/2}f_2(w_n^{2k})\\ &=f_1(w_{n/2}^{k})+w_n^{k+n/2}f_2(w_{n/2}^{k})\\ &=f_1(w_{n/2}^{k})-w_n^{k}f_2(w_{n/2}^{k})\\ \end{aligned}

得到了如上两式,不难发现,如果我们已经知道了 f1f_1f2f_2 的单位根点值表示,那么我们就可以在 O(n)O(n) 时间内求出 ff 的单位根点值表示。

我们只需要不断递归下去(这也要求了 nn22 的幂),就能在 O(nlogn)O(n\log n) 时间内求出所有的点值表示。
n=1n=1 时,直接返回即可,因为带入的是 w10=1w_1^0=1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
const double pi = acos(-1);
const int N = 4e6+5;

struct Complex{
double real,imag;
Complex(double a=0,double b=0):real(a),imag(b){}
friend Complex operator+(const Complex& a,const Complex& b){return {a.real+b.real, a.imag+b.imag};}
friend Complex operator-(const Complex& a,const Complex& b){return {a.real-b.real, a.imag-b.imag};}
friend Complex operator*(const Complex& a,const Complex& b){return {a.real*b.real-a.imag*b.imag, a.real*b.imag+a.imag*b.real};}
};
Complex temp[N];
// f[0],f[1],...,f[n-1]存储了 f(x) 的系数, n 是 2 的幂
// 原地修改 f, 最终 f[i] = f(w_n^i)
void DFT(Complex* f,int n){
if(n==1)return; // n=1,次数就是求 f(1),直接返回即可
for(int i=0;i<n;i++)temp[i]=f[i]; // 把 f 拷贝到 temp
for(int i=0;i<n/2;i++)f[i]=temp[i*2],f[i+n/2]=temp[i*2+1]; // 奇偶分开
DFT(f,n/2);DFT(f+n/2,n/2); // 递归分治
for(int i=0;i<n;i++)temp[i]=f[i]; // 把 f 拷贝到 temp
Complex w{cos(2*pi/n),sin(2*pi/n)}; // w_n^1
Complex wnow{1,0}; // 当前的单位根
for(int i=0;i<n/2;i++){
f[i] = temp[i] + wnow * temp[i+n/2];
f[i+n/2] = temp[i] - wnow * temp[i+n/2];
wnow = wnow * w; // 下一个单位根
}
}

IDFT

IDFT 解决的问题是:知道了 n1n-1 次多项式 f(x)=i=0n1aixif(x)=\sum_{i=0}^{n-1}a_ix^inn 次单位根上的点值 f(wn0),f(wni),f(wnn1)f(w_n^0),f(w_n^i)\ldots,f(w_n^{n-1}),如何求出 a0,a1,,an1a_0,a_1,\ldots,a_{n-1}

事实上,我们只需要将 DFT 过程中用到的单位根 wnkw_n^k 替换成 wnkw_n^{-k},最后得到的结果每一项再除 nn 即可。

可以看如下代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 【注意】最后得到的系数还要再除 n
void IDFT(Complex* f,int n){
if(n==1)return;
for(int i=0;i<n;i++)temp[i]=f[i];
for(int i=0;i<n/2;i++)f[i]=temp[i*2],f[i+n/2]=temp[i*2+1];
IDFT(f,n/2);IDFT(f+n/2,n/2);
for(int i=0;i<n;i++)temp[i]=f[i];
Complex w{cos(2*pi/n),-sin(2*pi/n)}; // 单位根反过来
Complex wnow{1,0};
for(int i=0;i<n/2;i++){
f[i] = temp[i] + wnow * temp[i+n/2];
f[i+n/2] = temp[i] - wnow * temp[i+n/2];
wnow = wnow * w;
}
}

/*
...
*/

for(int i=0;i<n;i++)f[i].real/=n,f[i].imag/=n;

我们用较朴素的方法来证明这个结论。

设原来系数为 FF,DFT 得到的点值为 GG,即

G[k]=i=0n1F[i](wnk)iG[k]=\sum_{i=0}^{n-1}F[i]\left(w_n^k\right)^i

需要证明的是

F[k]=1ni=0n1G[i](wnk)iF[k] = \dfrac 1 n \sum_{i=0}^{n-1}G[i]\left(w_n^{-k}\right)^i

证明:

右边=1ni=0n1G[i](wnk)i=1ni=0n1j=0n1F[j](wni)j(wnk)i=1ni=0n1j=0n1F[j]wn(jk)i\begin{aligned} \text{右边} &= \dfrac 1 n \sum_{i=0}^{n-1}G[i]\left(w_n^{-k}\right)^i\\ &= \dfrac 1 n \sum_{i=0}^{n-1}\sum_{j=0}^{n-1}F[j]\left(w_n^i\right)^j\left(w_n^{-k}\right)^i\\ &= \dfrac 1 n \sum_{i=0}^{n-1}\sum_{j=0}^{n-1}F[j] w_n^{(j-k)i}\\ \end{aligned}

考察每一个 F[j]F[j] 的贡献。

  • j=kj=k 时,此时 i=0,1,,n1i=0,1,\ldots,n-1 时,乘的都是 wn0=1w_n^0=1,即 F[k]F[k] 的贡献是 nn
  • jkj\ne k 时,此时我们发现 i=ti=tt+n/2t+n/2 的贡献正好相反抵消,没有任何贡献。

所以得证。

模板

P3803 【模板】多项式乘法(FFT)

除了 DFT 和 IDFT,还要解决的问题就是补 0。设两个多项式的次数为 nnmm,则我们最后的次数要补到大于等于 n+mn+m 的一个 22 的幂减一。这样 IDFT 才能准确得到结果多项式。

用一个 for 就可以完成。

1
for(m+=n,n=1;n<=m;n<<=1); // m 是最后多项式的次数, n-1 是我们运算时候的多项式次数

综上所述,我们就可以通过模板题了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include<bits/stdc++.h>
using namespace std;

const double pi = acos(-1);
const int N = 4e6+5;

struct Complex{
double real,imag;
Complex(double a=0,double b=0):real(a),imag(b){}
friend Complex operator+(const Complex& a,const Complex& b){return {a.real+b.real, a.imag+b.imag};}
friend Complex operator-(const Complex& a,const Complex& b){return {a.real-b.real, a.imag-b.imag};}
friend Complex operator*(const Complex& a,const Complex& b){return {a.real*b.real-a.imag*b.imag, a.real*b.imag+a.imag*b.real};}
};
Complex temp[N];
// f[0],f[1],...,f[n-1]存储了 f(x) 的系数, n 是 2 的幂
// 原地修改 f, 最终 f[i] = f(w_n^i)
void DFT(Complex* f,int n){
if(n==1)return; // n=1,次数就是求 f(1),直接返回即可
for(int i=0;i<n;i++)temp[i]=f[i]; // 把 f 拷贝到 temp
for(int i=0;i<n/2;i++)f[i]=temp[i*2],f[i+n/2]=temp[i*2+1]; // 奇偶分开
DFT(f,n/2);DFT(f+n/2,n/2); // 递归分治
for(int i=0;i<n;i++)temp[i]=f[i]; // 把 f 拷贝到 temp
Complex w{cos(2*pi/n),sin(2*pi/n)}; // w_n^1
Complex wnow{1,0}; // 当前的单位根
for(int i=0;i<n/2;i++){
f[i] = temp[i] + wnow * temp[i+n/2];
f[i+n/2] = temp[i] - wnow * temp[i+n/2];
wnow = wnow * w; // 下一个单位根
}
}

// 【注意】最后得到的系数还要再除 n
void IDFT(Complex* f,int n){
if(n==1)return;
for(int i=0;i<n;i++)temp[i]=f[i];
for(int i=0;i<n/2;i++)f[i]=temp[i*2],f[i+n/2]=temp[i*2+1];
IDFT(f,n/2);IDFT(f+n/2,n/2);
for(int i=0;i<n;i++)temp[i]=f[i];
Complex w{cos(2*pi/n),-sin(2*pi/n)}; // 单位根反过来
Complex wnow{1,0};
for(int i=0;i<n/2;i++){
f[i] = temp[i] + wnow * temp[i+n/2];
f[i+n/2] = temp[i] - wnow * temp[i+n/2];
wnow = wnow * w;
}
}

Complex f[N],g[N];

int main(){
int n,m;cin>>n>>m;
for(int i=0;i<=n;i++)cin>>f[i].real;
for(int i=0;i<=m;i++)cin>>g[i].real;
for(m+=n,n=1;n<=m;n<<=1); // 补 0

DFT(f,n);DFT(g,n);
for(int i=0;i<n;i++)f[i]=f[i]*g[i];
IDFT(f,n);
for(int i=0;i<=m;i++)cout<<(int)round(f[i].real/n)<<' '; // 除 n

return 0;
}

优化常数

这样我们已经能够通过模板题了。

事实上,通过一些巧妙的处理,我们可以完全删去代码中的任何数组拷贝。

在上面的代码中,分治的分开,合并都用到了拷贝,我们分别来看这两个过程。

分开

我们追踪分治过程中的下标变化。

我们发现,最后的下标就是原下标在 log2n\log_2 n 位二进制下反转后得到的结果。

比如 6=(110)26=(110)_2 最后变到了 3=(011)23=(011)_2 的位置。

可以通过递推 O(n)O(n) 求出反转结果:

1
2
3
4
for(int i=0;i<n;i++){
rev[i]=rev[i>>1]>>1;
if(i&1)rev[i]|=n>>1;
}

假设原来的数为 (abcdef)2(abcdef)_2,我们用右移后 (0abcde)2(0abcde)_2 的反转结果来递推。
(0abcde)2(0abcde)_2 反转后变成 (edcba0)2(edcba0)_2,右移得到 (0edcba)2(0edcba)_2,就差最高位的结果。

我们只需判断如果 f=1f=1,那么就将最高位设为 11 就能递推出结果。

因此在 DFT/IDFT 开始的时候,把 f[i]f[rev[i]] 交换即可。

1
2
3
for(int i=0;i<n;i++){
if(i<rev[i])swap(f[i],f[rev[i]]);
}

注意要求 i<rev[i],这是因为一个数反转两次还是自身,比如说 rev[3]=6,rev[6]=3,我们只需要第一次换就可了,不用两次都换,要不然就白换了。

合并

现在 f[0],f[1],...,f[n/2-1] 储存了偶次的结果,f[n/2],f[n/2+1],...,f[n] 储存了奇次的结果。
怎么在不进行拷贝的情况下合并出 f[0],f[1],...,f[n-1] 呢?

回看我们的式子:

f(wnk)=f1(wn/2k)+wnkf2(wn/2k)f(wnk+n/2)=f1(wn/2k)wnkf2(wn/2k)\begin{aligned} f(w_n^{k})=f_1(w_{n/2}^{k})+w_n^{k}f_2(w_{n/2}^{k})\\ f(w_n^{k+n/2})=f_1(w_{n/2}^{k})-w_n^{k}f_2(w_{n/2}^{k})\\ \end{aligned}

上式中有四项,考察他们储存的位置:

  • f1(wn/2k)f_1(w_{n/2}^{k}) 已经存在 f[k]
  • f2(wn/2k)f_2(w_{n/2}^{k}) 已经存在 f[k+n/2]
  • f(wnk)f(w_n^{k}) 要存在 f[k]
  • f(wnk+n/2)f(w_n^{k+n/2}) 要存在 f[k+n/2]

我们把 wnkf2(wn/2k)w_n^{k}f_2(w_{n/2}^{k}) 存在临时变量 t 里,然后就可以这么写:

1
2
3
4
5
6
for(int k=0;k<n/2;k++){
Complex t = f[k+n/2] * wnow;
f[k+n/2] = f[k] - t;
f[k] = f[k] + t; // 注意顺序
wnow = wnow * wn;
}

把递归改成循环

最后我们可以顺便把递归改成循环,进一步加快速度。

因为 DFT 和 IDFT 内容非常相似,我们把它们写进同一个函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
void fft(Complex* f,int n,bool flag){ //flag 0=DFT, 1=IDFT
for(int i=0;i<n;i++){
if(i<rev[i])swap(f[i],f[rev[i]]); // 分开
}

for(int p=2;p<=n;p<<=1){
Complex wn{cos(2*pi/p),sin(2*pi/p)};
if(flag)wn.imag=-wn.imag; // IDFT 的单位根是反过来的
for(int j=0;j<n;j+=p){
Complex wnow{1,0};
for(int k=j;k<j+p/2;k++){
Complex t = f[k+p/2] * wnow;
f[k+p/2] = f[k] - t; // 合并
f[k] = f[k] + t;
wnow = wnow * wn;
}
}
}

if(flag)for(int i=0;i<n;i++)f[i].real/=n,f[i].imag/=n; //IDFT 还要除 n
}

完整代码

可以通过模板题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include<bits/stdc++.h>
using namespace std;

const double pi = acos(-1);
const int N = 4e6+5;

struct Complex{
double real,imag;
Complex(double a=0,double b=0):real(a),imag(b){}
friend Complex operator+(const Complex& a,const Complex& b){
return {a.real+b.real, a.imag+b.imag};
}
friend Complex operator-(const Complex& a,const Complex& b){
return {a.real-b.real, a.imag-b.imag};
}
friend Complex operator*(const Complex& a,const Complex& b){
return {a.real*b.real-a.imag*b.imag, a.real*b.imag+a.imag*b.real};
}
};


int rev[N];

void fft(Complex* f,int n,bool flag){
for(int i=0;i<n;i++){
if(i<rev[i])swap(f[i],f[rev[i]]);
}
for(int p=2;p<=n;p<<=1){
Complex wn{cos(2*pi/p),sin(2*pi/p)};
if(flag)wn.imag=-wn.imag;
for(int j=0;j<n;j+=p){
Complex w{1,0};
for(int k=j;k<(j|p>>1);k++){
Complex t = f[k|p>>1] * w;
f[k|p>>1] = f[k] - t;
f[k] = f[k] + t;
w = w * wn;
}
}
}

if(flag)for(int i=0;i<n;i++)f[i].real/=n,f[i].imag/=n;
}

Complex f[N],g[N];
int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n,m;cin>>n>>m;
for(int i=0;i<=n;i++)cin>>f[i].real;
for(int i=0;i<=m;i++)cin>>g[i].real;

for(m+=n,n=1;n<=m;n<<=1);
for(int i=0;i<n;i++){
rev[i]=rev[i>>1]>>1;
if(i&1)rev[i]|=n>>1;
}

fft(f,n,0);fft(g,n,0);
for(int i=0;i<n;i++)f[i]=f[i]*g[i];
fft(f,n,1);
for(int i=0;i<=m;i++){
cout<<(int)round(f[i].real)<<' ';
}
cout<<'\n';
return 0;
}

练习

P3338 [ZJOI2014] 力
AGC047C - Product Modulo

参考资料

https://www.luogu.com.cn/article/v7vgqau1