Zth's Blog

记录学习路上的点滴

0%

多项式与FFT

算法导论第30章

前置知识

多项式

一个以$x$为变量的多项式定义在一个代数域$F$上, 将函数$A(x)$表示为形式和:

我们称$a_0, a_1, a_2, … , a_{n - 1}$为如上多项式的系数, 所有系数都属于域$F$, 典型的情形是负数集合$C$, 但是日常应用中实数集合$R$用的更多

如果一个多项式$A(x)$的最高次的非零系数是$a_k$, 则称$A(x)$的次数是$k$, 记做$degree(A) = k$。 任何严格大于一个多项式次数的整数都是该多项式的次数界, 因此, 对于次数界为$n$的多项式, 其次数可以是$0 … n - 1$之间的任何整数, 包括$0$和$n - 1$; 换言之, 对于一个次数为$n - 1$的多项式, 任何大于等于$n$的正整数都可以是其次数界

多项式加法

没什么好说的

多项式乘法

如果$A(x), B(x)$分别为次数界为$n, m$的多项式, 那么他们的乘积$C(x)$是一个次数界为$n + m - 1$的多项式, 对于所有属于定义域的$x$, 都有$C(x) = A(x) \times B(x)$(ps: 点值表达式用的)。

另外一种表示乘积$C(x)$的方法(系数表达)是

其中

由于一个次数界为$n + m - 1$的多项式也是次数界为$n + m$的多项式, 所以通常我们称乘积多项式$C$是一个次数界为$n + m$的多项式

多项式的表示

系数表达

对一个次数界为$n$的多项式$A(x) = \sum_{j = 0}^{n - 1} a_j x^j$而言, 其系数表达式是一个由系数组成的向量$a = (a_0, a_1, …, a_{n - 1})$。 在涉及到的矩阵方程中, 我们将向量当作列向量看待

显然, 对于给出的一个多项式$A(x)$的系数表达, 我们可以用$O(n)$的时间复杂度求出$A(x_0)$的值

现在考虑求两个用系数表达式给出的, 次数界均为$n$的两个多项式$A(x), B(x)$的乘法运算, 直接算的话时间复杂度显然是$O(n^2)$的, 因为$A(x)$的每一项的系数都要和$B(x)$中的每一项的系数相乘。 这样推导得出的向量$c$也被称为向量$a, b$的卷积, 表示为$c = a \otimes b$

点值表达

一个次数界为$n$的多项式$A(x)$的点值表达式就是一个由$n$个点值对所构成的集合

使得对$k = 0, 1, …, n - 1$, 所有的$x_k$各不相同, 其中

显然, 一个多项式可以有很多不同的点值表达

求值计算的逆(从一个多项式的点值表达式确定其系数表达式)称为差值, 当插值多项式的次数界等于已知点值对的数目时, 插值才是明确的, 即能确定唯一的系数表达。 证明过程用到了线性代数, 主要是范德蒙德矩阵, 书上有, 这里不证明了, 知道这个性质就好(不知道也不影响)

对于这种变化有一种时间复杂度为$O(n^2)$拉格朗日插值法, 这里也不具体说明了

点值表达有什么好处?

对于多项式乘法而言, 如果有$C(x) = A(x) \times B(x)$, 则对于任意的$x_k$, 有$C(x_k) = A(x_k) \times B(x_k)$, 对$A$的点值表达和$B$的点值表达逐点相乘, 就得到了$C$的点值表达

但同时我们必须注意到, $degree(C) = degree(A) + degree(B)$; 如果$A, B$次数界都为$n$, 那么$C$的次数界为$2n$, 按照上述方法我们只能得到$C$的$n$个点值对, 无法通过插值获得$C$的系数表达

为了得到所需要的$2n$个点值对, 我们需要对$A, B$进行扩展

则$C$的点值表达为

然后我们就可以插值了

系数形式表示的多项式的快速乘法

我们能否利用基于点值形式表达的多项式的线性时间乘法算法, 来加速基于系数形式表达的多项式乘法运算呢?

答案在于我们能否快速把一个多项式从系数形式转换为点值形式(求值), 以及从点值形式转换为系数形式(插值)

这就是我们后面要讲的FFT

