세그먼트 트리 심화: 다중 지연 전파와 세그먼트 트리 비트 구현

세그먼트 트리(Segment Tree)는 배열의 구간 연산을 효율적으로 처리하기 위한 강력한 자료구조입니다. 기본적인 구간 합 구하기를 넘어, 여러 종류의 업데이트를 동시에 처리하거나 구간 내 최솟값/최댓값을 갱신하는 고급 기법들이 존재합니다. 본 글에서는 지연 전파(Lazy Propagation)를 활용한 다중 연산 처리와 세그먼트 트리 비트(Segment Tree Beats)를 포함한 고급 세그먼트 트리 구현 방법을 다룹니다.

1. 기본 구간 덧셈과 구간 합 쿼리

가장 기본적인 형태의 세그먼트 트리입니다. 구간에 특정 값을 더하고, 구간의 합을 구하는 연산을 $O(\log N)$ 시간에 수행합니다. 지연 태그(Lazy Tag)를 사용하여 하위 노드로의 전파를 지연시킴으로써 시간 복잡도를 최적화합니다.

#include <iostream>
#include <vector>

using namespace std;
using ll = long long;

struct SegTreeBasic {
    struct Node {
        ll sum = 0;
        ll lazy = 0;
    };
    int n;
    vector<Node> tree;
    vector<ll> arr;

    SegTreeBasic(int size, const vector<ll>& input) : n(size), arr(input) {
        tree.resize(4 * n);
        build(1, 0, n - 1);
    }

    void build(int node, int start, int end) {
        if (start == end) {
            tree[node].sum = arr[start];
            return;
        }
        int mid = (start + end) / 2;
        build(node * 2, start, mid);
        build(node * 2 + 1, mid + 1, end);
        tree[node].sum = tree[node * 2].sum + tree[node * 2 + 1].sum;
    }

    void propagate(int node, int start, int end) {
        if (tree[node].lazy != 0) {
            int mid = (start + end) / 2;
            apply(node * 2, start, mid, tree[node].lazy);
            apply(node * 2 + 1, mid + 1, end, tree[node].lazy);
            tree[node].lazy = 0;
        }
    }

    void apply(int node, int start, int end, ll val) {
        tree[node].sum += val * (end - start + 1);
        tree[node].lazy += val;
    }

    void update(int node, int start, int end, int l, int r, ll val) {
        if (r < start || end < l) return;
        if (l <= start && end <= r) {
            apply(node, start, end, val);
            return;
        }
        propagate(node, start, end);
        int mid = (start + end) / 2;
        update(node * 2, start, mid, l, r, val);
        update(node * 2 + 1, mid + 1, end, l, r, val);
        tree[node].sum = tree[node * 2].sum + tree[node * 2 + 1].sum;
    }

    ll query(int node, int start, int end, int l, int r) {
        if (r < start || end < l) return 0;
        if (l <= start && end <= r) return tree[node].sum;
        propagate(node, start, end);
        int mid = (start + end) / 2;
        return query(node * 2, start, mid, l, r) + query(node * 2 + 1, mid + 1, end, l, r);
    }
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int n, q;
    if (!(cin >> n >> q)) return 0;
    vector<ll> a(n);
    for (int i = 0; i < n; i++) cin >> a[i];
    
    SegTreeBasic st(n, a);
    while (q--) {
        int type, l, r;
        cin >> type >> l >> r;
        l--; r--; 
        if (type == 1) {
            ll k; cin >> k;
            st.update(1, 0, n - 1, l, r, k);
        } else {
            cout << st.query(1, 0, n - 1, l, r) << "\n";
        }
    }
    return 0;
}

2. 다중 지연 전파: 구간 덧셈과 곱셈

구간에 값을 더하는 연산과 곱하는 연산이 혼합된 경우, 태그를 전파하는 순서가 중요합니다. 일반적으로 곱셈 태그를 먼저 적용한 후 덧셈 태그를 적용하는 방식을 사용하며, 각 노드의 태그를 업데이트할 때 기존 태그와의 상호작용을 정확히 계산해야 합니다.

#include <iostream>
#include <vector>

using namespace std;
using ll = long long;

struct SegTreeMath {
    struct Node {
        ll sum = 0;
        ll add_tag = 0;
        ll mul_tag = 1;
    };
    int n;
    ll mod;
    vector<Node> tree;
    vector<ll> arr;

