算法导论第30章
前置知识
多项式
一个以$x$为变量的多项式定义在一个代数域$F$上, 将函数$A(x)$表示为形式和:
我们称$a_0, a_1, a_2, … , a_{n - 1}$为如上多项式的系数, 所有系数都属于域$F$, 典型的情形是负数集合$C$, 但是日常应用中实数集合$R$用的更多
如果一个多项式$A(x)$的最高次的非零系数是$a_k$, 则称$A(x)$的次数是$k$, 记做$degree(A) = k$。 任何严格大于一个多项式次数的整数都是该多项式的次数界, 因此, 对于次数界为$n$的多项式, 其次数可以是$0 … n - 1$之间的任何整数, 包括$0$和$n - 1$; 换言之, 对于一个次数为$n - 1$的多项式, 任何大于等于$n$的正整数都可以是其次数界
多项式加法
没什么好说的
多项式乘法
如果$A(x), B(x)$分别为次数界为$n, m$的多项式, 那么他们的乘积$C(x)$是一个次数界为$n + m - 1$的多项式, 对于所有属于定义域的$x$, 都有$C(x) = A(x) \times B(x)$(ps: 点值表达式用的)。
另外一种表示乘积$C(x)$的方法(系数表达)是
其中
由于一个次数界为$n + m - 1$的多项式也是次数界为$n + m$的多项式, 所以通常我们称乘积多项式$C$是一个次数界为$n + m$的多项式
多项式的表示
系数表达
对一个次数界为$n$的多项式$A(x) = \sum_{j = 0}^{n - 1} a_j x^j$而言, 其系数表达式是一个由系数组成的向量$a = (a_0, a_1, …, a_{n - 1})$。 在涉及到的矩阵方程中, 我们将向量当作列向量看待
显然, 对于给出的一个多项式$A(x)$的系数表达, 我们可以用$O(n)$的时间复杂度求出$A(x_0)$的值
现在考虑求两个用系数表达式给出的, 次数界均为$n$的两个多项式$A(x), B(x)$的乘法运算, 直接算的话时间复杂度显然是$O(n^2)$的, 因为$A(x)$的每一项的系数都要和$B(x)$中的每一项的系数相乘。 这样推导得出的向量$c$也被称为向量$a, b$的卷积, 表示为$c = a \otimes b$
点值表达
一个次数界为$n$的多项式$A(x)$的点值表达式就是一个由$n$个点值对所构成的集合
使得对$k = 0, 1, …, n - 1$, 所有的$x_k$各不相同, 其中
显然, 一个多项式可以有很多不同的点值表达
求值计算的逆(从一个多项式的点值表达式确定其系数表达式)称为差值, 当插值多项式的次数界等于已知点值对的数目时, 插值才是明确的, 即能确定唯一的系数表达。 证明过程用到了线性代数, 主要是范德蒙德矩阵, 书上有, 这里不证明了, 知道这个性质就好(不知道也不影响)
对于这种变化有一种时间复杂度为$O(n^2)$拉格朗日插值法, 这里也不具体说明了
点值表达有什么好处?
对于多项式乘法而言, 如果有$C(x) = A(x) \times B(x)$, 则对于任意的$x_k$, 有$C(x_k) = A(x_k) \times B(x_k)$, 对$A$的点值表达和$B$的点值表达逐点相乘, 就得到了$C$的点值表达
但同时我们必须注意到, $degree(C) = degree(A) + degree(B)$; 如果$A, B$次数界都为$n$, 那么$C$的次数界为$2n$, 按照上述方法我们只能得到$C$的$n$个点值对, 无法通过插值获得$C$的系数表达
为了得到所需要的$2n$个点值对, 我们需要对$A, B$进行扩展
则$C$的点值表达为
然后我们就可以插值了
系数形式表示的多项式的快速乘法
我们能否利用基于点值形式表达的多项式的线性时间乘法算法, 来加速基于系数形式表达的多项式乘法运算呢?
答案在于我们能否快速把一个多项式从系数形式转换为点值形式(求值), 以及从点值形式转换为系数形式(插值)
这就是我们后面要讲的FFT
DFT与FFT
单位复数根
$n$次单位复数根是满足$ \omega^n = 1$的复数$\omega$。 $n$次单位复数根恰好有$n$个: 对于$k = 0, 1, …, n - 1$, 这些根是$e^{2 \pi k / n}$
为了解释上述表达式, 我们利用复数的指数形式的定义
未完待续
贴个代码得了
FFT的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
| #include <iostream> #include <cstdio> #include <cstring> #include <cmath>
const double PI = acos(-1.0);
int n, m;
struct Complex { double x, y;
Complex(double xx = 0, double yy = 0) {x = xx, y = yy;}
Complex operator + (Complex c) { return Complex (x + c.x, y + c.y); }
Complex operator - (Complex c) { return Complex(x - c.x, y - c.y); }
Complex operator * (Complex c) { return Complex(x * c.x - y * c.y, x * c.y + y * c.x); } } a[8000006], b[8000006];
void Fft(Complex *c, int len, int on) { if(len == 1) return;
int mid = len / 2; Complex *c1 = new Complex [mid], *c2 = new Complex [mid]; for(int i = 0; i < len; i += 2) c1[i >> 1] = c[i], c2[i >> 1] = c[i + 1];
Fft(c1, mid, on); Fft(c2, mid, on);
Complex wn = Complex(cos(2.0 * PI / len), on * sin(2.0 * PI / len)), w = Complex(1, 0);
for(int i = 0; i < mid; i ++) { c[i] = c1[i] + w * c2[i]; c[i + mid] = c1[i] - w * c2[i];
w = w * wn; }
delete [] c1; delete [] c2; }
int main() { scanf("%d%d", &n, &m); for(int i = 0; i <= n; i ++) { int x; scanf("%d", &x);
a[i].x = x; } for(int i = 0; i <= m; i ++) { int x; scanf("%d", &x);
b[i].x = x; }
int cn = 1; while(cn <= n + m) cn <<= 1;
Fft(a, cn, 1); Fft(b, cn, 1);
for(int i = 0; i < cn; i ++) a[i] = a[i] * b[i]; Fft(a, cn, -1);
for(int i = 0; i <= n + m; i ++) printf("%d%c", (int)(a[i].x / cn + 0.5), " \n"[i == n + m]);
return 0; }
|
NTT的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
| #include <iostream> #include <cstdio>
const int mod = 998244353;
int n, m; int a[2123456], b[2123456];
int Qpow(int x, int y) { int res = 1; for( ; y; x = 1LL * x * x % mod, y >>= 1) if(y & 1) res = 1LL * res * x % mod;
return res; }
void Ntt(int* c, int len, int on) { if(len == 1) return;
int mid = len >> 1; int *c1 = new int [mid], *c2 = new int [mid];
for(int i = 0; i < len; i += 2) c1[i >> 1] = c[i], c2[i >> 1] = c[i + 1];
Ntt(c1, mid, on); Ntt(c2, mid, on);
int wi = Qpow(3, (mod - 1) / len), w = 1; if(! on) wi = Qpow(wi, mod - 2);
for(int i = 0; i < mid; i ++) { c[i] = (c1[i] + 1LL * w * c2[i]) % mod; c[i + mid] = ((c1[i] - 1LL * w * c2[i]) % mod + mod) % mod;
w = 1LL * w * wi % mod; }
delete [] c1; delete [] c2; }
int main() { scanf("%d%d", &n, &m); for(int i = 0; i <= n; i ++) scanf("%d", &a[i]); for(int i = 0; i <= m; i ++) scanf("%d", &b[i]);
int cn = 1; while(cn <= n + m) cn <<= 1;
Ntt(a, cn, 1); Ntt(b, cn, 1);
for(int i = 0; i < cn; i ++) a[i] = 1LL * a[i] * b[i] % mod;
Ntt(a, cn, 0);
int inv = Qpow(cn, mod - 2); for(int i = 0; i < cn; i ++) a[i] = 1LL * a[i] * inv % mod;
cn = n + m; for(int i = 0; i <= cn; i ++) printf("%d%c", a[i], " \n"[i == cn]);
return 0; }
|