DFT与FFT

单位复数根

$n$次单位复数根是满足$ \omega^n = 1$的复数$\omega$。 $n$次单位复数根恰好有$n$个: 对于$k = 0, 1, …, n - 1$, 这些根是$e^{2 \pi k / n}$

为了解释上述表达式, 我们利用复数的指数形式的定义

未完待续

贴个代码得了

FFT的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>

const double PI = acos(-1.0);

int n, m;

struct Complex
{
double x, y;

Complex(double xx = 0, double yy = 0) {x = xx, y = yy;}

Complex operator + (Complex c)
{
return Complex (x + c.x, y + c.y);
}

Complex operator - (Complex c)
{
return Complex(x - c.x, y - c.y);
}

Complex operator * (Complex c)
{
return Complex(x * c.x - y * c.y, x * c.y + y * c.x);
}
} a[8000006], b[8000006];

void Fft(Complex *c, int len, int on)
{
if(len == 1) return;

int mid = len / 2;
Complex *c1 = new Complex [mid], *c2 = new Complex [mid];
for(int i = 0; i < len; i += 2) c1[i >> 1] = c[i], c2[i >> 1] = c[i + 1];

Fft(c1, mid, on);
Fft(c2, mid, on);

Complex wn = Complex(cos(2.0 * PI / len), on * sin(2.0 * PI / len)), w = Complex(1, 0);

for(int i = 0; i < mid; i ++)
{
c[i] = c1[i] + w * c2[i];
c[i + mid] = c1[i] - w * c2[i];

w = w * wn;
}

delete [] c1;
delete [] c2;
}

int main()
{
scanf("%d%d", &n, &m);
for(int i = 0; i <= n; i ++)
{
int x;
scanf("%d", &x);

a[i].x = x;
}
for(int i = 0; i <= m; i ++)
{
int x;
scanf("%d", &x);

b[i].x = x;
}

int cn = 1;
while(cn <= n + m) cn <<= 1;

Fft(a, cn, 1);
Fft(b, cn, 1);

for(int i = 0; i < cn; i ++) a[i] = a[i] * b[i];
Fft(a, cn, -1);

for(int i = 0; i <= n + m; i ++) printf("%d%c", (int)(a[i].x / cn + 0.5), " \n"[i == n + m]);

return 0;
}

NTT的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <iostream>
#include <cstdio>

const int mod = 998244353;

int n, m;
int a[2123456], b[2123456];

int Qpow(int x, int y)
{
int res = 1;
for( ; y; x = 1LL * x * x % mod, y >>= 1)
if(y & 1)
res = 1LL * res * x % mod;

return res;
}

void Ntt(int* c, int len, int on)
{
if(len == 1) return;

int mid = len >> 1;
int *c1 = new int [mid], *c2 = new int [mid];

for(int i = 0; i < len; i += 2)
c1[i >> 1] = c[i], c2[i >> 1] = c[i + 1];

Ntt(c1, mid, on);
Ntt(c2, mid, on);

int wi = Qpow(3, (mod - 1) / len), w = 1;
if(! on) wi = Qpow(wi, mod - 2);

for(int i = 0; i < mid; i ++)
{
c[i] = (c1[i] + 1LL * w * c2[i]) % mod;
c[i + mid] = ((c1[i] - 1LL * w * c2[i]) % mod + mod) % mod;

w = 1LL * w * wi % mod;
}

delete [] c1;
delete [] c2;
}

int main()
{
scanf("%d%d", &n, &m);
for(int i = 0; i <= n; i ++) scanf("%d", &a[i]);
for(int i = 0; i <= m; i ++) scanf("%d", &b[i]);

int cn = 1;
while(cn <= n + m) cn <<= 1;

Ntt(a, cn, 1);
Ntt(b, cn, 1);

for(int i = 0; i < cn; i ++) a[i] = 1LL * a[i] * b[i] % mod;

Ntt(a, cn, 0);

int inv = Qpow(cn, mod - 2);
for(int i = 0; i < cn; i ++) a[i] = 1LL * a[i] * inv % mod;

cn = n + m;
for(int i = 0; i <= cn; i ++) printf("%d%c", a[i], " \n"[i == cn]);

return 0;
}