    SegTreeMath(int size, const vector<ll>& input, ll m) : n(size), arr(input), mod(m) {
        tree.resize(4 * n);
        build(1, 0, n - 1);
    }

    void build(int node, int start, int end) {
        if (start == end) {
            tree[node].sum = arr[start] % mod;
            return;
        }
        int mid = (start + end) / 2;
        build(node * 2, start, mid);
        build(node * 2 + 1, mid + 1, end);
        tree[node].sum = (tree[node * 2].sum + tree[node * 2 + 1].sum) % mod;
    }

    void apply(int node, int start, int end, ll mul_val, ll add_val) {
        ll len = end - start + 1;
        tree[node].sum = (tree[node].sum * mul_val % mod + add_val * len % mod) % mod;
        tree[node].mul_tag = tree[node].mul_tag * mul_val % mod;
        tree[node].add_tag = (tree[node].add_tag * mul_val % mod + add_val) % mod;
    }

    void propagate(int node, int start, int end) {
        if (tree[node].mul_tag != 1 || tree[node].add_tag != 0) {
            int mid = (start + end) / 2;
            apply(node * 2, start, mid, tree[node].mul_tag, tree[node].add_tag);
            apply(node * 2 + 1, mid + 1, end, tree[node].mul_tag, tree[node].add_tag);
            tree[node].mul_tag = 1;
            tree[node].add_tag = 0;
        }
    }

    void update_mul(int node, int start, int end, int l, int r, ll val) {
        if (r < start || end < l) return;
        if (l <= start && end <= r) {
            apply(node, start, end, val, 0);
            return;
        }
        propagate(node, start, end);
        int mid = (start + end) / 2;
        update_mul(node * 2, start, mid, l, r, val);
        update_mul(node * 2 + 1, mid + 1, end, l, r, val);
        tree[node].sum = (tree[node * 2].sum + tree[node * 2 + 1].sum) % mod;
    }

    void update_add(int node, int start, int end, int l, int r, ll val) {
        if (r < start || end < l) return;
        if (l <= start && end <= r) {
            apply(node, start, end, 1, val);
            return;
        }
        propagate(node, start, end);
        int mid = (start + end) / 2;
        update_add(node * 2, start, mid, l, r, val);
        update_add(node * 2 + 1, mid + 1, end, l, r, val);
        tree[node].sum = (tree[node * 2].sum + tree[node * 2 + 1].sum) % mod;
    }

    ll query(int node, int start, int end, int l, int r) {
        if (r < start || end < l) return 0;
        if (l <= start && end <= r) return tree[node].sum;
        propagate(node, start, end);
        int mid = (start + end) / 2;
        return (query(node * 2, start, mid, l, r) + query(node * 2 + 1, mid + 1, end, l, r)) % mod;
    }
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int n, q;
    ll mod;
    if (!(cin >> n >> q >> mod)) return 0;
    vector<ll> a(n);
    for (int i = 0; i < n; i++) cin >> a[i];
    
    SegTreeMath st(n, a, mod);
    while (q--) {
        int type, l, r;
        cin >> type >> l >> r;
        l--; r--;
        if (type == 1) {
            ll k; cin >> k;
            st.update_mul(1, 0, n - 1, l, r, k % mod);
        } else if (type == 2) {
            ll k; cin >> k;
            st.update_add(1, 0, n - 1, l, r, k % mod);
        } else {
            cout << st.query(1, 0, n - 1, l, r) << "\n";
        }
    }
    return 0;
}

3. 세그먼트 트리 비트(Segment Tree Beats): 구간 Chmin과 역사적 최댓값

구간의 모든 원소를 특정 값보다 작게 만드는 연산(Chmin, $a_i = \min(a_i, v)$)은 일반적인 세그먼트 트리로 처리하기 어렵습니다. 이를 해결하기 위해 지 드라이버 세그먼트 트리(Ji Driver Segment Tree)를 사용합니다. 구간의 최댓값, 엄격한 두 번째 최댓값, 최댓값의 개수를 유지하여 조건에 따라 가지치기를 수행합니다. 또한, 역사적 최댓값(Historical Maximum)을 추적하기 위해 태그의 역사적 최댓값도 함께 관리합니다.

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;
using ll = long long;
const ll INF = 1e18;

