행렬 곱셈과 고속 지수 연산

행렬의 기초

  1. 기본 개념

행렬은 행과 열로 구성되는 2차원 배열이다. n×m 행렬은 n개의 행과 m개의 열을 가진 구조를 의미한다.

두 행렬을 곱할 때는 첫 번째 행렬의 열 개수와 두 번째 행렬의 행 개수가 반드시 일치해야 한다. 예를 들어, 2×3 행렬과 3×4 행렬을 곱하면 결과는 2×4 행렬이 된다.

[A_{a \times n} \times B_{n \times m} = C_{a \times m}]

행렬 곱셈의 핵심 공식은 다음과 같다:

[C_{i,j} = \sum_{k=1}^{n} A_{i,k} \times B_{k,j}]

즉, 결과 행렬의 각 원소는 첫 번째 행렬의 i번째 행과 두 번째 행렬의 j번째 열을 순차적으로 곱하여 합산한 값이다.

  1. 단위 행렬

단위 행렬은 행렬 곱셈에서 항등원 역할을 한다. 어떤 행렬에 단위 행렬을 곱해도 원래 행렬이 유지된다. n×n 정사각 행렬에서 단위 행렬은 대각선上の 원소가 1이고, 나머지 원소는 0인 형태를 갖는다.

  1. 연산 특성

행렬 곱셈은 결합 법칙을 만족한다:

[A \times B \times C = A \times (B \times C)]

이 특성 덕분에 지수 연산 최적화 기법을 적용할 수 있다. 그러나 교환 법칙은 성립하지 않으므로:

[A \times B \neq B \times A]

연산 순서가 결과에 영향을 미친다. 코딩 시 순서를 반드시 고려해야 한다.

연습 문제

행렬 고속 지수 연산

행렬의 지수 연산은 일반적인数の 지수 연산과 동일한 원리를 적용한다. 지수 b에 대해 행렬 A^b를 계산할 때, b가 짝수이면 A^(b/2)를 제곱하고, b가 홀수이면 결과에 A를 곱하면서 반복한다.

다음은 n×n 행렬의 k제곱을 구하는 구현 예시이다:

#include <bits/stdc++.h>
using namespace std;
using int64 = long long;

const int MAX_SIZE = 105;
const int64 MOD = 1e9 + 7;

int dimension;
int64 exponent;

struct SquareMatrix {
    int64 data[MAX_SIZE][MAX_SIZE];
    
    void reset() {
        memset(data, 0, sizeof(data));
    }
    
    void setIdentity() {
        reset();
        for (int i = 1; i <= dimension; i++) {
            data[i][i] = 1;
        }
    }
    
    SquareMatrix multiply(const SquareMatrix& other) const {
        SquareMatrix result;
        result.reset();
        
        for (int col = 1; col <= dimension; col++) {
            for (int row = 1; row <= dimension; row++) {
                for (int k = 1; k <= dimension; k++) {
                    result.data[row][col] = (result.data[row][col] + 
                        data[row][k] * other.data[k][col] % MOD) % MOD;
                }
            }
        }
        return result;
    }
};

int64 readInput() {
    int64 val = 0;
    int sign = 1;
    char ch = getchar();
    
    while (ch < '0' || ch > '9') {
        if (ch == '-') sign = -1;
        ch = getchar();
    }
    
    while (ch >= '0' && ch <= '9') {
        val = (val << 1) + (val << 3) + (ch ^ 48);
        ch = getchar();
    }
    return val * sign;
}

int main() {
    dimension = (int)readInput();
    exponent = readInput();
    
    SquareMatrix base, result;
    
    for (int i = 1; i <= dimension; i++) {
        for (int j = 1; j <= dimension; j++) {
            base.data[i][j] = readInput();
        }
    }
    
    result.setIdentity();
    
    while (exponent > 0) {
        if (exponent & 1) {
            result = result.multiply(base);
        }
        base = base.multiply(base);
        exponent >>= 1;
    }
    
    for (int i = 1; i <= dimension; i++) {
        for (int j = 1; j <= dimension; j++) {
            cout << result.data[i][j] << ' ';
        }
        cout << '\n';
    }
    
    return 0;
}

수열 가속 행렬

행렬 곱셈을 활용하면 특정 점화식으로 정의되는 수열을 효율적으로 계산할 수 있다. 1, 1, 1로 시작하여 F(n) = F(n-1) + F(n-3)를 만족하는 수열을 생각해보자.

전환 행렬을 다음과 같이 구성하면:

[\begin{bmatrix} F_n & F_{n-1} & F_{n-2} \end{bmatrix} \times \begin{bmatrix} 1 & 0 & 1 \ 1 & 0 & 0 \ 0 & 1 & 0 \end{bmatrix} = \begin{bmatrix} F_{n+1} & F_n & F_{n-1} \end{bmatrix}]

전환 행렬을 적절히 곱셱하여 원하는 항을 구할 수 있다.

#include <bits/stdc++.h>
using namespace std;
using int64 = long long;

const int64 MOD = 1e9 + 7;
const int MATRIX_SIZE = 3;

struct Matrix3x3 {
    int64 element[MATRIX_SIZE][MATRIX_SIZE];
    
    void clear() {
        memset(element, 0, sizeof(element));
    }
    
    void makeUnit() {
        clear();
        for (int i = 0; i < MATRIX_SIZE; i++) {
            element[i][i] = 1;
        }
    }
    
    Matrix3x3 combine(const Matrix3x3& other) const {
        Matrix3x3 temp;
        temp.clear();
        
        for (int i = 0; i < MATRIX_SIZE; i++) {
            for (int j = 0; j < MATRIX_SIZE; j++) {
                for (int k = 0; k < MATRIX_SIZE; k++) {
                    temp.element[i][j] = (temp.element[i][j] + 
                        element[i][k] * other.element[k][j] % MOD) % MOD;
                }
            }
        }
        return temp;
    }
};

int readInt() {
    int value = 0;
    int sign = 1;
    char c = getchar();
    
    while (c < '0' || c > '9') {
        if (c == '-') sign = -1;
        c = getchar();
    }
    
    while (c >= '0' && c <= '9') {
        value = (value << 1) + (value << 3) + (c ^ 48);
        c = getchar();
    }
    return value * sign;
}

int main() {
    int testCases = readInt();
    
    while (testCases--) {
        int target = readInt() - 3;
        
        if (target <= 0) {
            cout << 1 << '\n';
            continue;
        }
        
        Matrix3x3 initial, power, transform;
        
        initial.clear();
        initial.element[0][0] = initial.element[0][1] = initial.element[0][2] = 1;
        
        power.makeUnit();
        
        transform.clear();
        transform.element[0][0] = transform.element[0][1] = 1;
        transform.element[1][2] = transform.element[2][0] = 1;
        
        while (target > 0) {
            if (target & 1) {
                power = power.combine(transform);
            }
            transform = transform.combine(transform);
            target >>= 1;
        }
        
        initial = initial.combine(power);
        cout << initial.element[0][0] << '\n';
    }
    
    return 0;
}

위 구현은 행렬의 지수 연산을 통해 수열의 n번째 항을 O(log n) 시간 내에 계산한다.

태그: algorithm matrix-multiplication fast-exponentiation cpp competitive-programming

6월 11일 19:01에 게시됨