다항식의 빠른 곱셈
두 다항식의 합성곱(convolution)을 계산할 때, 단순한 방법은 모든 항을 직접 곱하는 것으로 시간 복잡도는 $O(n^2)$이다. 하지만 $O(n \log n)$ 시간에 이를 수행할 수 있는 알고리즘이 존재하는데, 대표적으로 FFT(고속 푸리에 변환)와 NTT(수론적 변환)가 있다. 이들은 본질적으로 DFT(이산 푸리에 변환)와 IDFT(역 이산 푸리에 변환)를 효율적으로 계산하는 방법이다.
점값 표현법
$n$차 이하의 다항식 $A(x)$와 서로 다른 $n+1$개의 점 $x_0, \ldots, x_n$이 주어졌을 때, $y_i = A(x_i)$라 하면 $(x_i, y_i)$ 쌍들은 다항식 $A(x)$를 유일하게 결정한다. 두 다항식 $A(x), B(x)$의 점값 표현이 있다면, 같은 위치에서의 $A(x) \pm B(x)$나 $A(x)B(x)$의 점값을 $O(n)$에 구할 수 있다. 핵심은 계수 표현과 점값 표현 간의 효율적인 변환이다.
단위근과 DFT
길이 $n$인 수열 $a$에 대해 $0 \le k < n$일 때,
$$b_k = \sum_{i=0}^{n-1} a_i \cdot \omega_n^{ki}$$
로 정의된 $b$를 $a$의 DFT라 한다. 여기서 $A(x) = \sum a_i x^i$이면 $b_k = A(\omega_n^k)$이므로, DFT 계산은 $A(x)$를 $\omega_n^0, \omega_n^1, \ldots, \omega_n^{n-1}$에서 평가하는 것과 동일하다.
FFT 구현
FFT는 분할 정복을 기반으로 한다. 다항식을 짝수 차수 항과 홀수 차수 항으로 분리하여 재귀적으로 처리한다.
void fft(int sz, complex<double> *coef, int dir) {
if (sz == 1) return;
complex<double> even[sz/2 + 1], odd[sz/2 + 1];
for (int i = 0; i < sz; i += 2) {
even[i/2] = coef[i];
odd[i/2] = coef[i+1];
}
fft(sz/2, even, dir);
fft(sz/2, odd, dir);
complex<double> omega(cos(2.0*PI/sz), dir*sin(2.0*PI/sz)), cur(1.0, 0.0);
for (int i = 0; i < sz/2; ++i, cur *= omega) {
coef[i] = even[i] + cur * odd[i];
coef[i + sz/2] = even[i] - cur * odd[i];
}
}
처리할 다항식의 길이는 반드시 $2^m$ 형태여야 하므로, 실제 구현에서는 필요시 고차 항에 0을 채워 길이를 조정한다.
NTT 구현
NTT는 FFT와 유사하되, 복소수 대신 원근을 사용하여 정수 계산에서도 정확한 결과를 얻는다.
void ntt(int sz, long long *coef, int dir) {
if (sz == 1) return;
long long even[sz/2 + 1], odd[sz/2 + 1];
for (int i = 0; i < sz; i += 2) {
even[i/2] = coef[i];
odd[i/2] = coef[i+1];
}
ntt(sz/2, even, dir);
ntt(sz/2, odd, dir);
long long cur = 1, omega = (dir == 1) ? mod_pow(root, (MOD-1)/sz)
: mod_pow(root_inv, (MOD-1)/sz);
for (int i = 0; i < sz/2; ++i, cur = cur * omega % MOD) {
coef[i] = (even[i] + cur * odd[i]) % MOD;
coef[i + sz/2] = (even[i] - cur * odd[i] + MOD) % MOD;
}
}
합성곱 형태로의 변환 기법
문제의 식을 합성곱 $C_k = \sum_{i=0}^k A_i B_{k-i}$ 형태로 변환하는 것이 핵심이다. 다음과 같은 방법들을 활용할 수 있다:
- 중첩된 합을 각각 합성곱으로 분리
- 합의 범경을 조정하여 닫힌 구간 형태로 변환
- 불변항을 인수로 분리하거나 소거하여 단순화
- 보조 변수 도입으로 구조 파악
- 지수가 $i+j$ 형태일 때, 한 수열을 뒤집어 새 배열로 변환
- 조합수 등으로 재해석 후 합성곱으로 전환
- 각 항의 기여도나 계산 횟수를 직접 세어 원래 수열과 합성곱
생성함수
생성함수는 조합 대상을 기술하는 형식적 다항식으로, 크기나 길이 등의 속성을 만족하는 대상의 수를 계수로 표현한다.
일반 생성함수 (OGF)
수열 $A_0, A_1, \ldots$의 OGF는 $A(x) = \sum_{i=0}^{\infty} A_i x^i$로 정의된다. 무표호 객체를 다루며, 두 수열의 합성곱은 생성함수의 곱에 대응한다.
지수 생성함수 (EGF)
수열 $A_0, A_1, \ldots$의 EGF는 $A(x) = \sum_{i=0}^{\infty} \frac{A_i x^i}{i!}$로 정의된다. 유표호 객체를 다루며, 두 객체를 크기 $n, m$인 것끼리 붙일 때 $\binom{n+m}{n} = \frac{(n+m)!}{n!m!}$가지 방법이 있어 EGF의 곱으로 표현된다.
고속 월시 변환 (FWT)
FWT는 비트 연산을 인덱스로 하는 합성곱 $c_i = \sum_{i=j \oplus k} a_j b_k$를 계산하는 방법이다. $\oplus$는 OR, AND, XOR 등의 이항 연산을 의미한다.
FFT/NTT와 마찬가지로 변환-점별 곱셈-역변환 구조를 따른다:
- $O(n \log n)$으로 $fwt[a], fwt[b]$ 계산
- $O(n)$으로 $fwt[c] = fwt[a] \cdot fwt[b]$ 계산
- $O(n \log n)$으로 역변환하여 $c$ 복원
OR 연산
$c_i = \sum_{j|k=i} a_j b_k$에 대해, $fwt[a]_i = \sum_{j|i=i} a_j$로 정의하면:
$$fwt[a] \times fwt[b] = \sum_{j|i=i}\sum_{k|i=i} a_j b_k = \sum_{(j|k)|i=i} a_j b_k = fwt[c]$$
변환은 $fwt[a] = \text{merge}(fwt[a_0], fwt[a_0] + fwt[a_1])$로 분할 정복한다.
void or_transform(long long *arr, int n, int dir) {
for (int step = 2, half = 1; step <= n; step <<= 1, half <<= 1)
for (int i = 0; i < n; i += step)
for (int j = 0; j < half; ++j)
arr[i + j + half] = ((arr[i + j + half] + dir * arr[i + j]) % MOD + MOD) % MOD;
}
AND 연산
void and_transform(long long *arr, int n, int dir) {
for (int step = 2, half = 1; step <= n; step <<= 1, half <<= 1)
for (int i = 0; i < n; i += step)
for (int j = 0; j < half; ++j)
arr[i + j] = ((arr[i + j] + dir * arr[i + j + half]) % MOD + MOD) % MOD;
}
XOR 연산
$x \otimes y = \text{popcount}(x \& y) \mod 2$를 정의하고, $fwt[a]_i = \sum_{i \otimes j = 0} a_j - \sum_{i \otimes j = 1} a_j$로 설정한다.
void xor_transform(long long *arr, int n, int dir) {
for (int step = 2, half = 1; step <= n; step <<= 1, half <<= 1)
for (int i = 0; i < n; i += step)
for (int j = 0; j < half; ++j) {
long long u = arr[i + j], v = arr[i + j + half];
arr[i + j] = (u + v) % MOD;
arr[i + j + half] = (u - v + MOD) % MOD;
if (dir == -1) {
arr[i + j] = arr[i + j] * inv2 % MOD;
arr[i + j + half] = arr[i + j + half] * inv2 % MOD;
}
}
}