struct SegTreeBeats {
    struct Node {
        ll sum = 0;
        ll max_val = -INF;
        ll sec_max = -INF;
        ll max_cnt = 0;
        ll hist_max = -INF;
        
        ll add_max = 0;
        ll add_other = 0;
        ll hist_add_max = 0;
        ll hist_add_other = 0;
    };
    
    int n;
    vector<Node> tree;
    vector<ll> arr;

    SegTreeBeats(int size, const vector<ll>& input) : n(size), arr(input) {
        tree.resize(4 * n);
        build(1, 0, n - 1);
    }

    void pull(int node) {
        int l = node * 2, r = node * 2 + 1;
        tree[node].sum = tree[l].sum + tree[r].sum;
        tree[node].hist_max = max(tree[l].hist_max, tree[r].hist_max);
        
        if (tree[l].max_val == tree[r].max_val) {
            tree[node].max_val = tree[l].max_val;
            tree[node].max_cnt = tree[l].max_cnt + tree[r].max_cnt;
            tree[node].sec_max = max(tree[l].sec_max, tree[r].sec_max);
        } else if (tree[l].max_val > tree[r].max_val) {
            tree[node].max_val = tree[l].max_val;
            tree[node].max_cnt = tree[l].max_cnt;
            tree[node].sec_max = max(tree[l].sec_max, tree[r].max_val);
        } else {
            tree[node].max_val = tree[r].max_val;
            tree[node].max_cnt = tree[r].max_cnt;
            tree[node].sec_max = max(tree[l].max_val, tree[r].sec_max);
        }
    }

    void apply(int node, int start, int end, ll am, ll ao, ll ham, ll hao) {
        ll len = end - start + 1;
        tree[node].sum += am * tree[node].max_cnt + ao * (len - tree[node].max_cnt);
        
        tree[node].hist_max = max(tree[node].hist_max, tree[node].max_val + ham);
        
        tree[node].max_val += am;
        if (tree[node].sec_max != -INF) tree[node].sec_max += ao;
        
        tree[node].hist_add_max = max(tree[node].hist_add_max, tree[node].add_max + ham);
        tree[node].hist_add_other = max(tree[node].hist_add_other, tree[node].add_other + hao);
        
        tree[node].add_max += am;
        tree[node].add_other += ao;
    }

    void push(int node, int start, int end) {
        int mid = (start + end) / 2;
        int l = node * 2, r = node * 2 + 1;
        
        ll am = (tree[l].max_val == tree[node].max_val) ? tree[node].add_max : tree[node].add_other;
        ll ao = tree[node].add_other;
        ll ham = (tree[l].max_val == tree[node].max_val) ? tree[node].hist_add_max : tree[node].hist_add_other;
        ll hao = tree[node].hist_add_other;
        apply(l, start, mid, am, ao, ham, hao);
        
        am = (tree[r].max_val == tree[node].max_val) ? tree[node].add_max : tree[node].add_other;
        ham = (tree[r].max_val == tree[node].max_val) ? tree[node].hist_add_max : tree[node].hist_add_other;
        apply(r, mid + 1, end, am, ao, ham, hao);
        
        tree[node].add_max = tree[node].add_other = 0;
        tree[node].hist_add_max = tree[node].hist_add_other = 0;
    }

    void build(int node, int start, int end) {
        if (start == end) {
            tree[node].sum = tree[node].max_val = tree[node].hist_max = arr[start];
            tree[node].max_cnt = 1;
            tree[node].sec_max = -INF;
            return;
        }
        int mid = (start + end) / 2;
        build(node * 2, start, mid);
        build(node * 2 + 1, mid + 1, end);
        pull(node);
    }

    void update_add(int node, int start, int end, int l, int r, ll val) {
        if (r < start || end < l) return;
        if (l <= start && end <= r) {
            apply(node, start, end, val, val, val, val);
            return;
        }
        push(node, start, end);
        int mid = (start + end) / 2;
        update_add(node * 2, start, mid, l, r, val);
        update_add(node * 2 + 1, mid + 1, end, l, r, val);
        pull(node);
    }

    void update_chmin(int node, int start, int end, int l, int r, ll val) {
        if (r < start || end < l || tree[node].max_val <= val) return;
        if (l <= start && end <= r && tree[node].sec_max < val) {
            ll diff = tree[node].max_val - val;
            apply(node, start, end, -diff, 0, -diff, 0);
            return;
        }
        push(node, start, end);
        int mid = (start + end) / 2;
        update_chmin(node * 2, start, mid, l, r, val);
        update_chmin(node * 2 + 1, mid + 1, end, l, r, val);
        pull(node);
    }

    ll query_sum(int node, int start, int end, int l, int r) {
        if (r < start || end < l) return 0;
        if (l <= start && end <= r) return tree[node].sum;
        push(node, start, end);
        int mid = (start + end) / 2;
        return query_sum(node * 2, start, mid, l, r) + query_sum(node * 2 + 1, mid + 1, end, l, r);
    }

    ll query_max(int node, int start, int end, int l, int r) {
        if (r < start || end < l) return -INF;
        if (l <= start && end <= r) return tree[node].max_val;
        push(node, start, end);
        int mid = (start + end) / 2;
        return max(query_max(node * 2, start, mid, l, r), query_max(node * 2 + 1, mid + 1, end, l, r));
    }

    ll query_hist_max(int node, int start, int end, int l, int r) {
        if (r < start || end < l) return -INF;
        if (l <= start && end <= r) return tree[node].hist_max;
        push(node, start, end);
        int mid = (start + end) / 2;
        return max(query_hist_max(node * 2, start, mid, l, r), query_hist_max(node * 2 + 1, mid + 1, end, l, r));
    }
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int n, m;
    if (!(cin >> n >> m)) return 0;
    vector<ll> a(n);
    for (int i = 0; i < n; i++) cin >> a[i];
    
    SegTreeBeats st(n, a);
    while (m--) {
        int type, l, r;
        cin >> type >> l >> r;
        l--; r--;
        if (type == 1) {
            ll k; cin >> k;
            st.update_add(1, 0, n - 1, l, r, k);
        } else if (type == 2) {
            ll k; cin >> k;
            st.update_chmin(1, 0, n - 1, l, r, k);
        } else if (type == 3) {
            cout << st.query_sum(1, 0, n - 1, l, r) << "\n";
        } else if (type == 4) {
            cout << st.query_max(1, 0, n - 1, l, r) << "\n";
        } else if (type == 5) {
            cout << st.query_hist_max(1, 0, n - 1, l, r) << "\n";
        }
    }
    return 0;
}

4. 복잡한 지연 태그: 구간 덮어쓰기와 역사적 최댓값

구간의 값을 특정 값으로 덮어쓰는 연산(Cover)과 덧셈 연산이 혼합되고, 동시에 역사적 최댓값을 구해야 하는 경우 태그의 상태 전이가 매우 복잡해집니다. 덮어쓰기 태그가 존재할 때 덧셈 연산이 들어오면 덮어쓰기 태그 자체를 업데이트해야 하며, 하위 노드로 전파할 때 태그의 적용 순서를 엄격하게 지켜야 합니다.

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;
using ll = long long;
const ll INF = 1e18;

struct SegTreeCPU {
    struct Node {
        ll max_val = -INF;
        ll hist_max = -INF;
        
        ll lazy_add = 0;
        ll hist_add = 0;
        ll lazy_cov = -INF;
        ll hist_cov = -INF;
    };
    
    int n;
    vector<Node> tree;
    vector<ll> arr;

    SegTreeCPU(int size, const vector<ll>& input) : n(size), arr(input) {
        tree.resize(4 * n);
        build(1, 0, n - 1);
    }

    void pull(int node) {
        int l = node * 2, r = node * 2 + 1;
        tree[node].max_val = max(tree[l].max_val, tree[r].max_val);
        tree[node].hist_max = max(tree[l].hist_max, tree[r].hist_max);
    }

    void apply_add(int node, ll add_val, ll hist_add_val) {
        tree[node].hist_max = max(tree[node].hist_max, tree[node].max_val + hist_add_val);
        tree[node].max_val += add_val;
        
        tree[node].hist_add = max(tree[node].hist_add, tree[node].lazy_add + hist_add_val);
        tree[node].lazy_add += add_val;
    }

    void apply_cov(int node, ll cov_val, ll hist_cov_val) {
        tree[node].hist_max = max(tree[node].hist_max, hist_cov_val);
        tree[node].max_val = cov_val;
        
        tree[node].hist_cov = max(tree[node].hist_cov, hist_cov_val);
        tree[node].lazy_cov = cov_val;
    }

    void push(int node) {
        int l = node * 2, r = node * 2 + 1;
        
        if (tree[l].lazy_cov == -INF) {
            apply_add(l, tree[node].lazy_add, tree[node].hist_add);
        } else {
            apply_cov(l, tree[node].lazy_add + tree[l].max_val, tree[node].hist_add + tree[l].max_val);
        }
        
        if (tree[r].lazy_cov == -INF) {
            apply_add(r, tree[node].lazy_add, tree[node].hist_add);
        } else {
            apply_cov(r, tree[node].lazy_add + tree[r].max_val, tree[node].hist_add + tree[r].max_val);
        }
        
        if (tree[node].lazy_cov != -INF) {
            apply_cov(l, tree[node].lazy_cov, tree[node].hist_cov);
            apply_cov(r, tree[node].lazy_cov, tree[node].hist_cov);
        }
        
        tree[node].lazy_add = 0;
        tree[node].hist_add = 0;
        tree[node].lazy_cov = -INF;
        tree[node].hist_cov = -INF;
    }

    void build(int node, int start, int end) {
        tree[node].lazy_cov = -INF;
        tree[node].hist_cov = -INF;
        if (start == end) {
            tree[node].max_val = tree[node].hist_max = arr[start];
            return;
        }
        int mid = (start + end) / 2;
        build(node * 2, start, mid);
        build(node * 2 + 1, mid + 1, end);
        pull(node);
    }

    void update_add(int node, int start, int end, int l, int r, ll val) {
        if (r < start || end < l) return;
        if (l <= start && end <= r) {
            if (tree[node].lazy_cov != -INF) {
                apply_cov(node, tree[node].lazy_cov + val, tree[node].hist_cov + val);
            } else {
                apply_add(node, val, val);
            }
            return;
        }
        push(node);
        int mid = (start + end) / 2;
        update_add(node * 2, start, mid, l, r, val);
        update_add(node * 2 + 1, mid + 1, end, l, r, val);
        pull(node);
    }

    void update_cov(int node, int start, int end, int l, int r, ll val) {
        if (r < start || end < l) return;
        if (l <= start && end <= r) {
            apply_cov(node, val, val);
            return;
        }
        push(node);
        int mid = (start + end) / 2;
        update_cov(node * 2, start, mid, l, r, val);
        update_cov(node * 2 + 1, mid + 1, end, l, r, val);
        pull(node);
    }

    ll query_max(int node, int start, int end, int l, int r) {
        if (r < start || end < l) return -INF;
        if (l <= start && end <= r) return tree[node].max_val;
        push(node);
        int mid = (start + end) / 2;
        return max(query_max(node * 2, start, mid, l, r), query_max(node * 2 + 1, mid + 1, end, l, r));
    }

    ll query_hist(int node, int start, int end, int l, int r) {
        if (r < start || end < l) return -INF;
        if (l <= start && end <= r) return tree[node].hist_max;
        push(node);
        int mid = (start + end) / 2;
        return max(query_hist(node * 2, start, mid, l, r), query_hist(node * 2 + 1, mid + 1, end, l, r));
    }
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int n;
    if (!(cin >> n)) return 0;
    vector<ll> a(n);
    for (int i = 0; i < n; i++) cin >> a[i];
    
    SegTreeCPU st(n, a);
    int m;
    cin >> m;
    while (m--) {
        char type;
        int l, r;
        cin >> type >> l >> r;
        l--; r--;
        if (type == 'P') {
            ll k; cin >> k;
            st.update_add(1, 0, n - 1, l, r, k);
        } else if (type == 'C') {
            ll k; cin >> k;
            st.update_cov(1, 0, n - 1, l, r, k);
        } else if (type == 'Q') {
            cout << st.query_max(1, 0, n - 1, l, r) << "\n";
        } else if (type == 'A') {
            cout << st.query_hist(1, 0, n - 1, l, r) << "\n";
        }
    }
    return 0;
}

태그: SegmentTree LazyPropagation SegmentTreeBeats C++ DataStructure

5월 27일 05:46에 게시됨